文章目录
前言
本文旨在记录pytorch的API如何影响Tensor运算的‘内存共享性’和‘内存连续性’。’内存共享‘可以理解为浅拷贝;’内存连续’就是Tensor在信息区的内存空间上的连续性。 本文会结合代码介绍pytorch中的op是如何影响这两个性质的。
1、前置基础知识
1.1.Tensor的结构
因为涉及到Tensor的性质,因此,本节先简单回顾下Tensor的数据结构,Tensor包含信息区和存储区。信息区包含Tensor的一些维度信息(比如一个Tensor的shape=(2,3),变成(3,2),张量内容没变,变得只是我们看待这个张量的视角);存储区则是存储着数据。
深拷贝自然会同时拷贝两个区的内容;而维度变换操作往往仅影响信息区的内容,是为了减少张量计算中频繁的拷贝操作。
1.2.内存共享和内存连续API介绍
大家可先扫一眼下面的代码:这里简单介绍两个API,is_contiguous()能够判断一个Tensor的**信息区**上是否‘内存连续’;.data_ptr()能够返回张量在内存空间上的地址,可用于判断两个张量是否‘内存共享’。
# case 1: share contiguous And deepCopy?
x = torch.tensor([1,2,3], dtype=torch.float32)
y = x # shallow copy
print(x.is_contiguous(), y.is_contiguous()) # True, True
print(x.data_ptr() == y.data_ptr()) # True
# 若y发生额外的运算,此时pytorch会额外开辟新的内存,即转化成深拷贝!
y = y + 1 # x = [1,2,3], y = [2,3,4]
print(x.data_ptr() == y.data_ptr()) # False
我说下结论:说到底是python的语言特性。1)大多数赋值操作 = 全是浅拷贝,比如(y = x),因此,张量x和y内存连续且内存共享。也就是说:由于发生的是浅拷贝,即当我们对y做了某些op
后,对应的x的值也会发生变化。2)但千万不能对y做运算(比如y = y +1),此时就由浅拷贝转化成了深拷贝,即python内部会自动开辟一块新的内存来存储y,即此时x和y各自内存连续但已经不共享内存了。(可能会forward的计算图产生影响)。
本文会在第2部分介绍一些pytorch中哪些op会对Tensor的内存连续性产生影响;在第3部分介绍pytorch中哪些op会对Tensor的内存共享性产生影响。
2、内存连续性
2.1.维度变换操作(transpose, permute)
# -------------- transpose op ------------------ #
# transpose op
x = torch.arange(0,6).view(2,3)
print(x.is_contiguous()) # True
y = x # shallow copy
y = y.transpose(0,1) # 张量的信息区发生变更,但存储区没发生变更,x和y共用一块存储区
# True: dont destropy ;False: transpose op destroy share contigouse
print(x.is_contiguous, y.is_contiguous())
print(x.data_ptr() == y.data_ptr()) # True : share meomery
# -------------- permute op -------------------- #
# permute op
x = torch.arange(0,6).view(2,3)
print(x.is_contiguous()) # True
y = x # shallow copy
y = y.permute(1,0)
# True: dont destropy ;False: transpose op destroy share contigouse
print(x.is_contiguous, y.is_contiguous())
print(x.data_ptr() == y.data_ptr()) # True : share meomery
上述代码是pytorch中两个常用的维度变换op:transpose和permute。从上述代码可以看出,二者都会破坏了原始张量的内存连续性,更准确的说是破坏了信息区的内存连续性。但由于y是由x浅拷贝过来的,所以y和x共用一块存储区。
例外!!!这里有个例外就是存在dim=1的Tensor:若某Tensor的shape=(1,2,3),则调用transpose()/permute()时只有在交换后的维度的非0相对dim没变情况下才不会破坏信息区的内存连续性,即is_contiguous() == True;若破坏了非0dim的相对位置,则is_contiguous() == False。举个例子:比如交换后的shape变成(1,3,2)/(3,2,1)/(3,1,2),则破坏了内存连续性;若交换后shape变成(2,3,1)/(2,1,3),则依旧内存连续。
2.2.view和reshape
pytorch中另外两个常用的维度变化操作就是:view和reshape。先贴两段code,看是如何影响内存连续性和内存共享性的。
# ------- view op need contiguous----- #
x = torch.arange(0,6).view(2,3)
y = x.permute(1,0)
# Error: permute导致信息区的内存不连续,view操作会报错
y = y.view(2,3)
# -------- reshape op ---------------- #
x = torch.arange(0,6).view(2,3)
y = x.transpose(0,1) # y的信息区不连续
y = y.reshape(2,3) # 效果 == y.contiguous().view(2,3)
print(x.is_contiguous(), y.is_contiguous()) # true, true
print(x.data_ptr() == y.data_ptr()) # false
长话短说:在调用transpose和permute操作后,会破坏张量在信息区的内存连续性。而view操作需要张量的内存连续,否则会报错!而reshape则可以无脑使用:1)若Tensor本来内存连续,则调用reshape操作相当于调用view,并不会深拷贝源张量;2)若Tensor内存不连续,则reshape操作会首先深拷贝一份张量使其连续,然后在进行view操作。其效果等同于.contiguous().view(2,3)。
总的来说:view op不会深拷贝张量但需要内存连续;reshape op在张量内存不连续情况下会发生深拷贝!还有别忘了:.contiguous()方法会对张量进行深拷贝。
2.3.维度拼接:cat和stack op
# -------- torch.cat op ----------- #
x = torch.tensor([[1,2,3]], dtype=torch.float32)
y = torch.tensor([[4,5,6]], dtype=torch.float32)
z = torch.cat((x, y), dim=0)
print(z.data_ptr() == x.data_ptr()) # False
# -------- torch.stack op --------- #
v = torch.stack((x,y),dim = 0)
print(v.data_ptr() == x.data_ptr()) # False
这两个比较容易理解,拼接产生新的张量自然会开辟新的内存,且内存连续。
2.4. squeeze()和unsqueeze()
# -------- torch.squeeze op ----------- #
x = torch.tensor([[1,2,3]])
y = x.squeeze()
print(y.data_ptr() == x.data_ptr()) # True
# -------- torch.unsqueeze op --------- #
z = x.unsqueeze(0)
print(z.data_ptr() == x.data_ptr()) # True
一句话:squeeze()和unsqueeze()共享内存。而内存连续性则和源张量保持一致,即x是内存连续,则y和z也是内存连续;x不连续则y和z也不一致。
2.5. expand 和 repeat
x = torch.arange(0,6).view(1,2,3)
y = x.permute(0,2,1)
print(y.is_contiguous()) # False
y = y.expand(size=(2,3,2))
print(y.is_contiguous()) # False
print(y.data_ptr() == x.data_ptr()) # True
有了前面基础,这两个op就容易了,都是复制张量。expand内存共享;而repeat会深拷贝,内存不共享。内存连续性和源张量保持一致。
这里注意下expand,即内存共享,也就是说,pytorch调用expand后实际上并没有在内存中开辟新的内存存储数据。你将y的值进行修改的结果会同时把x的值也更改掉。
# ------------- expand op ----------------- #
x = torch.arange(0,6).view(1,2,3)
y = x
y = y.expand(size=(2,2,3))
print(y.data_ptr() == x.data_ptr()) # True
y[1][0][0] = 100
print(x[0][0]) # [100,1,2]
print(y[0][0]) # [100,1,2]
# ------------ repeat op ------------------ #
x = torch.arange(0,6).view(1,2,3)
y = x
y = y.repeat(repeats=(2,1,1))
print(y.data_ptr() == x.data_ptr()) # True
y[1][0][0] = 100
print(x[0][0]) # [0,1,2]
print(y[1][0]) # [100,1,2]
2.6. numpy和from_numpy内存共享
2.7.切片
x = torch.arange(0,6).view(1,2,3)
y = x[0][0]
print(y.data_ptr() == x.data_ptr()) # True
y[0] = 100
print(x) # [100,1,2]
浅拷贝,改变y的值会同时改变x的值。
总结
写这种文章阅读量不高,但是这些看似不起眼的知识往往会造成意想不到的错误。本文讲的是一些Tensor的偏底层的知识,而在用pytorch搭建神经网络过程中何时采用深浅拷贝,比如clone和detach等op,会对网络的梯度训练产生何种影响呢?敬请期待后续文章。