bytetensor于longtensor对于索引的影响

今天发现pytorch里, bytetensor(torch.uint8)和longtensor(torch.long)对于索引有不同的影响

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

boxes = torch.randint(6, (3,4))
boxes

# [1., 4., 5., 2.],
# [5., 5., 1., 5.]])
mask1 = torch.tensor([0,1,0], dtype=torch.long).view(-1,1)
mask2 = torch.tensor([0,1,0], dtype=torch.uint8).view(-1,1)

boxes[mask1.expand_as(boxes)]
#tensor([[[0., 4., 1., 1.],
# [0., 4., 1., 1.],
# [0., 4., 1., 1.],
# [0., 4., 1., 1.]],

# [[1., 4., 5., 2.],
# [1., 4., 5., 2.],
# [1., 4., 5., 2.],
# [1., 4., 5., 2.]],

# [[0., 4., 1., 1.],
# [0., 4., 1., 1.],
# [0., 4., 1., 1.],
# [0., 4., 1., 1.]]])

boxes[mask2.expand_as(boxes)]
#tensor([1., 4., 5., 2.])

从上面代码可以看出,longtensor对于数组的索引是按照值对应的位置来索引

而bytetensor索引的位置是当前为非零的位置输出