pytorch中的unsqueeze以及squeeze用法举例

时间:2024-03-24 18:42:21

unsqueeze:简单来说就是添加tensor的维度

举例说明:

代码

import torch
x = torch.tensor([1, 2, 3])#dim=1,即(3)
print('x: ', x)
print('x.size: ', x.size())
x1 = torch.unsqueeze(x, 1)#x1变为(3,1)的矩阵
print('x1: ', x1)
print('x1.size: ', x1.size())
x2 = torch.unsqueeze(x, 0)#x2变为(1,3)的矩阵
print('x2: ', x2)
print('x2.size: ', x2.size())

结果

pytorch中的unsqueeze以及squeeze用法举例

解析:我们初始的张量为Tensor([1,2,3]),输出size为[3]。而我们进行unsqueeze操作,即torch.unsqueeze(x, 1),得到x1的size为[3,1]。当我们进行torch.unsqueeze(x, 0)时,x2的size为[1,3]。

 

squeeze:简单来说就是删减tensor的维度(只能dim=1的维度)

举例说明:

存在一个tensor X 的size为[1,2,3,4,1,5,1]

torch.squeeze(x).size:删除了所有dim=1的维度,即size = [2,3,4,5]

torch.squeeze(x, 0).size:删除了第一维度的维度,由于值为1,所以删除成功即size = [2,3,4,1,5,1]

torch.squeeze(x, 1).size:删除了第二维度的维度,由于值为2,不等于1 所以删除失败即size = [1,2,3,4,1,5,1]