Pytorch:张量的形状操作

时间:2024-04-18 07:06:31

文章目录

      • 一、维度改变
        • 1.flatten展开
          • a.函数的基本用法
          • b.示例
        • 2.unsqueeze增维
          • a.函数的基本用法
          • b.示例
        • 3.squeeze降维
          • a.函数的基本用法
          • b.示例
      • 二、张量变形
        • 1.view()
          • a.函数的基本用法
          • b.参数:
          • c.注意事项
          • d.示例
        • 2.reshape()
          • a.注意事项
          • b.示例
        • 3.reshape_as()
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
      • 三、维度重排
        • 1.permute
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
        • 2.transpose
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意

维度改变和张量变形都不改变内存中存储的结构,因此改变后的张量的值顺序和没改变前是一样的。

一、维度改变

1.flatten展开
  • torch.flatten(tensor)
  • tensor.flatten()

torch.flatten() 是一个在 PyTorch 中常用于张量(tensor)处理的函数,它将输入张量展开成一个一维张量。该函数通常用于准备数据,将多维数据转换为一维,以便用于机器学习模型,特别是在模型的全连接层(fully connected layers)之前。
常用于展开成一维

a.函数的基本用法

只给定一个张量,将直接展开成一维。
torch.flatten(input, start_dim=0, end_dim=-1) 的参数解释如下:

  • input: 输入的张量。
  • start_dim: 开始展开的维度,默认为 0。这意味着从哪个维度开始将张量展开。
  • end_dim: 结束展开的维度,默认为 -1,即最后一个维度。这意味着展开将持续到哪个维度。
b.示例

考虑一个三维张量,例如形状为 (2, 3, 4) 的张量。如果使用 torch.flatten() 将其展开,可以有多种方式处理:

  1. 完全展开: 将整个张量展开成一维数组。

    import torch
    x = torch.randn(2, 3, 4)
    flat_x = torch.flatten(x)
    # 结果形状为 [24]
    
  2. 从特定维度开始展开: 指定从哪个维度开始展开。例如,从第一维(索引为 0 的维度)开始展开。

    flat_x = torch.flatten(x, start_dim=1)
    # 结果形状为 [2, 12],保留了第一个维度,其余维度被展开
    
2.unsqueeze增维
  • torch.unsqueeze(tensor)
  • tensor.unsqueeze()

torch.unsqueeze() 是 PyTorch 中用来增加张量的维度的函数。该函数可以在张量的指定位置插入一个维度,它非常有用于调整张量的形状,以满足特定操作或模型的需求,例如在单样本张量上应用需要批处理的模型。
常用于在第0个维度上增加大小为1的维度

a.函数的基本用法

torch.unsqueeze(input, dim) 的参数解释如下:

  • input: 输入的张量。
  • dim: 要插入新维度的索引位置。这个位置遵循 Python 的索引规则,支持负索引。
b.示例

假设有一个二维张量 x 形状为 (3, 4),表示一个包含3个样本,每个样本4个特征的数据集。如果需要在特定维度增加一个维度,可以使用 torch.unsqueeze() 如下:

import torch
x = torch.randn(3, 4)

# 在第0维增加一个维度
x_unsqueezed = x.unsqueeze(0)
print(x_unsqueezed.shape)
# 输出: torch.Size([1, 3, 4])

# 在第1维增加一个维度
x_unsqueezed = torch.unsqueeze(x, 1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 1, 4])

# 使用负索引,在最后一个维度后增加一个维度
x_unsqueezed = torch.unsqueeze(x, -1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 4, 1])
3.squeeze降维
  • torch.squeeze(tensor)
  • tensor.squeeze()

torch.squeeze() 是 PyTorch 中的一个函数,用于减少张量的维度,特别是去除那些维度大小为1的维度。这个函数非常有用于去除由于某些操作(比如 unsqueeze)产生的单一维度,从而使张量的形状更加紧凑。

a.函数的基本用法

只给定一个张量,将直接去掉所有大小为1的维度。
torch.squeeze(input, dim=None) 的参数解释如下:

  • input: 输入的张量。
  • dim: 指定要压缩的维度。如果指定的维度大小为1,则该维度会被去除如果大小不为1,则该维度不会被压缩如果不指定 dim 参数,那么所有大小为1的维度都会被去除。
b.示例

考虑一个张量 x,其形状包括一些大小为1的维度。以下是如何使用 torch.squeeze() 来去除这些维度的示例:

import torch
x = torch.randn(1, 3, 1, 5)

# 去除所有大小为1的维度
squeezed_x = x.squeeze()
print(squeezed_x.shape)
# 输出: torch.Size([3, 5])

# 只压缩第0维(大小为1)
squeezed_x = x.squeeze(0)
print(squeezed_x.shape)
# 输出: torch.Size([3, 1, 5])

