pytorch相关总结
以下全部是实例,通过例子更能体会相关函数的作用。
一、view()和expand()
x=torch.Tensor([[1,2],[3,4]])
H=x.shape[0]
W=x.shape[1]
x1 = x.view(H, 1, W, 1)
x1
输出:
tensor([[[[1.],
[2.]]],
[[[3.],
[4.]]]])
x2=x1.expand( H, 3, W, 3)
x2
输出:
tensor([[[[1., 1., 1.],
[2., 2., 2.]],
[[1., 1., 1.],
[2., 2., 2.]],
[[1., 1., 1.],
[2., 2., 2.]]],
[[[3., 3., 3.],
[4., 4., 4.]],
[[3., 3., 3.],
[4., 4., 4.]],
[[3., 3., 3.],
[4., 4., 4.]]]])
x3=x2.contiguous().view(H*3, W*3)
x3
输出:
tensor([[1., 1., 1., 2., 2., 2.],
[1., 1., 1., 2., 2., 2.],
[1., 1., 1., 2., 2., 2.],
[3., 3., 3., 4., 4., 4.],
[3., 3., 3., 4., 4., 4.],
[3., 3., 3., 4., 4., 4.]])
二、view()和repeat()
b=torch.linspace(1,11,3)
b
输出:
tensor([ 1., 6., 11.])
bb=b.repeat(3,1)
bb
输出:
tensor([[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.]])
bbb=bb.repeat(3,1)
bbb
输出:
tensor([[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.],
[ 1., 6., 11.]])
bb=b.repeat(3,2)
bb
输出:
tensor([[ 1., 6., 11., 1., 6., 11.],
[ 1., 6., 11., 1., 6., 11.],
[ 1., 6., 11., 1., 6., 11.]])
所以说,他表示的意思就是对b进行多少次赋值,括号中有几个数字,那么变换之后就是几维的。
bb=b.repeat(2,2,2)
bb
输出:
tensor([[[ 1., 6., 11., 1., 6., 11.],
[ 1., 6., 11., 1., 6., 11.]],
[[ 1., 6., 11., 1., 6., 11.],
[ 1., 6., 11., 1., 6., 11.]]])
dim()
返回最大的维数
index_select()
a=torch.Tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
a
输出:
Out[73]:
tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]])
a.index_select(1,torch.LongTensor([0]))
输出:
tensor([[1.],
[2.],
[3.],
[4.]])
list的相加和torch.Tensor的相加
a=[[1,1],[2,2]]
b=[[3,3],[4,4]]
c=a+b
c
输出:
[[1, 1], [2, 2], [3, 3], [4, 4]]
a=np.array([[1,1],[2,2]])
b=np.array([[3,3],[4,4]])
c=a+b
c
输出:
array([[4, 4],
[6, 6]])
a=torch.Tensor([[1,1],[2,2]])
b=torch.Tensor([[3,3],[4,4]])
c=a+b
c
输出:
tensor([[4., 4.],
[6., 6.]])
使用matplotlib在图片上绘制矩形
import matplotlib.pyplot as plt
import matplotlib.patches as patches
img=plt.imread("./images/3.jpg")
fig, a = plt.subplots(1, 1)
a.imshow(img)
# Plot the bounding boxes and corresponding labels on top of the image
# for i in range(1):
# Set the postion and size of the bounding box. (x1, y2) is the pixel coordinate of the
# lower-left corner of the bounding box relative to the size of the image.
rect = patches.Rectangle((10, 10),
200, 200,
linewidth=2,
edgecolor='r',
facecolor='none')
# Draw the bounding box on top of the image
a.add_patch(rect)
plt.show()
输出: