torch.max()
1.
torch.max()简单来说是返回一个tensor中的最大值。
例如:
1
2
3
4
5
6
7
8
9
|
>>> si = torch.randn( 4 , 5 )
>>> print (si)
tensor([[ 1.1659 , - 1.5195 , 0.0455 , 1.7610 , - 0.2064 ],
[ - 0.3443 , 2.0483 , 0.6303 , 0.9475 , 0.4364 ],
[ - 1.5268 , - 1.0833 , 1.6847 , 0.0145 , - 0.2088 ],
[ - 0.8681 , 0.1516 , - 0.7764 , 0.8244 , - 1.2194 ]])
>>> print (torch. max (si))
tensor( 2.0483 )
|
2.
这个函数的参数中还有一个dim参数,使用方法为re = torch.max(Tensor,dim),返回的re为一个二维向量,其中re[0]为最大值的Tensor,re[1]为最大值对应的index的Tensor。
例如:
1
2
|
>>> print (torch. max (si, 0 )[ 0 ])
tensor([ 1.1659 , 2.0483 , 1.6847 , 1.7610 , 0.4364 ])
|
注意,Tensor的维度从0开始算起。在torch.max()中指定了dim之后,比如对于一个3x4x5的Tensor,指定dim为0后,得到的结果是维度为0的“每一行”对应位置求最大的那个值,此时输出的Tensor的维度是4x5.
对于简单的二维Tensor,如上面例子的这个4x5的Tensor。指定dim为0,则给出的结果是4行做比较之后的最大值;如果指定dim为1,则给出的结果是5列做比较之后的最大值,且此处做比较时是按照位置分别做比较,得到一个新的Tensor。
Tensor.view()
简单说就是一个把tensor 进行reshape的操作。
1
2
3
4
|
>>> a = torch.randn( 3 , 4 , 5 , 7 )
>>> b = a.view( 1 , - 1 )
>>> print (b.size())
torch.Size([ 1 , 420 ])
|
其中参数-1表示剩下的值的个数一起构成一个维度。如上例中,第一个参数1将第一个维度的大小设定成1,后一个-1就是说第二个维度的大小=元素总数目/第一个维度的大小,此例中为3*4*5*7/1=420.
1
2
3
4
5
6
7
8
9
|
>>> d = a.view(a.size( 0 ),a.size( 1 ), - 1 )
>>> print (d.size())
torch.Size([ 3 , 4 , 35 ])
>>> e = a.view( 4 , - 1 , 5 )
>>> print (e.size())
torch.Size([ 4 , 21 , 5 ])
|
以上这篇pytorch中torch.max和Tensor.view函数用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_38469553/article/details/85290207