# 只压缩第2维(大小为1)
squeezed_x = torch.squeeze(x, 2)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 5])

# 尝试压缩一个不是大小为1的维度(没有变化)
squeezed_x = torch.squeeze(x, 1)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 1, 5])

二、张量变形

1.view()

在 PyTorch 中,.view() 方法是一个非常重要且常用的功能,用于改变张量的形状而不改变其数据内容。此方法提供了一种高效的方式来重新排列张量的维度,使其适应不同的需求,例如输入到一个模型或对数据进行不同的操作。
view是共享内存的!

a.函数的基本用法

.view() 方法的基本用法是 tensor.view(*shape),其中 *shape 是希望张量拥有的新形状,由一组维度大小组成。

b.参数:
  • shape: 新的形状,是一个由整数构成的元组,其中的每个整数指定相应维度的大小。你也可以在某个位置使用 -1,让 PyTorch 自动计算该维度的大小。(注意某个位置是任意的某个位置,但是只能有一个)
c.注意事项
  1. 连续性.view() 要求张量在内存中是连续的(即一维数组中的元素顺序与多维视图中的顺序相同)。如果张量不是连续的,你可能需要首先调用 .contiguous() 方法来使其连续。

  2. 自动计算维度:使用 -1 作为形状参数的一部分,PyTorch 将自动计算该维度的正确大小,以便保持元素总数不变。

  3. 大小不变.view()要求张量变换形状之后的大小和变换之前的大小是一样的。即维度大小之积相等。比如tensor.Size([2,4])tensor.Size([8])是一样的。

d.示例
import torch
x = torch.randn(4, 4)  # 创建一个 4x4 的张量

# 改变形状为 2x8
y = x.view(2, 8)
print(y.shape)
# 输出: torch.Size([2, 8])

# 改变形状为 16(一维)
z = x.view(-1)#z = x.view(16)
print(z.shape)
# 输出: torch.Size([16])

# 使用 -1 自动计算维度
w = x.view(-1, 8)
print(w.shape)
# 输出: torch.Size([2, 8])
import torch
x = torch.randn(2, 1)  # 创建一个 2×1 的张量

# 改变形状为 2x8
y = x.view(2)
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=2 #共享内存,y也会变
print(x)
print(y)
tensor([-0.5001,  0.5409])
tensor([[2.0000],
        [0.5409]])
tensor([2.0000, 0.5409])
2.reshape()

在 PyTorch 中,.reshape() 方法用于改变张量的形状而不改变其数据内容。
这一方法与 .view() 类似,都允许您重新排列张量的维度,但它们在处理非连续张量时的行为不同。
只有当非连续张量时,才会导致和.view不一样,如果是连续的,同样也是共享内存的。

a.注意事项
  1. 数据连续性:与 .view() 相比,.reshape() 可以处理非连续张量,如果必要,它会自动处理数据的内存复制。因此,如果原始张量不连续,而你尝试用 .view() 改变其形状可能会导致错误,但 .reshape() 会自动解决这个问题。

  2. 自动计算维度:使用 -1 作为形状参数的一部分时,PyTorch 会自动计算该维度的大小,以确保总元素数量与原张量相同。

b.示例
import torch
x = torch.randn(2, 3, 4)  # 创建一个 2x3x4 的张量

# 改变形状为 6x4
y = x.reshape(6, 4)
print(y.shape)
# 输出: torch.Size([6, 4])

# 改变形状为 1x24
z = x.reshape(1, 24)
print(z.shape)
# 输出: torch.Size([1, 24])

# 使用 -1 自动计算维度
w = x.reshape(-1, 2)
print(w.shape)
# 输出: torch.Size([12, 2])
import torch
x = torch.randn(2, 2)  # 创建一个 2x1 的张量
x=x.transpose(0,1)
# 改变形状为 2x8
y = x.reshape(4)#转置后的x不是连续的,使用reshape产生复制,此时不能用.view()
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=100
print(x)
print(y)
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
tensor([[100.0000,  -0.1661],
        [ -0.3646,  -0.2516]])
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
3.reshape_as()

在 PyTorch 中,.reshape_as() 是一个方便的方法,用于将一个张量重新塑形为与另一个张量相同的形状。这个方法实质上是 .reshape() 方法的一个简化版本,它以另一个张量的形状为目标形状。
换句话说,.reshape_as()相当于是省略了自指定参数的.reshape(),而可以直接用目标张量形状作为形状。

a.函数的基本用法

.reshape_as() 的基本用法非常直接:tensor1.reshape_as(tensor2)。这会将 tensor1 的形状修改为与 tensor2 相同的形状。

b.参数:
  • tensor2: 这是模型张量,tensor1 将改变形状以匹配 tensor2 的形状。
