**
一、where函数
**
(condition,x,y)
out = x,if condition is 1
= y ,if condition is 0
In [29]: cond = torch.rand(2,2)
In [30]: cond
Out[30]:
tensor([[0.1326, 0.4126],
[0.7093, 0.5339]])
In [31]: a = torch.zeros(2,2)
In [32]: a
Out[32]:
tensor([[0., 0.],
[0., 0.]])
In [33]: b = torch.ones(2,2)
In [34]: b
Out[34]:
tensor([[1., 1.],
[1., 1.]])
In [35]: torch.where(cond>0.5,a,b)
Out[35]:
tensor([[1., 1.],
**
二、gather函数
**
(input,dim,index,out=None)
In [1]: import torch
In [2]: prob = torch.randn(4,10)
In [3]: prob
Out[3]:
tensor([[ 0.8383, 0.2332, 1.1231, 0.4929, 2.1630, 0.9328, -0.9775, 2.9904,
1.7534, 0.2515],
[-1.1460, 0.3640, -0.6829, -2.0924, 0.2590, 1.8114, 0.8341, -0.6201,
-0.8322, -0.1316],
[ 2.3255, -0.4369, 0.6470, 1.0118, -0.4143, 0.5650, -0.2035, -0.1714,
-1.1866, -0.0068],
[ 0.9149, -0.3263, -0.3857, 2.1448, -0.2073, 1.5119, -1.1741, -1.6586,
0.1826, 1.0848]])
In [4]: idx = prob.topk(dim=1,k=3)
In [5]: idx = idx[1]
In [6]: idx
Out[6]:
tensor([[7, 4, 8],
[5, 6, 1],
[0, 3, 2],
[3, 5, 9]])
In [7]: label = torch.arange(10)+100
In [8]: label
Out[8]: tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
In [9]: torch.gather(label.expand(4,10),dim=1,index=idx.long())
Out[9]:
tensor([[107, 104, 108],
[105, 106, 101],
[100, 103, 102],
[103, 105, 109]])