计算cross_entropy

转onehot格式

1.首先用到scatter函数(scatter_原地改变数值)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    >>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
## 下面0为dim表示行index,1表示按列进行赋值
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1
, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])

>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])

In []: z = torch.zeros(2, 4).scatter_(1, torch.tensor([[1, 2], [1, 3]]), 1.23
...: )

In []: z
Out[]:
tensor([[0.0000, 1.2300, 1.2300, 0.0000],
[0.0000, 1.2300, 0.0000, 1.2300]])

具体代码为

1
2
3
4
label = torch.tensor([0,0,1])
onehot_ = torch.FloatTensor(label.shape[0], 3)
onehot.zero_()
onehot.scatter_(1, torch.reshape(label, (3,1)), 1)

计算crosstropy

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch

lable = torch.tensor([0,0,1])

fc_out = torch.tensor(
[
[2.5, -2, 0.8989],
[3, 0.8, -865],
[0.00000000000001, 2, 4.9]
])

class CrossEntropyLoss():
def __init__(self):
super(CrossEntropyLoss, self).__init__()

def forward(self, fc_out, label):
one_hot_lable = torch.FloatTensor(fc_out.shape[0], 3)
one_hot_lable.zero_()
one_hot_lable.scatter_(1, torch.reshape(lable, (fc_out.shape[0], 1)), 1)
loss = one_hot_lable * torch.softmax(fc_out, 1)
loss = -torch.sum(torch.log(torch.sum(loss, 1)))/fc_out.shape[0]

return loss


loss = torch.nn.CrossEntropyLoss()
loss1 = CrossEntropyLoss()
l = loss(fc_out, lable)
l2 = loss1.forward(fc_out, lable)

print(l)
print(l2)