c.示例
import torch
x = torch.randn(2, 3, 4)  # 原始张量,形状为 2x3x4
y = torch.randn(6, 4)     # 目标张量,形状为 6x4

# 将 x 的形状改变为与 y 相同
z = x.reshape_as(y)
print(z.shape)
# 输出: torch.Size([6, 4])
d.注意

虽然 .reshape_as() 很方便,但使用它时应确保两个张量具有相同的元素总数,因为改变形状的操作不会改变数据的总量。如果两个张量的总元素数量不匹配,尝试使用 .reshape_as() 将抛出错误。此外,如果原始张量在内存中是非连续的,.reshape_as() 会像 .reshape() 一样处理,可能需要在内部进行数据复制以确保连续性。

三、维度重排

permute方法可以按照指定顺序重新排列维度,而transpose方法可以交换张量的两个维度。用于需要进行维度重排或转置操作。如矩阵转置。

1.permute

在 PyTorch 中,.permute() 方法用于重新排列张量的维度,这是处理多维数据时一个非常有用的功能,尤其在需要对维度进行特定的重排序操作时。

a.函数的基本用法

.permute() 方法的调用格式为 tensor.permute(*dims),其中 *dims 是一个整数序列,代表新的维度排列顺序。

b.参数:
  • dims: 这个参数定义了张量的每个维度应该如何重新排列。序列中的每个整数都代表原始张量中一个维度的索引,这些索引的排列顺序确定了输出张量的形状。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量

# 改变维度的排列顺序为 [2, 0, 1]
y = x.permute(2, 0, 1)
print(y.shape)
# 输出: torch.Size([5, 2, 3])

# 将维度的排列顺序改为 [1, 2, 0]
z = x.permute(1, 2, 0)
print(z.shape)
# 输出: torch.Size([3, 5, 2])
d.注意
import torch
x = torch.tensor([[1,2,3,4],[2,4,2,4],[5,6,7,8]]) 
x = x.permute(1,0)
'''
tensor([[1, 2, 5],
        [2, 4, 6],
        [3, 2, 7],
        [4, 4, 8]])
'''

在 PyTorch 中,当使用 .permute() 方法重排张量维度时,张量的数据实际上在内存中的位置并没有改变。更准确地说.permute() 改变的是张量访问这些数据的方式,通过调整形状(shape)步长(stride) 的元信息,而不是数据本身。

  • 步长(Stride)
    • 步长是一个定义在每一维上的整数数组,表示为了在数据中从当前维度的一个元素移动到下一个元素,需要跨过的内存位置数。对于一个连续的张量,步长决定了元素在内存中的布局。

形状(Shape)和步长的调整当调用 .permute(1,0) 时,你实际上是告诉 PyTorch 以一个新的顺序来解释原始数据的内存布局。例如:

x = torch.tensor([[1, 2, 3, 4],
                  [2, 4, 2, 4],
                  [5, 6, 7, 8]])

原始的 x 的形状为 (3, 4),即有 3 行和 4 列。在 PyTorch 中,这意味着其步长为 (4, 1),其中 4 表示要从一行的开始移动到下一行的开始,在内存中需要跨过 4 个元素位置;1 表示在同一行中从一个元素移动到下一个元素,只需要跨过 1 个元素位置。

当你调用 x.permute(1, 0) 时,你是在指示 PyTorch 将原来的列视为行,将原来的行视为列。这就改变了形状为 (4, 3)。这时,步长变为 (1, 4)。这意味着:

  • 要从列的一个元素到下一个元素(现在变成了“行”移动),你只需要移动一个数据位置(原来的行移动)。
  • 要从一行移动到下一行(现在是原来的列跨行移动),你需要跨过 4 个数据位置。
2.transpose

在 PyTorch 中,.transpose() 方法用于交换张量中的两个维度,这是处理多维数组时一个常用的功能,尤其是在需要对特定的维度进行转置操作时。

a.函数的基本用法

.transpose() 方法的调用格式为 tensor.transpose(dim0, dim1),其中 dim0dim1 是要交换的维度的索引。

b.参数:
  • dim0: 第一个要交换的维度的索引。
  • dim1: 第二个要交换的维度的索引。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量

# 交换维度 0 和 1
y = x.transpose(0, 1)
print(y.shape)
# 输出: torch.Size([3, 2, 5])

# 交换维度 1 和 2
z = x.transpose(1, 2)
print(z.shape)
# 输出: torch.Size([2, 5, 3])
d.注意

.permute() 类似,.transpose() 也是返回原始数据的一个新视图,并不复制数据。因此,输出张量与输入张量共享同一块内存空间,只是它们的形状和步长(stride)不同。同样,.transpose() 会导致张量在内存中可能变为非连续,因此在某些情况下,可能需要调用 .contiguous() 来使张量在内存中连续。