123456789101112131415161718 |
- import torch
- import torch.nn as nn
- from model import MobileNetV3_large
- weight = 'weights/best.pkl'
- net = MobileNetV3_large(num_classes = 3755)
- net.eval()
- if torch.cuda.is_available():
- net.cuda()
- net.load_state_dict(torch.load(weight))
- input_shape = (3, 224, 224)
- batch_size = 1
- x = torch.randn(batch_size, *input_shape) # 生成张量
- if torch.cuda.is_available():
- x = x.to(torch.device("cuda"))
- export_file="onnx/handwriting.onnx"
- torch.onnx.export(net,x, export_file,input_names=["input"],output_names=["output"])
|