如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
# -*- coding: utf-8 -*-
# @Time : 2018/1/17 16:37
# @Author : Zhiwei Zhong
# @Site :
# @File : Numpy_Pytorch.py
# @Software: PyCharm
import torch
import numpy as np
np_data = np.arange( 6 ).reshape(( 2 , 3 ))
# numpy 转为 pytorch格式
torch_data = torch.from_numpy(np_data)
print (
'\n numpy' , np_data,
'\n torch' , torch_data,
)
'''
numpy [[0 1 2]
[3 4 5]]
torch
0 1 2
3 4 5
[torch.LongTensor of size 2x3]
'''
# torch 转为numpy
tensor2array = torch_data.numpy()
print (tensor2array)
"""
[[0 1 2]
[3 4 5]]
"""
# 运算符
# abs 、 add 、和numpy类似
data = [[ 1 , 2 ], [ 3 , 4 ]]
tensor = torch.FloatTensor(data) # 转为32位浮点数,torch接受的都是Tensor的形式,所以运算前先转化为Tensor
print (
'\n numpy' , np.matmul(data, data),
'\n torch' , torch.mm(tensor, tensor) # torch.dot()是点乘
)
'''
numpy [[ 7 10]
[15 22]]
torch
7 10
15 22
[torch.FloatTensor of size 2x2]
'''
|
以上这篇浅谈pytorch和Numpy的区别以及相互转换方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_34535410/article/details/79088952