pytorch 统计计算

Pytorch 统计计算

norm 范数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
a = torch.full([8], 1)

b = a.view(2, 4)

c = a.view(2, 2, 2)
>>>
>>> b
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>>
>>> c
tensor([[[1., 1.],
[1., 1.]],

[[1., 1.],
[1., 1.]]])
>>>
>>> a.norm(1), b.norm(1), c.norm(1)
(tensor(8.), tensor(8.), tensor(8.))
>>>
>>> a.norm(2), b.norm(2), c.norm(2)
(tensor(2.8284), tensor(2.8284), tensor(2.8284))
>>>
>>> b.norm(1, dim=1)
tensor([4., 4.])
>>>
>>> b.norm(2, dim=1)
tensor([2., 2.])
>>>
>>> c.norm(1, dim=0)
tensor([[2., 2.],
[2., 2.]])
>>>
>>> b.norm(2, dim=0)
tensor([1.4142, 1.4142, 1.4142, 1.4142])

max, min, argmax, argmin, sum

max, min, sum

1
2
3
4
5
6
7
8
9
10
11
>>> a = torch.arange(8).view(2, 4).float()
>>>
>>> a
tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]])
>>>
>>> a.min(), a.max(), a.mean(), a.prod()
(tensor(0.), tensor(7.), tensor(3.5000), tensor(0.))
>>>
>>> a.sum()
tensor(28.)

argmax, argmin

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> 
... a.argmax(), a.argmin()
(tensor(7), tensor(0))
>>>
>>> # 可以传入dim
... a = torch.randn(4, 10)
>>>
>>> a[0]
tensor([-0.8606, 0.8238, -0.0901, 1.2483, -0.1413, 2.4067, 1.1060, -1.0993,
0.3217, 0.9500])
>>>
>>> a.argmax()
tensor(5)
>>>
>>> a.argmax(dim=1)
tensor([5, 3, 7, 3])

keepdim 保持原来的 dim

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

>>> a.shape
torch.Size([4, 10])
>>>
>>> a.max(dim=1)
torch.return_types.max(
values=tensor([2.4067, 1.7488, 2.2724, 1.5806]),
indices=tensor([5, 3, 7, 3]))
>>>
>>> a.argmax(dim=1)
tensor([5, 3, 7, 3])
>>>
>>> a.max(dim=1, keepdim=True)
torch.return_types.max(
values=tensor([[2.4067],
[1.7488],
[2.2724],
[1.5806]]),
indices=tensor([[5],
[3],
[7],
[3]]))
>>>
>>> a.argmax(dim=1, keepdim=True)
tensor([[5],
[3],
[7],
[3]])

topK , kthvalue

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

>>> a = torch.randn(4, 10)
>>> a.topk(3, dim=1)
torch.return_types.topk(
values=tensor([[1.8580, 1.0516, 0.9508],
[1.4077, 0.8320, 0.5378],
[1.8832, 0.6575, 0.4265],
[1.5381, 0.0498, 0.0174]]),
indices=tensor([[4, 6, 5],
[1, 2, 3],
[7, 5, 0],
[0, 2, 3]]))
>>>
>>> a.topk(3, dim=1, largest=False)
torch.return_types.topk(
values=tensor([[-1.3991, -0.9355, -0.8185],
[-0.6529, -0.3630, -0.1805],
[-2.8063, -2.1974, -1.0160],
[-1.8190, -1.7599, -1.2404]]),
indices=tensor([[2, 1, 9],
[0, 8, 7],
[9, 1, 8],
[7, 6, 1]]))
>>>
>>> a.kthvalue(8, dim=1)
torch.return_types.kthvalue(
values=tensor([0.9508, 0.5378, 0.4265, 0.0174]),
indices=tensor([5, 3, 0, 3]))
>>>
>>> a.kthvalue(8)
torch.return_types.kthvalue(
values=tensor([0.9508, 0.5378, 0.4265, 0.0174]),
indices=tensor([5, 3, 0, 3]))

比较

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

>>> a = torch.randn(4, 10)
>>>
>>> a > 0
tensor([[1, 1, 0, 1, 1, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
[1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 1, 1, 0, 1, 1]], dtype=torch.uint8)
>>>
>>> torch.gt(a, 0)
tensor([[1, 1, 0, 1, 1, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
[1, 0, 0, 1, 0, 0, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 1, 1, 0, 1, 1]], dtype=torch.uint8)
>>>
>>> a != 0
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)
>>>
>>> a = torch.ones(2, 3)
>>>
>>> b = torch.randn(2, 3)
>>>
>>> torch.eq(a, b)
tensor([[0, 0, 0],
[0, 0, 0]], dtype=torch.uint8)
>>>
>>> torch.eq(a, a)
tensor([[1, 1, 1],
[1, 1, 1]], dtype=torch.uint8)
>>>
>>> torch.equal(a, a)
True