pytorch 使用tensor混合:进行index操作

时间:2024-06-02 08:34:02
(Pdb) tmp = torch.randn(3,5)
(Pdb) indx = torch.tensor([1,0]).long()
(Pdb) temp(indx)
*** NameError: name ‘temp’ is not defined
(Pdb) tmp(indx)
*** TypeError: ‘Tensor’ object is not callable
(Pdb) tmp[indx]
tensor([[ 0.1633, 0.9389, 1.2806, -0.2525, 0.2817],
[ 0.6204, 0.5973, -1.7741, 0.3721, -0.5338]])
(Pdb) tmp
tensor([[ 0.6204, 0.5973, -1.7741, 0.3721, -0.5338],
[ 0.1633, 0.9389, 1.2806, -0.2525, 0.2817],
[ 0.4279, -0.2156, 2.4653, 0.3173, -0.0719]])
(Pdb) indx
tensor([1, 0])
(Pdb) indx2= torch.tensor([[1,0]]).long()
(Pdb) index2
*** NameError: name ‘index2’ is not defined
(Pdb) indx2
tensor([[1, 0]])
(Pdb) indx2.shape
torch.Size([1, 2])
(Pdb) tmp[indx2]
tensor([[[ 0.1633, 0.9389, 1.2806, -0.2525, 0.2817],
[ 0.6204, 0.5973, -1.7741, 0.3721, -0.5338]]])
(Pdb) tmp[indx2].shape
torch.Size([1, 2, 5])
(Pdb) tmp[:,indx2].shape
torch.Size([3, 1, 2])
(Pdb) tmp[:,indx2]
tensor([[[ 0.5973, 0.6204]],
[[ 0.9389, 0.1633]],
[[-0.2156, 0.4279]]])