pytorch总结学习系列-操作
算术操作
在PyTorch中,同一种操作可能有很多种形式,下⾯用加法作为例⼦。
加法形式⼀
x = torch.tensor([5.5, 3])
y = torch.rand(5, 3)
print(x + y)
加法形式⼆
print(torch.add(x, y))
还可指定输出:
result = torch.empty(5, 3)
torch.add(x, y, out=result)
print(result)
加法形式三、inplace
# adds x to y
y.add_(x)
print(y)
索引
我们还可以使⽤用类似NumPy的索引操作来访问 Tensor 的⼀部分,需要注意的是:索引出来的结果与
原数据共享内存,也即修改一个,另一个会跟着修改。
y = x[0, :]
y += 1
print(y)
print(x[0, :]) # 源tensor也被改了了
除了常⽤的索引选择数据之外,PyTorch还提供了一些高级的选择函数:
改变形状
⽤用 view() 来改变 Tensor 的形状:
y = x.view(15)
z = x.view(-1, 5) # -1所指的维度可以根据其他维度的值推出来
print(x.size(), y.size(), z.size())
输出
torch.Size([5, 3]) torch.Size([15]) torch.Size([3, 5])
注意 view() 返回的新tensor与源tensor共享内存(其实是同⼀一个tensor),也即更更改其中的⼀一个,另
外⼀一个也会跟着改变。(顾名思义,view仅仅是改变了了对这个张量量的观察⻆角度)
x += 1
print(x)
print(y) # 也加了1
所以如果我们想返回⼀个真正新的副本(即不共享内存)该怎么办呢?Pytorch还提供了了⼀一
个 reshape() 可以改变形状,但是此函数并不不能保证返回的是其拷⻉贝,所以不不推荐使⽤用。推荐先
用 clone 创造一个副本然后再使⽤ view
x_cp = x.clone().view(15)
x -= 1
print(x)
print(x_cp)
使⽤用 clone 还有⼀一个好处是会被记录在计算图中,即梯度回传到副本时也会传到源 Tensor 。
另外⼀一个常⽤用的函数就是 item() , 它可以将⼀一个标量 Tensor 转换成⼀一个Python number:
x = torch.randn(1)
print(x)
print(x.item())
输出
tensor([2.3466])
2.3466382026672363
线性代数
另外,PyTorch还⽀支持一些线性函数,不用自⼰造轮⼦,具体用法参考官⽅
文档。如下表所示: