import torch import torch.nn as nn from model import MobileNetV3_large weight = 'weights/best.pkl' net = MobileNetV3_large(num_classes = 3755) net.eval() if torch.cuda.is_available(): net.cuda() net.load_state_dict(torch.load(weight)) input_shape = (3, 224, 224) batch_size = 1 x = torch.randn(batch_size, *input_shape) # 生成张量 if torch.cuda.is_available(): x = x.to(torch.device("cuda")) export_file="onnx/handwriting.onnx" torch.onnx.export(net,x, export_file,input_names=["input"],output_names=["output"])