export_onnx.py 529 B

123456789101112131415161718
  1. import torch
  2. import torch.nn as nn
  3. from model import MobileNetV3_large
  4. weight = 'weights/best.pkl'
  5. net = MobileNetV3_large(num_classes = 3755)
  6. net.eval()
  7. if torch.cuda.is_available():
  8. net.cuda()
  9. net.load_state_dict(torch.load(weight))
  10. input_shape = (3, 224, 224)
  11. batch_size = 1
  12. x = torch.randn(batch_size, *input_shape) # 生成张量
  13. if torch.cuda.is_available():
  14. x = x.to(torch.device("cuda"))
  15. export_file="onnx/handwriting.onnx"
  16. torch.onnx.export(net,x, export_file,input_names=["input"],output_names=["output"])