1. requires_grad
Variable變量的requires_grad的屬性默認為False,若一個節(jié)點requires_grad被設(shè)置為True,那么所有依賴它的節(jié)點的requires_grad都為True。
x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)
y依賴于w,w的requires_grad=True,因此y的requires_grad=True (類似or操作)
2. volatile
volatile=True是Variable的另一個重要的標識,它能夠?qū)⑺幸蕾囁墓?jié)點全部設(shè)為volatile=True,其優(yōu)先級比requires_grad=True高。因而volatile=True的節(jié)點不會求導,即使requires_grad=True,也不會進行反向傳播,對于不需要反向傳播的情景(inference,測試推斷),該參數(shù)可以實現(xiàn)一定速度的提升,并節(jié)省一半的顯存,因為其不需要保存梯度。
前方高能預警:如果你看完了前面volatile,請及時把它從你的腦海中擦除掉,因為
UserWarning: volatile was removed (Variable.volatile is always False)
該屬性已經(jīng)在0.4版本中被移除了,并提示你可以使用with torch.no_grad()代替該功能
x = torch.tensor([1], requires_grad=True)
with torch.no_grad():
... y = x * 2
y.requires_grad
False
@torch.no_grad()
def doubler(x):
return x * 2
z = doubler(x)
z.requires_grad
False
參考文章: https://blog.csdn.net/jiangpeng59/article/details/80667335