handwriting_trade.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3. from paddlex import transforms as T
  4. import paddlex as pdx
  5. # 定义训练和验证时的transforms
  6. # API说明https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/cls_transforms.html
  7. # train_transforms = T.Compose(
  8. # [T.RandomCrop(crop_size=224), T.RandomHorizontalFlip(), T.Normalize()])
  9. train_transforms = T.Compose(
  10. [T.Resize(target_size=224), T.Normalize()]
  11. )
  12. eval_transforms = T.Compose([
  13. T.Resize(target_size=224), T.Normalize()
  14. ])
  15. # 定义训练和验证所用的数据集
  16. # API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-imagenet
  17. train_dataset = pdx.datasets.ImageNet(
  18. data_dir='/home/xiongweixp/data/handwriting/train',
  19. file_list='/home/xiongweixp/data/handwriting/train/train_list.txt',
  20. label_list='/home/xiongweixp/data/handwriting/train/labels.txt',
  21. transforms=train_transforms,
  22. shuffle=True)
  23. eval_dataset = pdx.datasets.ImageNet(
  24. data_dir='/home/xiongweixp/data/handwriting/train',
  25. file_list='/home/xiongweixp/data/handwriting/train/val_list.txt',
  26. label_list='/home/xiongweixp/data/handwriting/train/labels.txt',
  27. transforms=eval_transforms)
  28. # 初始化模型,并进行训练
  29. # 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
  30. # model = pdx.cls.MobileNetV2(num_classes=len(train_dataset.labels))
  31. model = pdx.cls.MobileNetV3_small(num_classes=len(train_dataset.labels))
  32. # API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/classification.html#train
  33. # 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
  34. model.train(
  35. num_epochs=10,
  36. train_dataset=train_dataset,
  37. train_batch_size=64,
  38. eval_dataset=eval_dataset,
  39. lr_decay_epochs=[4, 6, 8],
  40. learning_rate=0.025,
  41. save_dir='output/mobilenetv3_small',
  42. use_vdl=True
  43. #resume_checkpoint='output/resnet/epoch_9'
  44. )