博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch .detach() .detach_() 和 .data用于切断反向传播
阅读量:6553 次
发布时间:2019-06-24

本文共 5998 字,大约阅读时间需要 19 分钟。

参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource

当我们再训练网络的时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时候我们就需要使用detach()函数来切断一些分支的反向传播

1   detach()[source]

返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。

即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

这样我们就会继续使用这个新的Variable进行计算,后面当我们进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播

源码为:

def detach(self):        """Returns a new Variable, detached from the current graph.        Result will never require gradient. If the input is volatile, the output        will be volatile too.        .. note::          Returned Variable uses the same data tensor, as the original one, and          in-place modifications on either of them will be seen, and may trigger          errors in correctness checks.        """        result = NoGrad()(self)  # this is needed, because it merges version counters        result._grad_fn = None      return result

可见函数进行的操作有:

  • 将grad_fn设置为None
  • 将Variablerequires_grad设置为False

如果输入 volatile=True(即不需要保存记录,当只需要结果而不需要更新参数时这么设置来加快运算速度),那么返回的Variable volatile=True。(volatile已经弃用)

注意:

返回的Variable和原始的Variable公用同一个data tensorin-place函数修改会在两个Variable上同时体现(因为它们共享data tensor),当要对其调用backward()时可能会导致错误。

举例:

比如正常的例子是:

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)print(a.grad)out = a.sigmoid()out.sum().backward()print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py Nonetensor([0.1966, 0.1050, 0.0452])

 

当使用detach()但是没有进行更改时,并不会影响backward():

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)print(a.grad)out = a.sigmoid()print(out)#添加detach(),c的requires_grad为Falsec = out.detach()print(c)#这时候没有对c进行更改,所以并不会影响backward()out.sum().backward()print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py Nonetensor([0.7311, 0.8808, 0.9526], grad_fn=
)tensor([0.7311, 0.8808, 0.9526])tensor([0.1966, 0.1050, 0.0452])

可见c,out之间的区别是c是没有梯度的,out是有梯度的

 

如果这里使用的是c进行sum()操作并进行backward(),则会报错:

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)print(a.grad)out = a.sigmoid()print(out)#添加detach(),c的requires_grad为Falsec = out.detach()print(c)#使用新生成的Variable进行反向传播c.sum().backward()print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py Nonetensor([0.7311, 0.8808, 0.9526], grad_fn=
)tensor([0.7311, 0.8808, 0.9526])Traceback (most recent call last): File "test.py", line 13, in
c.sum().backward() File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward allow_unreachable=True) # allow_unreachable flagRuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

 

如果此时对c进行了更改,这个更改会被autograd追踪,在对out.sum()进行backward()时也会报错,因为此时的值进行backward()得到的梯度是错误的:

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)print(a.grad)out = a.sigmoid()print(out)#添加detach(),c的requires_grad为Falsec = out.detach()print(c)c.zero_() #使用in place函数对其进行修改#会发现c的修改同时会影响out的值print(c)print(out)#这时候对c进行更改,所以会影响backward(),这时候就不能进行backward(),会报错out.sum().backward()print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py Nonetensor([0.7311, 0.8808, 0.9526], grad_fn=
)tensor([0.7311, 0.8808, 0.9526])tensor([0., 0., 0.])tensor([0., 0., 0.], grad_fn=
)Traceback (most recent call last): File "test.py", line 16, in
out.sum().backward() File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward allow_unreachable=True) # allow_unreachable flagRuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

 

2   .data

如果上面的操作使用的是.data,效果会不同:

这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)print(a.grad)out = a.sigmoid()print(out)c = out.dataprint(c)c.zero_() #使用in place函数对其进行修改#会发现c的修改同时也会影响out的值print(c)print(out)#这里的不同在于.data的修改不会被autograd追踪,这样当进行backward()时它不会报错,回得到一个错误的backward值out.sum().backward()print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py Nonetensor([0.7311, 0.8808, 0.9526], grad_fn=
)tensor([0.7311, 0.8808, 0.9526])tensor([0., 0., 0.])tensor([0., 0., 0.], grad_fn=
)tensor([0., 0., 0.])

 

上面的内容实现的原理是:

In-place 正确性检查

所有的Variable都会记录用在他们身上的 in-place operations。如果pytorch检测到variable在一个Function中已经被保存用来backward,但是之后它又被in-place operations修改。当这种情况发生时,在backward的时候,pytorch就会报错。这种机制保证了,如果你用了in-place operations,但是在backward过程中没有报错,那么梯度的计算就是正确的。

 

⚠️下面结果正确是因为改变的是sum()的结果,中间值a.sigmoid()并没有被影响,所以其对求梯度并没有影响:

import torcha = torch.tensor([1, 2, 3.], requires_grad=True)print(a.grad)out = a.sigmoid().sum() #但是如果sum写在这里,而不是写在backward()前,得到的结果是正确的print(out)c = out.dataprint(c)c.zero_() #使用in place函数对其进行修改#会发现c的修改同时也会影响out的值print(c)print(out)#没有写在这里out.backward()print(a.grad)

返回:

(deeplearning) userdeMBP:pytorch user$ python test.py Nonetensor(2.5644, grad_fn=
)tensor(2.5644)tensor(0.)tensor(0., grad_fn=
)tensor([0.1966, 0.1050, 0.0452])

 

3   detach_()[source]

将一个Variable从创建它的图中分离,并把它设置成叶子variable

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子variable是x,但是这个时候对m进行了.detach_()操作,其实就是进行了两个操作:

  • 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  • 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

 

⚠️

这么一看其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的variable

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的

但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

 

转载于:https://www.cnblogs.com/wanghui-garcia/p/10677071.html

你可能感兴趣的文章
《FPGA全程进阶---实战演练》第十一章 VGA五彩缤纷
查看>>
C# for循环①护栏长度 ②广场砖面积 ③判断闰年平年
查看>>
mysql数据库中,查看数据库的字符集(所有库的字符集或者某个特定库的字符集)...
查看>>
LintCode刷题——打劫房屋I、II、III
查看>>
第七次课程作业
查看>>
C++ 文本查询2.0(逻辑查询)
查看>>
Objective-C学习总结-13协议1
查看>>
web学习方向
查看>>
寒假训练营第四次作业
查看>>
SQLServer 维护脚本分享(05)内存(Memory)
查看>>
A*算法实现
查看>>
第一周 从C走进C++ 002 命令行参数
查看>>
【java】itext pdf 分页
查看>>
看看这个电脑的配置
查看>>
用户自定义控件(.ascx)
查看>>
[转]【NoSQL】NoSQL入门级资料整理(CAP原理、最终一致性)
查看>>
RequireJS进阶(二)
查看>>
.NET中数组的隐秘特性
查看>>
Console-算法-一个偶数总能表示为两个素数之和
查看>>
我设计的网站的分布式架构
查看>>