handwriting_socket.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import socketserver
  2. import struct
  3. import base64
  4. import numpy as np
  5. import cv2
  6. import time
  7. from PIL import Image
  8. from paddle.inference import Config
  9. from paddle.inference import create_predictor
  10. from handwriting_util import ResizeImage,CropImage,ToTensor,NormalizeImage
  11. # 分包接收长度
  12. BUFFER_SIZE = 1024
  13. # paddle模型文件
  14. MODEL_FILE = "output_2/inference/inference.pdmodel"
  15. MODEL_PARAM = "output_2/inference/inference.pdiparams"
  16. # 分配显存
  17. GPU_MEM = 8000
  18. # 前k个候选
  19. top_k = 10
  20. # 测试文件
  21. img_file = "data/test/test_deal.png"
  22. #标签文件
  23. label_file = "data/trade/labels.txt"
  24. # 建立Paddle预测器
  25. def create_paddle_perdictor():
  26. config=Config(MODEL_FILE, MODEL_PARAM)
  27. config.enable_use_gpu(GPU_MEM, 0)
  28. config.set_cpu_math_library_num_threads(10)
  29. config.disable_glog_info()
  30. config.switch_ir_optim(True)
  31. config.enable_memory_optim()
  32. # use zero copy
  33. config.switch_use_feed_fetch_ops(False)
  34. predictor = create_predictor(config)
  35. return predictor
  36. # 图像预处理,转为Tensor张量
  37. def preprocess(img):
  38. resize_op = ResizeImage(resize_short=256)
  39. img = resize_op(img)
  40. crop_op = CropImage(size=(224, 224))
  41. img = crop_op(img)
  42. img_mean = [0.485, 0.456, 0.406]
  43. img_std = [0.229, 0.224, 0.225]
  44. img_scale = 1.0 / 255.0
  45. normalize_op = NormalizeImage(
  46. scale=img_scale, mean=img_mean, std=img_std)
  47. img = normalize_op(img)
  48. tensor_op = ToTensor()
  49. img = tensor_op(img)
  50. return img
  51. # 结果分析处理
  52. def postprocess(output):
  53. output = output.flatten()
  54. classes = np.argpartition(output, -top_k)[-top_k:]
  55. classes = classes[np.argsort(-output[classes])]
  56. scores = output[classes]
  57. return classes, scores
  58. # 读取标签文件
  59. def readLabels():
  60. index = 0
  61. labels=[]
  62. with open(label_file) as file_obj:
  63. for line in file_obj:
  64. labels.append(line.strip())
  65. index=index+1
  66. return labels
  67. # 减少图像白边
  68. def cutImg(img):
  69. dst = 255- img
  70. gray = cv2.cvtColor(dst,cv2.COLOR_BGR2GRAY)
  71. ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
  72. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  73. minx = 65523
  74. miny = 65523
  75. maxx = 0
  76. maxy = 0
  77. for contour in contours:
  78. boundRect = cv2.boundingRect(contour)
  79. (x,y,w,h)=boundRect
  80. if x<minx:
  81. minx=x
  82. if y<miny:
  83. miny=y
  84. if x+w>maxx:
  85. maxx=x+w
  86. if y+h>maxy:
  87. maxy=y+h
  88. cutw=maxx-minx
  89. cuth=maxy-miny
  90. img = img[miny:maxy,minx:maxx]
  91. return img
  92. all_label = readLabels()
  93. predictor = create_paddle_perdictor()
  94. class MyTCPHandler(socketserver.BaseRequestHandler):
  95. def handle(self):
  96. try:
  97. self.data_len=self.request.recv(4)
  98. (byte_len,) = struct.unpack(">I", self.data_len)
  99. print("{} received byte len:".format(self.client_address),byte_len)
  100. # 分包接收
  101. content = b''
  102. while byte_len >0:
  103. if byte_len>=BUFFER_SIZE:
  104. tmp = self.request.recv(BUFFER_SIZE)
  105. content += tmp
  106. byte_len -= BUFFER_SIZE
  107. else:
  108. tmp = self.request.recv(byte_len)
  109. content += tmp
  110. byte_len = 0
  111. break
  112. base64content = content
  113. # print("{} received image:".format(self.client_address), base64content)
  114. # 进行预测
  115. input_names = predictor.get_input_names()
  116. input_tensor = predictor.get_input_handle(input_names[0])
  117. output_names = predictor.get_output_names()
  118. output_tensor = predictor.get_output_handle(output_names[0])
  119. imgString = base64.b64decode(base64content)
  120. nparr = np.fromstring(imgString,np.uint8)
  121. img_recv = cv2.imdecode(nparr,cv2.IMREAD_UNCHANGED)
  122. #make mask of where the transparent bits are
  123. trans_mask = img_recv[:,:,3] == 0
  124. #replace areas of transparency with white and not transparent
  125. img_recv[trans_mask] = [255, 255, 255, 255]
  126. img_save = cv2.cvtColor(img_recv, cv2.COLOR_BGRA2BGR)
  127. # cv2.imwrite("data/test/receive.png", img_save)
  128. img_save = cutImg(img_save)
  129. img = cv2.resize(img_save,(64,64))
  130. inputs = preprocess(img)
  131. inputs = np.expand_dims(
  132. inputs, axis=0).repeat(
  133. 1, axis=0).copy()
  134. input_tensor.copy_from_cpu(inputs)
  135. predictor.run()
  136. output = output_tensor.copy_to_cpu()
  137. classes, scores = postprocess(output)
  138. print(classes)
  139. result=""
  140. for index in classes:
  141. label=all_label[index]
  142. print(label)
  143. result = result+label
  144. print(result)
  145. self.request.sendall(bytes(result, encoding = "utf8"))
  146. except Exception as e:
  147. print("出现如下异常%s"%e)
  148. finally:
  149. print(self.client_address,"连接断开")
  150. self.request.close()
  151. def setup(self):
  152. print("before handle,连接建立:",self.client_address)
  153. def finish(self):
  154. print("finish run after handle")
  155. if __name__=="__main__":
  156. HOST,PORT = "0.0.0.0",8123
  157. server=socketserver.TCPServer((HOST,PORT),MyTCPHandler)
  158. print('启动端口8123')
  159. server.serve_forever()