文章目录
- 更新时间
- 环境
- 流程介绍
- 代码(onnx,darknet)
- 代码(pytorch)
- 参考文章
更新时间
最近更新: 2022-09
环境
Python 3.8
opencv 4.5.2.54
onnxruntime 1.10.0
pytorch 1.10.2
cryptography 3.1.1
流程介绍
加密:
- 模型以二进制存储
- 以二进制形式读取
- 用cryptography对文件加密
- 保存加密后文件
解密:
- 读取加密文件
- 使用cryptography解密
- 转换为框架可读数据
代码(onnx,darknet)
# 生成密钥
from cryptography.fernet import Fernet
key = Fernet.generate_key()
#保存license
#with open('license', 'wb') as fw:
#(key)
print(key)
#解密时 key修改成自己的就行
f = Fernet(key)
#模型加密
#原始模型路径
model_file = ['./','./','./']
#加密模型,名称根据自己需求更改
new_model_file = ['./','./','./']
for i in range(2):
#二进创建打开生成文件
with open(new_model_file[i],'wb') as ew:
#二进制读取模型文件
content = open(model_file[i],'rb').read()
#根据密钥解密文件
encrypted_content = f.encrypt(content)
# print(encrypted_content)
#保存到新文件
ew.write(encrypted_content)
# 模型解密
#使用opencv dnn模块读取darknet模型测试
import cv2
#二进制读取加密后的文件
conf_file = open(new_model_file[1],'rb').read()
#解密
conf_file = f.decrypt(conf_file)
#转换数据格式
conf_file = bytearray(conf_file)
# 与上一致
weight_flie = open(new_model_file[0],'rb').read()
weight_flie = f.decrypt(weight_flie)
weight_flie = bytearray(weight_flie)
#读取模型
net = cv2.dnn.readNetFromDarknet(conf_file,weight_flie)
print(net)
#使用onnxruntime模块读取onnx模型测试
import onnxruntime
onnx_file = open(new_model_file[2],'rb').read()
onnx_file = f.decrypt(onnx_file)
#onnx不需要转换数据格式
session = onnxruntime.InferenceSession(onnx_file)
print(session)
代码(pytorch)
# 生成密钥
from cryptography.fernet import Fernet
key = Fernet.generate_key()
#保存license
#with open('license', 'wb') as fw:
#(key)
print(key)
#解密时 key修改成自己的就行
f = Fernet(key)
#模型存储格式预处理
import io
import torch
#读取pytorch模型存为格式
model = torch.load('./',map_location='cpu')
#print(model)
byte = io.BytesIO()
torch.save(model, byte)
byte.seek(0)
#加密
#二进制方式读取模型
pth_bytes = byte.read()
#加密
new_model = f.encrypt(pth_bytes)
#保存加密文件
with open('new_model.pth', 'wb') as fw:
fw.write(new_model)
#解密
#读取加密后的文件
with open('new_model.pth', 'rb') as fr:
new_model = fr.read()
# cryptography 解密
new_model = f.decrypt(new_model)
#解密后的byte数据转为数据格式
model_byte = io.BytesIO(new_model)
model_byte.seek(0)
# 读取解密后的格式数据
model = torch.load(model_byte)
print(model)
参考文章
Pytorch模型加密的方法
神经网络如何加密(Python)