Pytorch Tensor的奇妙运算

时间:2025-03-30 07:37:38

加法运算

1. 加号运算符 

同型时,效果等同于点加

import torch

a = ([1,2])
b = ([3,4])
c=a+b

print(c)
#tensor([4,6])


a = ([[1,2],[3,4]])
b = ([[5,6],[7,8]])
c=a+b

print(c)
#tensor([[6,8],[10,12]])

不同型时,奇妙加法

(a,1)+(b) = (a,b)

import torch

a=(64,1)
b=(32)
c=a+b
()
#([64, 32])

#相当于对a为1的最小那一维,先扩充到32
#再每个元素加上b中对应元素。

->减法运算符 ' - ' 理同

2. 与python native sum()

import torch

a = ([2,2])
b = ([2,2])

#如果用concat和()
c1 =([a,b],0)
print(c1)
#tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

print((dim=0))
#tensor([4., 4.])

print(())
#tensor(8.)

#以上两例可以看出,对应tensorflow里面的reduce_sum,是一定要降维的。
#不可能保持同型


#如果使用python的原生sum
c2=sum([a,b])
print(c2)
#tensor([[2., 2.],
        [2., 2.]])

#效果等于同形状的点加运算
#这个效果是做不到的!!!!!!

python native sum:

sum(iterable,start=0)

传入一个iterable的类型,比如我们传进去的是list of tensor。逐个访问,与start相加。

sum([a,b]) 

等价于 

sum([a,b], 0)

等价于

start=0

for i in [a,b]:
    start += i

#当a,b都是tensor时
#等价于 a+b
#所以才等价于 “点加”运算

个人经验:基于原生sum的可迭代特性,它可以处理比更多类型的数据。

比如我现在有一个list of tensor

lis = [a,b,c]

只能使用原生sum

out=sum(lis)

如果你使用out=(lis),会报错

'argument 'input' (position 1) must be Tensor, not list'

 3.不同形状的tensor加法

import torch

a=([[[1,2,3,4],[5,6,7,8.0]]])
b=([[[50,60,70,80]],[[10,20,30,40]],[[15,25,35,45]]])
c=a+b

print(())
print(())
print(())
print(c)

结果,这两个shape都不等的Tensor,居然能相加

([1, 2, 4])
([3, 1, 4])
([3, 2, 4])
tensor([[[51., 62., 73., 84.],
         [55., 66., 77., 88.]],

        [[11., 22., 33., 44.],
         [15., 26., 37., 48.]],

        [[16., 27., 38., 49.],
         [20., 31., 42., 53.]]])

换成Numpy,也是一样的

import torch
import numpy

a=([[[1,2,3,4],[5,6,7,8.0]]]).numpy()
b=([[[50,60,70,80]],[[10,20,30,40]],[[15,25,35,45]]]).numpy()
c=a+b
print()
print()
print()
print(c)

结果

(1, 2, 4)
(3, 1, 4)
(3, 2, 4)
[[[51. 62. 73. 84.]
  [55. 66. 77. 88.]]

 [[11. 22. 33. 44.]
  [15. 26. 37. 48.]]

 [[16. 27. 38. 49.]
  [20. 31. 42. 53.]]]

为什么能相加?

乘法运算

1.乘法运算符 * 

#(k,m)*(m) = (k,m)

#(k,m)*(m,1) ,报错

#(k,m)*(k,m) = (k,m)

#证明一个(4,3)*(3) = (4,3)

import torch

x = [
[1,1,1]
[1,1,1]
[1,1,1]
[1,1,1]
]

y=[1,2,3]

x= (x)
y= (y)

z = x*y

print(())
print(z)

#([4, 3])
#tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.]])

原因解释:

If two tensors x, y are “broadcastable”, the resulting tensor size is calculated as follows:

If the number of dimensions of x and y are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length.
Then, for each dimension size, the resulting dimension size is the max of the sizes of x and yalong that dimension.
--------------------- 
作者:lvhhh 
原文:/lvhao92/article/details/79757031 

2. ()函数

同型效果同点乘。

当x与y同型时

(y) == x*y

对应点相乘

不同型时,也会对较小者进行维度扩充。

3. ()函数 

效果同矩阵运算

矩阵相乘

(y) , 矩阵大小需满足: (i, n)x(n, j)

则 返回结果 (i,j)

除法

整个除法的大前提:

除数不要有0,不然有时候会在对应位置返回nan,有时候会raise balabala;

两边的数据类型都要求至少是float,不然有时候会强制转换,有时候会raise balabala,搞不懂搞不懂;

而又由于pytorch的tensor同dtype才能运算的大前提,所以如果一边是float32,另一边也得是float32,若一边64另一边跟着64;

怎么转类型看这里 《Pytorch 数据类型与转换》

1.除法运算符 /

同型效果同点除

2. ()函数

同型效果同点除

#保证两边都是float才不容易报错
a = (([160, 110]).float(), 0.137)

拼接运算

1. (sequeltial,dim)

这个比较简单

import torch

a=([[1,2,3],[5,6,7]])
b=([[4,5,6],[9,8,7]])

c = ((a,b),dim=0)
d = ((a,b),dim=1)

print(c)
print(d)

#tensor([[1., 2., 3.],
        [5., 6., 7.],
        [4., 5., 6.],
        [9., 8., 7.]])

#tensor([[1., 2., 3., 4., 5., 6.],
        [5., 6., 7., 9., 8., 7.]])

需要注意的是,()不能对torch标量使用。

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

torch标量是指用小写tensor创建的0维标量。

或者由(),()等创建的0维标量。

a=(1) 
b=(1)

print()
#[]

print()
#[(1)]

#
c=(b)
print()
#[]

()

这个就很神奇了!

怎么实现的很难解释,反正stack之后会提升一维。

两个(2,3) stack会变成(2,2,3),根据dim参数不同效果不同。

不好控制,推荐使用cat。

import torch

a=([[1,2,3],[5,6,7]]) #(2,3)
b=([[4,5,6],[9,8,7]]) #(2,3)

c = ((a,b),dim=0)
d = ((a,b),dim=1)

print(c)
print(d)

#tensor([[[1., 2., 3.],
         [5., 6., 7.]],

        [[4., 5., 6.],
         [9., 8., 7.]]])


#tensor([[[1., 2., 3.],
         [4., 5., 6.]],

        [[5., 6., 7.],
         [9., 8., 7.]]])


#解析
#dim=0,得到 = (2,2,3)
#dim=1, 得到 = (2,2,3)

取数运算

正好遇到一个需求。

我有m行k列的一个表a,和一个长为m的索引列表b。

b中储存着,取每行第几列的元素。

这种情况下,你用普通的索引是会失效的。

import torch
a= ([[1,2,3],[4,5,6]])
b= ([0,1])

c= a[b]
print(c)

#tensor([[1, 2, 3],
        [4, 5, 6]])

#不满足要求

经过一番查找,发现我们可以用神奇的()函数

import torch

a= ([[1,2,3],[4,5,6]])
b= ([0,1]).view(2,1)

c= (input=a,dim=1,index=b)
print(c)
#tensor([[1],
        [5]])


#成功满足需求
#要注意的是,b必须也是2维数组,有x维必须跟a一致。

维度调整

permute函数

这个跟reshape或者view是两码事

(3,5,7) -> view(3,7,5)

#和交换最后2维,是不一样的!


x = (2, 3, 5)
()
#([2, 3, 5])
(2, 0, 1).size()
#([5, 2, 3])