tensor1[tensor2]
刚看到这个结构有点懵,不知道它是具体怎么工作的
example.py
a = torch.arange(16)
b = torch.tensor([2,2,0,1,0,0,1,0,2,1,0,0,1,0,0,0],dtype=torch.uint8)
print(a)
print(b)
print(a[b])
index_list = [[4,3,2,1,0]]
c = torch.LongTensor(index_list)
# print(a)
print(a[c])
print(a.shape,c.shape,a[c].shape)
d = []
for i,index in enumerate(index_list):
d.append(a[index])
print(d)
'''output tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) tensor([2, 2, 0, 1, 0, 0, 1, 0, 2, 1, 0, 0, 1, 0, 0, 0], dtype=torch.uint8) tensor([ 0, 1, 3, 6, 8, 9, 12]) tensor([[4, 3, 2, 1, 0]]) torch.Size([16]) torch.Size([1, 5]) torch.Size([1, 5]) [tensor([4, 3, 2, 1, 0])] '''
索引为torch.uint8类型
可以看到在tensor2
为bool/uint8
类型时,tensor2 更像是一个mask,将原有tensor进行筛选一遍,取出tensor2 对应位置不为0的元素
索引为torch.long类型
这个时候就比较麻烦了,tensor2
中存的更像是tensor1
中的位置id, 这个时候a[b].shape == b.shape
相当于在 tensor2 中将所有的元素替换成tensor1中指定位置的元素,写了一个替代脚本:
a = torch.arange(16)
index_list = [[4,3,2,1,0]]
c = torch.LongTensor(index_list)
print(a[c])
d = []
for i,index in enumerate(index_list):
d.append(a[index])
print(d)
# a[c] == d
## ***的tensor
a = torch.arange(12).view(4,3)
print(a[c])
print(a.shape,c.shape,a[c].shape)
d = []
for i,index in enumerate(index_list):
d.append(a[index])
print(d)
''' tensor([[[ 6, 7, 8], [ 9, 10, 11], [ 6, 7, 8], [ 3, 4, 5], [ 0, 1, 2]]]) torch.Size([4, 3]) torch.Size([1, 5]) torch.Size([1, 5, 3]) [tensor([[ 6, 7, 8], [ 9, 10, 11], [ 6, 7, 8], [ 3, 4, 5], [ 0, 1, 2]])] '''