PyTorch permute的用法

软件发布|下载排行|最新软件

当前位置:首页IT学院IT技术

PyTorch permute的用法

York1996   2022-06-03 我要评论

permute(dims)

将tensor的维度换位。

参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。

例:

import torch
import numpy    as np

a=np.array([[[1,2,3],[4,5,6]]])

unpermuted=torch.tensor(a)
print(unpermuted.size())  #  ——>  torch.Size([1, 2, 3])

permuted=unpermuted.permute(2,0,1)
print(permuted.size())     #  ——>  torch.Size([3, 1, 2])

 再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

利用这个函数permute(0,2,1)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成

tensor([[[1., 4.],
        [2., 5.],
        [3., 6.]]])

如果使用view,可以得到

tensor([[[1., 2.],
         [3., 4.],
         [5., 6.]]])

关于view的用法:参见PyTorch中view的用法 

附:permute(多维数组,[维数的组合])

比如:

a=rand(2,3,4);  %这是一个三维数组,各维的长度分别为:2,3,4

%现在交换第一维和第二维:

permute(A,[2,1,3])  %变成3*2*4的矩阵

import torch
import numpy    as np
 
a=np.array([[[1,2,3],[4,5,6]]])
 
unpermuted=torch.tensor(a)
print(unpermuted.size())  #  ——>  torch.Size([1, 2, 3])
 
tensor([[[1., 4.],
        [2., 5.],
        [3., 6.]]])
 
permuted=unpermuted.permute(2,0,1)
print(permuted.size())     #  ——>  torch.Size([3, 1, 2])
 
tensor([[[1., 2.],
         [3., 4.],
         [5., 6.]]])

总结

Copyright 2022 版权所有 软件发布 访问手机版

声明:所有软件和文章来自软件开发商或者作者 如有异议 请与本站联系 联系我们