Catalogue
PyTorch 的自动求导
先看一个小例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
|
import torch from torch.autograd import Variable import torch.nn as nn
x = Variable(torch.ones(1, 2), requires_grad=True) linear = nn.Linear(2, 2, bias=False) linear.weight.data.copy_(torch.Tensor([[1, 2], [1, 2]]))
y = x for i in range(2): y = linear(y)
loss = y.sum() loss.backward() print(linear.weight.grad) print(x.grad)
|
其中 $y = Wcdot(Wcdot x)$, 使用 .backward()
即可自动求导了.
参考资料
近期评论