pytorch 维度操作那些事
维度转置函数 transpose() 和 permute()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| x = torch.randn([2,3]) y = torch.randn([2,3,4]) print(f'x.shape:{x.shape},y.shape{y.shape}') # transpose() torch.transpose(x)合法, x.transpose()合法。 ## torch.transpose(input, dim0, dim1, out=None) print(f'{x.transpose(0,1).shape,torch.transpose(x,0,1).shape}') # transpose 一次性只能对两个维度进行操作。因此在进行维度变化的时候多采用permute() print(f'{y.transpose(0,2).shape}') x = torch.randn([2,3]) y = torch.randn([2,3,4]) print(f'x.shape:{x.shape},y.shape{y.shape}') # permute() torch.permute(x)不合法,x.permute()合法。 print(f'{x.permute(0,1).shape}') print(f'{y.permute(0,2,1).shape}') # transpose() 和 permute() 两者大致相同,但是在内存处理方面具有差异
|
维度修改函数 squeeze()和 unsqueeze()
1 2 3 4 5 6 7 8 9
| # squeeze() ## 低维记为0 高维记为1 ## 如果小括号里不是1怎么办?这个括号里的1是什么意思? ## 具体而言,如果一个张量有四个维度的,squeeze(index)会将张量中第index维度, ## 且大小为1的维度进行去除,从而减少张量的维度。如果index是负整数,那就是倒数第index个维度 print(f'{x.squeeze(-1).shape}') # unsqueeze() ## 从左到右,高维记为0,最低维最大 print(f'{y.unsqueeze(0).shape}')
|
参考文献
https://blog.csdn.net/xinjieyuan/article/details/105232802
https://blog.csdn.net/guihaiyuan123/article/details/113455775