PyTorch里的requires_grad、volatile及no_grad

1. requires_grad

Variable變量的requires_grad的屬性默認為False,若一個節(jié)點requires_grad被設(shè)置為True,那么所有依賴它的節(jié)點的requires_grad都為True。

x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)

y依賴于w,w的requires_grad=True,因此y的requires_grad=True (類似or操作)

2. volatile

volatile=True是Variable的另一個重要的標識,它能夠?qū)⑺幸蕾囁墓?jié)點全部設(shè)為volatile=True,其優(yōu)先級比requires_grad=True高。因而volatile=True的節(jié)點不會求導,即使requires_grad=True,也不會進行反向傳播,對于不需要反向傳播的情景(inference,測試推斷),該參數(shù)可以實現(xiàn)一定速度的提升,并節(jié)省一半的顯存,因為其不需要保存梯度。
前方高能預警:如果你看完了前面volatile,請及時把它從你的腦海中擦除掉,因為

UserWarning: volatile was removed (Variable.volatile is always False)

該屬性已經(jīng)在0.4版本中被移除了,并提示你可以使用with torch.no_grad()代替該功能

x = torch.tensor([1], requires_grad=True)
with torch.no_grad():
...   y = x * 2
      y.requires_grad
False

@torch.no_grad()
def doubler(x):
      return x * 2
z = doubler(x)

z.requires_grad
False

參考文章: https://blog.csdn.net/jiangpeng59/article/details/80667335

注意:torch.Tensor生成的tensor,requires_grad默認為False。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容