unsqueeze:简单来说就是添加tensor的维度
举例说明:
代码
import torch x = torch.tensor([1, 2, 3])#dim=1,即(3) print('x: ', x) print('x.size: ', x.size()) x1 = torch.unsqueeze(x, 1)#x1变为(3,1)的矩阵 print('x1: ', x1) print('x1.size: ', x1.size()) x2 = torch.unsqueeze(x, 0)#x2变为(1,3)的矩阵 print('x2: ', x2) print('x2.size: ', x2.size())
结果
解析:我们初始的张量为Tensor([1,2,3]),输出size为[3]。而我们进行unsqueeze操作,即torch.unsqueeze(x, 1),得到x1的size为[3,1]。当我们进行torch.unsqueeze(x, 0)时,x2的size为[1,3]。
squeeze:简单来说就是删减tensor的维度(只能dim=1的维度)
举例说明:
存在一个tensor X 的size为[1,2,3,4,1,5,1]
torch.squeeze(x).size:删除了所有dim=1的维度,即size = [2,3,4,5]
torch.squeeze(x, 0).size:删除了第一维度的维度,由于值为1,所以删除成功即size = [2,3,4,1,5,1]
torch.squeeze(x, 1).size:删除了第二维度的维度,由于值为2,不等于1 所以删除失败即size = [1,2,3,4,1,5,1]