要学会看源码啊...
最近google 上不去,就用bing来举例好了,一般来讲,看了各种example代码以后,发现transformer 里面有一个 totensor,所以直接搜索,看里面的文档,在文档里面搜索to_tensor,里面有对这个函数的解释,然后查看源码 ,源码如下:
def to_tensor(pic):
"""Convert a ``PIL Image`` or ```` to tensor.
See ``ToTensor`` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if isinstance(pic, ):
# handle numpy array
if == 2:
pic = pic[:, :, None]
img = torch.from_numpy(((2, 0, 1)))
# backward compatibility
if isinstance(img, ):
return ().div(255)
else:
return img
if accimage is not None and isinstance(pic, ):
nppic = ([, , ], dtype=np.float32)
(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if == 'I':
img = torch.from_numpy((pic, np.int32, copy=False))
elif == 'I;16':
img = torch.from_numpy((pic, np.int16, copy=False))
elif == 'F':
img = torch.from_numpy((pic, np.float32, copy=False))
elif == '1':
img = 255 * torch.from_numpy((pic, np.uint8, copy=False))
else:
img = (.from_buffer(()))
# PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
if == 'YCbCr':
nchannel = 3
elif == 'I;16':
nchannel = 1
else:
nchannel = len()
img = ([1], [0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = (0, 1).transpose(0, 2).contiguous()
if isinstance(img, ):
return ().div(255)
else:
return img
可以看到基本上就是调用了 torch.from_numpy()这个函数,并且里面用 吧一个PIL图像转换成为 numpy数组,然后利用view函数 ,紧接着利用 transpose直接是转置一下,最后再除以255.
PS:彩蛋,看最新的torch vision,里面有一个注释很有意思:"# yikes, this transpose takes 80% of the loading time/CPU"说明这个代码还有提升的空间:)