最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下*上的回答并加上自己的理解。
在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。
这些操作是:
narrow(),view(),expand()和transpose()
举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。
转置的tensor和原tensor的内存是共享的!
为了证明这一点,我们来看下面的代码:
1
2
3
4
5
|
x = torch.randn( 3 , 2 )
y = x.transpose(x, 0 , 1 )
x[ 0 , 0 ] = 233
print (y[ 0 , 0 ])
# print 233
|
可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。
也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。
在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。
注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!
当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。
一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:
RuntimeError: input is not contiguous
只要看到这个错误提示,加上contiguous()就好啦~
补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax
gather
torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.
input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
import torch
a = torch.tensor([[. 1 ,. 2 ,. 3 ],
[ 1.1 , 1.2 , 1.3 ],
[ 2.1 , 2.2 , 2.3 ],
[ 3.1 , 3.2 , 3.3 ]])
b = torch.LongTensor([[ 1 , 2 , 1 ],
[ 2 , 2 , 2 ],
[ 2 , 2 , 2 ],
[ 1 , 1 , 0 ]])
b = b.view( 4 , 3 )
print (a.gather( 1 ,b))
print (a.gather( 0 ,b))
c = torch.LongTensor([ 1 , 2 , 0 , 1 ])
c = c.view( 4 , 1 )
print (a.gather( 1 ,c))
|
输出:
1
2
3
4
5
6
7
8
9
10
11
12
|
tensor([[ 0.2000 , 0.3000 , 0.2000 ],
[ 1.3000 , 1.3000 , 1.3000 ],
[ 2.3000 , 2.3000 , 2.3000 ],
[ 3.2000 , 3.2000 , 3.1000 ]])
tensor([[ 1.1000 , 2.2000 , 1.3000 ],
[ 2.1000 , 2.2000 , 2.3000 ],
[ 2.1000 , 2.2000 , 2.3000 ],
[ 1.1000 , 1.2000 , 0.3000 ]])
tensor([[ 0.2000 ],
[ 1.3000 ],
[ 2.1000 ],
[ 3.2000 ]])
|
squeeze
将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)
1
2
3
4
|
import torch
a = torch.randn( 2 , 1 , 1 , 3 )
print (a)
print (a.squeeze())
|
输出:
1
2
3
4
|
tensor([[[[ - 0.2320 , 0.9513 , 1.1613 ]]],
[[[ 0.0901 , 0.9613 , - 0.9344 ]]]])
tensor([[ - 0.2320 , 0.9513 , 1.1613 ],
[ 0.0901 , 0.9613 , - 0.9344 ]])
|
expand
扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)
1
2
3
4
5
|
import torch
x = torch.randn( 2 , 2 , 1 )
print (x)
y = x.expand( 2 , 2 , 3 )
print (y)
|
输出:
1
2
3
4
5
6
7
8
9
10
|
tensor([[[ 0.0608 ],
[ 2.2106 ]],
[[ - 1.9287 ],
[ 0.8748 ]]])
tensor([[[ 0.0608 , 0.0608 , 0.0608 ],
[ 2.2106 , 2.2106 , 2.2106 ]],
[[ - 1.9287 , - 1.9287 , - 1.9287 ],
[ 0.8748 , 0.8748 , 0.8748 ]]])
|
sum
size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量
1
2
3
4
|
import torch
a = torch.tensor([[[ 1 , 2 , 3 ],[ 4 , 8 , 12 ]],[[ 1 , 2 , 3 ],[ 4 , 8 , 12 ]]])
print (a. sum ())
print (a. sum (dim = 1 ))
|
输出:
1
2
3
|
tensor( 60 )
tensor([[ 5 , 10 , 15 ],
[ 5 , 10 , 15 ]])
|
contiguous
返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。
可以通过is_contiguous查看张量内存是否连续。
1
2
3
4
|
import torch
a = torch.tensor([[[ 1 , 2 , 3 ],[ 4 , 8 , 12 ]],[[ 1 , 2 , 3 ],[ 4 , 8 , 12 ]]])
print (a.is_contiguous)
print (a.contiguous().view( 4 , 3 ))
|
输出:
1
2
3
4
5
|
<built - in method is_contiguous of Tensor object at 0x7f4b5e35afa0 >
tensor([[ 1 , 2 , 3 ],
[ 4 , 8 , 12 ],
[ 1 , 2 , 3 ],
[ 4 , 8 , 12 ]])
|
softmax
假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。
1
2
3
4
5
|
import torch
import torch.nn.functional as F
a = torch.tensor([[ 1. , 1 ],[ 2 , 1 ],[ 3 , 1 ],[ 1 , 2 ],[ 1 , 3 ]])
b = F.softmax(a,dim = 1 )
print (b)
|
输出:
1
2
3
4
5
|
tensor([[ 0.5000 , 0.5000 ],
[ 0.7311 , 0.2689 ],
[ 0.8808 , 0.1192 ],
[ 0.2689 , 0.7311 ],
[ 0.1192 , 0.8808 ]])
|
max
返回最大值,或指定维度的最大值以及index
1
2
3
4
5
6
7
|
import torch
a = torch.tensor([[. 1 ,. 2 ,. 3 ],
[ 1.1 , 1.2 , 1.3 ],
[ 2.1 , 2.2 , 2.3 ],
[ 3.1 , 3.2 , 3.3 ]])
print (a. max (dim = 1 ))
print (a. max ())
|
输出:
1
2
|
(tensor([ 0.3000 , 1.3000 , 2.3000 , 3.3000 ]), tensor([ 2 , 2 , 2 , 2 ]))
tensor( 3.3000 )
|
argmax
返回最大值的index
1
2
3
4
5
6
7
|
import torch
a = torch.tensor([[. 1 ,. 2 ,. 3 ],
[ 1.1 , 1.2 , 1.3 ],
[ 2.1 , 2.2 , 2.3 ],
[ 3.1 , 3.2 , 3.3 ]])
print (a.argmax(dim = 1 ))
print (a.argmax())
|
输出:
1
2
|
tensor([ 2 , 2 , 2 , 2 ])
tensor( 11 )
|
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。如有错误或未考虑完全的地方,望不吝赐教。
原文链接:https://blog.csdn.net/gdymind/article/details/82662502