pytorch 维度操作那些事

维度转置函数 transpose() 和 permute()

plaintext
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = torch.randn([2,3])
y = torch.randn([2,3,4])
print(f'x.shape:{x.shape},y.shape{y.shape}')
# transpose() torch.transpose(x)合法, x.transpose()合法。
## torch.transpose(input, dim0, dim1, out=None)
print(f'{x.transpose(0,1).shape,torch.transpose(x,0,1).shape}')
# transpose 一次性只能对两个维度进行操作。因此在进行维度变化的时候多采用permute()
print(f'{y.transpose(0,2).shape}')
x = torch.randn([2,3])
y = torch.randn([2,3,4])
print(f'x.shape:{x.shape},y.shape{y.shape}')
# permute() torch.permute(x)不合法,x.permute()合法。
print(f'{x.permute(0,1).shape}')
print(f'{y.permute(0,2,1).shape}')
# transpose() 和 permute() 两者大致相同,但是在内存处理方面具有差异

维度修改函数 squeeze()和 unsqueeze()

plaintext
1
2
3
4
5
6
7
8
9
# squeeze()
## 低维记为0 高维记为1
## 如果小括号里不是1怎么办?这个括号里的1是什么意思?
## 具体而言,如果一个张量有四个维度的,squeeze(index)会将张量中第index维度,
## 且大小为1的维度进行去除,从而减少张量的维度。如果index是负整数,那就是倒数第index个维度
print(f'{x.squeeze(-1).shape}')
# unsqueeze()
## 从左到右,高维记为0,最低维最大
print(f'{y.unsqueeze(0).shape}')

参考文献

https://blog.csdn.net/xinjieyuan/article/details/105232802
https://blog.csdn.net/guihaiyuan123/article/details/113455775