pytorch技巧1: 数组排序后复原

pytorch用法1: 数组排序后复原

主要是利用torch.sort函数里返回的第二个参数index,这个index表示的是排序后的数字在原来数组中的位置。

比如:

l = torch.randint(10,(10,))
a, idx1 = torch.sort(l)

结果为:
l: tensor([3., 3., 8., 7., 9., 9., 7., 4., 5., 1.])
a: tensor([1., 3., 3., 4., 5., 7., 7., 8., 9., 9.])
idx1: tensor([9, 0, 1, 7, 8, 3, 6, 2, 4, 5])

这里的index结果比如是从0到N的,这里的idx1中的值就对应a中的值在b中的位置,比如idx1中的第一个9表示a中的1在数组l中出现在第9个。

如果我们将这个index再进行排序就会发现第二次得到的index保留了一些有用的信息:

b, idx2 = torch.sort(idx1)

b: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
idx2: tensor([1, 2, 7, 5, 8, 9, 6, 3, 4, 0])

idx2的结果表示的是b中的值在idx1中出现的原位置,而idx1中的数的位置表示的正是数组a中的数字在原始数组l中的数的位置。于是如果使用a按照idx2的结果来选择数字,就得到了原始数组l:

a.index_select(0,idx2) #0表示选择的维度

结果正是数组l。

组合起来如下所示:

import torch

l = torch.randint(10,(10,))
a, idx1 = torch.sort(l)
b, idx2 = torch.sort(idx1)
print(l)
print(a,idx1)
print(b,idx2)
print(a.index_select(0,idx2))

结果为:
pytorch技巧1: 数组排序后复原