torch_train.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import random
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from torch.utils.data import DataLoader
  7. import torchvision.transforms as transforms
  8. import torch.optim as optim
  9. from matplotlib import pyplot as plt
  10. from torchvision.transforms.functional import scale
  11. from model import MobileNetV3_large
  12. from model import MobileNetV3_small
  13. import torchvision
  14. from torch.autograd import Variable
  15. from dataloader import CustomDataset
  16. from tqdm import tqdm
  17. #宏定义一些数据,如epoch数,batchsize等
  18. MAX_EPOCH=10
  19. BATCH_SIZE=64
  20. LR=0.0001
  21. log_interval=3
  22. val_interval=1
  23. cls_num = 3755
  24. # ============================ step 1/5 数据 ============================
  25. split_dir=os.path.join("/home/xiongweixp/data/handwriting/","pytorch_data")
  26. train_dir=os.path.join(split_dir,"train")
  27. valid_dir=os.path.join(split_dir,"validate")
  28. handwrite_mean=(0.877, 0.877, 0.877)
  29. handwrite_std=(0.200, 0.200, 0.200)
  30. mean = (0.485,0.456,0.406)
  31. std = (0.229,0.224,0.225)
  32. #对训练集所需要做的预处理
  33. train_transform=transforms.Compose([
  34. transforms.Resize((224,224)),
  35. # transforms.RandomResizedCrop(224),
  36. transforms.ToTensor(),
  37. transforms.Normalize(mean = handwrite_mean , std = handwrite_std)
  38. ])
  39. #对验证集所需要做的预处理
  40. valid_transform=transforms.Compose([
  41. transforms.Resize((224,224)),
  42. transforms.ToTensor(),
  43. transforms.Normalize(mean = handwrite_mean, std = handwrite_std)
  44. ])
  45. # 构建MyDataset实例
  46. train_data=CustomDataset(data_dir=train_dir,transform=train_transform)
  47. valid_data=CustomDataset(data_dir=valid_dir,transform=valid_transform)
  48. # 构建DataLoader
  49. # 训练集数据最好打乱
  50. # DataLoader的实质就是把数据集加上一个索引号,再返回
  51. train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
  52. valid_loader=DataLoader(dataset=valid_data,batch_size=BATCH_SIZE)
  53. # ============================ step 2/5 模型 ============================
  54. net=MobileNetV3_large(num_classes=cls_num)
  55. if torch.cuda.is_available():
  56. net.cuda()
  57. # ============================ step 3/5 损失函数 ============================
  58. criterion=nn.CrossEntropyLoss()
  59. # ============================ step 4/5 优化器 ============================
  60. optimizer=optim.Adam(net.parameters(),lr=LR, betas=(0.9, 0.99))# 选择优化器
  61. # ============================ step 5/5 训练 ============================
  62. # 记录每一次的数据,方便绘图
  63. train_curve=list()
  64. valid_curve=list()
  65. net.train()
  66. accurancy_global=0.0
  67. for epoch in range(MAX_EPOCH):
  68. loss_mean=0.
  69. correct=0.
  70. total=0.
  71. running_loss = 0.0
  72. process = tqdm(train_loader)
  73. for i,data in enumerate(process):
  74. img,label=data
  75. img = Variable(img)
  76. label = Variable(label)
  77. if torch.cuda.is_available():
  78. img=img.cuda()
  79. label=label.cuda()
  80. # 前向传播
  81. out=net(img)
  82. optimizer.zero_grad() # 归0梯度
  83. loss=criterion(out,label)#得到损失函数
  84. print_loss=loss.data.item()
  85. loss.backward()#反向传播
  86. optimizer.step()#优化
  87. if (i+1)%log_interval==0:
  88. # print('epoch:{},loss:{:.4f}'.format(epoch+1,loss.data.item()))
  89. process.set_description('epoch:{},loss:{:.4f}'.format(epoch+1,loss.data.item()))
  90. _, predicted = torch.max(out.data, 1)
  91. total += label.size(0)
  92. # print("============================================")
  93. # print("源数据标签:",label)
  94. # print("============================================")
  95. # print("预测结果:",predicted)
  96. # print("相等的结果为:",predicted == label)
  97. correct += (predicted == label).sum()
  98. print("============================================")
  99. accurancy=correct / total
  100. if accurancy>accurancy_global:
  101. torch.save(net.state_dict(), 'output3/weights/best.pkl')
  102. print("准确率由:", accurancy_global, "上升至:", accurancy, "已更新并保存权值为weights/best.pkl")
  103. accurancy_global=accurancy
  104. print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, 100*accurancy))
  105. torch.save(net.state_dict(), 'output3/weights/epoch%d.pkl'%(epoch))
  106. print("训练完毕,权重已保存为:weights/last.pkl")