handwriting_socket.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import socketserver
  2. import struct
  3. import base64
  4. import numpy as np
  5. import cv2
  6. import time
  7. from PIL import Image
  8. import torch
  9. import torch.nn as nn
  10. from torch.utils.data import DataLoader
  11. import torchvision.transforms as transforms
  12. from model import MobileNetV3_large
  13. # 分包接收长度
  14. BUFFER_SIZE = 1024
  15. # pytorch模型文件
  16. weights = 'weights/best.pkl'
  17. # 前k个候选
  18. top_k = 10
  19. # 均值、方差
  20. handwrite_mean=(0.877, 0.877, 0.877)
  21. handwrite_std=(0.200, 0.200, 0.200)
  22. #标签文件
  23. label_file = "data/labels.txt"
  24. # 读取标签文件
  25. def readLabels():
  26. index = 0
  27. labels=[]
  28. with open(label_file) as file_obj:
  29. for line in file_obj:
  30. labels.append(line.strip())
  31. index=index+1
  32. return labels
  33. # 减少图像白边
  34. def cutImg(img):
  35. dst = 255- img
  36. gray = cv2.cvtColor(dst,cv2.COLOR_BGR2GRAY)
  37. ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
  38. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  39. minx = 65523
  40. miny = 65523
  41. maxx = 0
  42. maxy = 0
  43. for contour in contours:
  44. boundRect = cv2.boundingRect(contour)
  45. (x,y,w,h)=boundRect
  46. if x<minx:
  47. minx=x
  48. if y<miny:
  49. miny=y
  50. if x+w>maxx:
  51. maxx=x+w
  52. if y+h>maxy:
  53. maxy=y+h
  54. cutw=maxx-minx
  55. cuth=maxy-miny
  56. img = img[miny:maxy,minx:maxx]
  57. return img
  58. all_label = readLabels()
  59. net = MobileNetV3_large(num_classes = 3755)
  60. net.eval()
  61. if torch.cuda.is_available():
  62. net.cuda()
  63. net.load_state_dict(torch.load(weights))
  64. transform = transforms.Compose([
  65. transforms.Resize((224, 224)),
  66. transforms.ToTensor(),
  67. transforms.Normalize(mean = handwrite_mean, std = handwrite_std)
  68. ])
  69. class MyTCPHandler(socketserver.BaseRequestHandler):
  70. def handle(self):
  71. try:
  72. self.data_len=self.request.recv(4)
  73. (byte_len,) = struct.unpack(">I", self.data_len)
  74. print("{} received byte len:".format(self.client_address),byte_len)
  75. # 分包接收
  76. content = b''
  77. while byte_len >0:
  78. if byte_len>=BUFFER_SIZE:
  79. tmp = self.request.recv(BUFFER_SIZE)
  80. content += tmp
  81. byte_len -= BUFFER_SIZE
  82. else:
  83. tmp = self.request.recv(byte_len)
  84. content += tmp
  85. byte_len = 0
  86. break
  87. base64content = content
  88. # print("{} received image:".format(self.client_address), base64content)
  89. # 进行预测
  90. imgString = base64.b64decode(base64content)
  91. nparr = np.fromstring(imgString,np.uint8)
  92. img_recv = cv2.imdecode(nparr,cv2.IMREAD_UNCHANGED)
  93. #make mask of where the transparent bits are
  94. trans_mask = img_recv[:,:,3] == 0
  95. #replace areas of transparency with white and not transparent
  96. img_recv[trans_mask] = [255, 255, 255, 255]
  97. img_save = cv2.cvtColor(img_recv, cv2.COLOR_BGRA2BGR)
  98. # cv2.imwrite("data/test/receive.png", img_save)
  99. img_save = cutImg(img_save)
  100. image = Image.fromarray(cv2.cvtColor(img_save,cv2.COLOR_BGR2RGB))
  101. img_tensor = transform(image).unsqueeze(0)
  102. if torch.cuda.is_available():
  103. img_tensor=img_tensor.cuda()
  104. net_output = net(img_tensor)
  105. print(net_output)
  106. _, predicted = torch.max(net_output.data, 1)
  107. result = predicted[0].item()
  108. print("预测的结果为:",result)
  109. top = torch.topk(input=net_output.data, k=top_k)
  110. topk = top[1].cpu().numpy()[0]
  111. result=""
  112. for index in topk:
  113. label=all_label[index]
  114. print(label)
  115. result = result+label
  116. print(result)
  117. self.request.sendall(bytes(result, encoding = "utf8"))
  118. except Exception as e:
  119. print("出现如下异常%s"%e)
  120. finally:
  121. print(self.client_address,"连接断开")
  122. self.request.close()
  123. def setup(self):
  124. print("before handle,连接建立:",self.client_address)
  125. def finish(self):
  126. print("finish run after handle")
  127. if __name__=="__main__":
  128. HOST,PORT = "0.0.0.0",8123
  129. server=socketserver.TCPServer((HOST,PORT),MyTCPHandler)
  130. print('启动端口8123')
  131. server.serve_forever()