1
|
permute(dims)
|
将tensor的维度换位。
参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。
例:
1
2
3
4
5
6
7
|
import torch
import numpy as np
a = np.array([[[ 1 , 2 , 3 ],[ 4 , 5 , 6 ]]])
unpermuted = torch.tensor(a)
print (unpermuted.size()) # ——> torch.Size([1, 2, 3])
permuted = unpermuted.permute( 2 , 0 , 1 )
print (permuted.size()) # ——> torch.Size([3, 1, 2])
|
再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。
利用这个函数permute(1,3,2)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成
1
2
3
|
tensor([[[ 1. , 4. ],
[ 2. , 5. ],
[ 3. , 6. ]]])
|
如果使用view(1,3,2),可以得到
1
2
3
|
tensor([[[ 1. , 2. ],
[ 3. , 4. ],
[ 5. , 6. ]]])
|
以上这篇PyTorch中permute的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_40231500/article/details/90606872