加法运算
1. 加号运算符
同型时,效果等同于点加
import torch
a = ([1,2])
b = ([3,4])
c=a+b
print(c)
#tensor([4,6])
a = ([[1,2],[3,4]])
b = ([[5,6],[7,8]])
c=a+b
print(c)
#tensor([[6,8],[10,12]])
不同型时,奇妙加法
(a,1)+(b) = (a,b)
import torch
a=(64,1)
b=(32)
c=a+b
()
#([64, 32])
#相当于对a为1的最小那一维,先扩充到32
#再每个元素加上b中对应元素。
->减法运算符 ' - ' 理同
2. 与python native sum()
import torch
a = ([2,2])
b = ([2,2])
#如果用concat和()
c1 =([a,b],0)
print(c1)
#tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]])
print((dim=0))
#tensor([4., 4.])
print(())
#tensor(8.)
#以上两例可以看出,对应tensorflow里面的reduce_sum,是一定要降维的。
#不可能保持同型
#如果使用python的原生sum
c2=sum([a,b])
print(c2)
#tensor([[2., 2.],
[2., 2.]])
#效果等于同形状的点加运算
#这个效果是做不到的!!!!!!
python native sum:
sum(iterable,start=0)
传入一个iterable的类型,比如我们传进去的是list of tensor。逐个访问,与start相加。
sum([a,b])
等价于
sum([a,b], 0)
等价于
start=0
for i in [a,b]:
start += i
#当a,b都是tensor时
#等价于 a+b
#所以才等价于 “点加”运算
个人经验:基于原生sum的可迭代特性,它可以处理比更多类型的数据。
比如我现在有一个list of tensor
lis = [a,b,c]
只能使用原生sum
out=sum(lis)
如果你使用out=(lis),会报错
'argument 'input' (position 1) must be Tensor, not list'
3.不同形状的tensor加法
import torch
a=([[[1,2,3,4],[5,6,7,8.0]]])
b=([[[50,60,70,80]],[[10,20,30,40]],[[15,25,35,45]]])
c=a+b
print(())
print(())
print(())
print(c)
结果,这两个shape都不等的Tensor,居然能相加
([1, 2, 4])
([3, 1, 4])
([3, 2, 4])
tensor([[[51., 62., 73., 84.],
[55., 66., 77., 88.]],
[[11., 22., 33., 44.],
[15., 26., 37., 48.]],
[[16., 27., 38., 49.],
[20., 31., 42., 53.]]])
换成Numpy,也是一样的
import torch
import numpy
a=([[[1,2,3,4],[5,6,7,8.0]]]).numpy()
b=([[[50,60,70,80]],[[10,20,30,40]],[[15,25,35,45]]]).numpy()
c=a+b
print()
print()
print()
print(c)
结果
(1, 2, 4)
(3, 1, 4)
(3, 2, 4)
[[[51. 62. 73. 84.]
[55. 66. 77. 88.]]
[[11. 22. 33. 44.]
[15. 26. 37. 48.]]
[[16. 27. 38. 49.]
[20. 31. 42. 53.]]]
为什么能相加?
乘法运算
1.乘法运算符 *
#(k,m)*(m) = (k,m)
#(k,m)*(m,1) ,报错
#(k,m)*(k,m) = (k,m)
#证明一个(4,3)*(3) = (4,3)
import torch
x = [
[1,1,1]
[1,1,1]
[1,1,1]
[1,1,1]
]
y=[1,2,3]
x= (x)
y= (y)
z = x*y
print(())
print(z)
#([4, 3])
#tensor([[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]])
原因解释:
If two tensors x, y are “broadcastable”, the resulting tensor size is calculated as follows:
If the number of dimensions of x and y are not equal, prepend 1 to the dimensions of the tensor with fewer dimensions to make them equal length.
Then, for each dimension size, the resulting dimension size is the max of the sizes of x and yalong that dimension.
---------------------
作者:lvhhh
原文:/lvhao92/article/details/79757031
2. ()函数
同型效果同点乘。
当x与y同型时
(y) == x*y
对应点相乘
不同型时,也会对较小者进行维度扩充。
3. ()函数
效果同矩阵运算
矩阵相乘
(y) , 矩阵大小需满足: (i, n)x(n, j)
则 返回结果 (i,j)
除法
整个除法的大前提:
除数不要有0,不然有时候会在对应位置返回nan,有时候会raise balabala;
两边的数据类型都要求至少是float,不然有时候会强制转换,有时候会raise balabala,搞不懂搞不懂;
而又由于pytorch的tensor同dtype才能运算的大前提,所以如果一边是float32,另一边也得是float32,若一边64另一边跟着64;
怎么转类型看这里 《Pytorch 数据类型与转换》
1.除法运算符 /
同型效果同点除
2. ()函数
同型效果同点除
#保证两边都是float才不容易报错
a = (([160, 110]).float(), 0.137)
拼接运算
1. (sequeltial,dim)
这个比较简单
import torch
a=([[1,2,3],[5,6,7]])
b=([[4,5,6],[9,8,7]])
c = ((a,b),dim=0)
d = ((a,b),dim=1)
print(c)
print(d)
#tensor([[1., 2., 3.],
[5., 6., 7.],
[4., 5., 6.],
[9., 8., 7.]])
#tensor([[1., 2., 3., 4., 5., 6.],
[5., 6., 7., 9., 8., 7.]])
需要注意的是,()不能对torch标量使用。
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenatedtorch标量是指用小写tensor创建的0维标量。
或者由(),()等创建的0维标量。
a=(1)
b=(1)
print()
#[]
print()
#[(1)]
#
c=(b)
print()
#[]
()
这个就很神奇了!
怎么实现的很难解释,反正stack之后会提升一维。
两个(2,3) stack会变成(2,2,3),根据dim参数不同效果不同。
不好控制,推荐使用cat。
import torch
a=([[1,2,3],[5,6,7]]) #(2,3)
b=([[4,5,6],[9,8,7]]) #(2,3)
c = ((a,b),dim=0)
d = ((a,b),dim=1)
print(c)
print(d)
#tensor([[[1., 2., 3.],
[5., 6., 7.]],
[[4., 5., 6.],
[9., 8., 7.]]])
#tensor([[[1., 2., 3.],
[4., 5., 6.]],
[[5., 6., 7.],
[9., 8., 7.]]])
#解析
#dim=0,得到 = (2,2,3)
#dim=1, 得到 = (2,2,3)
取数运算
正好遇到一个需求。
我有m行k列的一个表a,和一个长为m的索引列表b。
b中储存着,取每行第几列的元素。
这种情况下,你用普通的索引是会失效的。
import torch
a= ([[1,2,3],[4,5,6]])
b= ([0,1])
c= a[b]
print(c)
#tensor([[1, 2, 3],
[4, 5, 6]])
#不满足要求
经过一番查找,发现我们可以用神奇的()函数
import torch
a= ([[1,2,3],[4,5,6]])
b= ([0,1]).view(2,1)
c= (input=a,dim=1,index=b)
print(c)
#tensor([[1],
[5]])
#成功满足需求
#要注意的是,b必须也是2维数组,有x维必须跟a一致。
维度调整
permute函数
这个跟reshape或者view是两码事
(3,5,7) -> view(3,7,5)
#和交换最后2维,是不一样的!
x = (2, 3, 5)
()
#([2, 3, 5])
(2, 0, 1).size()
#([5, 2, 3])