torch.stack
是 PyTorch 中用于将一系列张量沿一个新的维度堆叠的函数。与 torch.cat
不同的是,torch.stack
会在指定的维度上增加一个新的维度,而不是将张量直接拼接。
基本用法
语法:
torch.stack(tensors, dim=0)
-
tensors
: 一个张量列表,包含多个形状相同的张量(shape 必须相同)。 -
dim
: 新增维度的位置,默认是0
。
举例说明
假设有三个形状为 (2, 3)
的张量:
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
c = torch.tensor([[13, 14, 15], [16, 17, 18]])
沿 dim=0
堆叠
stacked = torch.stack([a, b, c], dim=0)
print(stacked.shape) # torch.Size([3, 2, 3])
- 在维度
0
上增加一个新的维度,原始的(2, 3)
形状变成(3, 2, 3)
。 -
stacked
的第0
维度有3
个元素,对应原来的a
,b
,c
张量。
沿 dim=1
堆叠
stacked = torch.stack([a, b, c], dim=1)
print(stacked.shape) # torch.Size([2, 3, 3])
- 新的维度插入到原第
1
维的位置。 -
stacked
的第1
维度有3
个元素,对应原来的a
,b
,c
张量。
沿 dim=2
堆叠
stacked = torch.stack([a, b, c], dim=2)
print(stacked.shape) # torch.Size([2, 3, 3])
- 新的维度插入到原第
2
维的位置,形状变为(2, 3, 3)
。
torch.stack
的形状变化总结
假设堆叠前的每个张量形状是 (A, B, C)
,在 dim=0
、dim=1
和 dim=2
堆叠后的形状分别为:
-
dim=0
:(N, A, B, C)
-
dim=1
:(A, N, B, C)
-
dim=2
:(A, B, N, C)
其中 N
是堆叠的张量数量。
和torch.cat函数的区别:
cat
:在指定维度拼接多个张量。不增加维度。
c1 = torch.tensor([[1, 2], [3, 4]])
c2 = torch.tensor([[5, 6], [7, 8]])
c_cat = torch.cat([c1, c2], dim=0) # shape (4, 2)