pytorch 的pth格式模型转onnx格式模型 - python 实现
#-*-coding:utf-8-*-
# date:2021-10-5
# Author: DataBall - XIAN
# function: pytorch model 2 onnx
import os
import argparse
import torch
import torch.nn as nn
import numpy as np
from network.resnet import resnet18,resnet50
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=' Project handpose x')
parser.add_argument('--model_path', type=str, default = r'ckpt\resnet_18_epoch-275-x96.pth',
help = 'model_path') # 模型路径
parser.add_argument('--model', type=str, default = 'resnet_18',
help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2
shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0''') # 模型类型
parser.add_argument('--GPUS', type=str, default = '0',
help = 'GPUS') # GPU选择
parser.add_argument('--test_path', type=str, default = './image/',
help = 'test_path') # 测试图片路径
parser.add_argument('--img_size', type=tuple , default = (96,96),
help = 'img_size') # 输入模型图片尺寸
print('\n/******************* {} ******************/\n'.format(parser.description))
#--------------------------------------------------------------------------
ops = parser.parse_args()# 解析添加参数
#--------------------------------------------------------------------------
print('----------------------------------')
unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典
for key in unparsed.keys():
print('{} : {}'.format(key,unparsed[key]))
#---------------------------------------------------------------------------
os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS
test_path = ops.test_path # 测试图片文件夹路径
#---------------------------------------------------------------- 构建模型
print('use model : %s'%(ops.model))
if ops.model == 'resnet_50':
model_ = resnet50(img_size=ops.img_size[0])
elif ops.model == 'resnet_18':
model_ = resnet18(img_size=ops.img_size[0])
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model_ = model_.to(device)
model_.eval() # 设置为前向推断模式
# 加载测试模型
if os.access(ops.model_path,os.F_OK):# checkpoint
chkpt = torch.load(ops.model_path, map_location=device)
model_.load_state_dict(chkpt)
print('load test model : {}'.format(ops.model_path))
input_size = ops.img_size[0]
batch_size = 1 #批处理大小
input_shape = (3, input_size,input_size) #输入数据,改成自己的输入shape
print("input_size : ",input_size)
x = torch.randn(batch_size, *input_shape) # 生成张量
x = x.to(device)
export_onnx_file = "{}_size-{}.onnx".format(ops.model,input_size) # 目的ONNX文件名
torch.onnx.export(model_,
x,
export_onnx_file,
opset_version=9,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output2d"], # 输出名
#dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
# "output":{0:"batch_size"}}
)
脚本对应输出结果如下:
/******************* Project handpose x ******************/
----------------------------------
model_path : ckpt\resnet_18_epoch-275-x96.pth
model : resnet_18
GPUS : 0
test_path : ./image/
img_size : (96, 96)
use model : resnet_18
load test model : ckpt\resnet_18_epoch-275-x96.pth
input_size : 96
助力快速掌握数据集的信息和使用方式。
数据可以如此美好!