import socketserver import struct import base64 import numpy as np import cv2 import time from PIL import Image from paddle.inference import Config from paddle.inference import create_predictor from handwriting_util import ResizeImage,CropImage,ToTensor,NormalizeImage # 分包接收长度 BUFFER_SIZE = 1024 # paddle模型文件 MODEL_FILE = "output_2/inference/inference.pdmodel" MODEL_PARAM = "output_2/inference/inference.pdiparams" # 分配显存 GPU_MEM = 8000 # 前k个候选 top_k = 10 # 测试文件 img_file = "data/test/test_deal.png" #标签文件 label_file = "data/trade/labels.txt" # 建立Paddle预测器 def create_paddle_perdictor(): config=Config(MODEL_FILE, MODEL_PARAM) config.enable_use_gpu(GPU_MEM, 0) config.set_cpu_math_library_num_threads(10) config.disable_glog_info() config.switch_ir_optim(True) config.enable_memory_optim() # use zero copy config.switch_use_feed_fetch_ops(False) predictor = create_predictor(config) return predictor # 图像预处理,转为Tensor张量 def preprocess(img): resize_op = ResizeImage(resize_short=256) img = resize_op(img) crop_op = CropImage(size=(224, 224)) img = crop_op(img) img_mean = [0.485, 0.456, 0.406] img_std = [0.229, 0.224, 0.225] img_scale = 1.0 / 255.0 normalize_op = NormalizeImage( scale=img_scale, mean=img_mean, std=img_std) img = normalize_op(img) tensor_op = ToTensor() img = tensor_op(img) return img # 结果分析处理 def postprocess(output): output = output.flatten() classes = np.argpartition(output, -top_k)[-top_k:] classes = classes[np.argsort(-output[classes])] scores = output[classes] return classes, scores # 读取标签文件 def readLabels(): index = 0 labels=[] with open(label_file) as file_obj: for line in file_obj: labels.append(line.strip()) index=index+1 return labels # 减少图像白边 def cutImg(img): dst = 255- img gray = cv2.cvtColor(dst,cv2.COLOR_BGR2GRAY) ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY) contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) minx = 65523 miny = 65523 maxx = 0 maxy = 0 for contour in contours: boundRect = cv2.boundingRect(contour) (x,y,w,h)=boundRect if xmaxx: maxx=x+w if y+h>maxy: maxy=y+h cutw=maxx-minx cuth=maxy-miny img = img[miny:maxy,minx:maxx] return img all_label = readLabels() predictor = create_paddle_perdictor() class MyTCPHandler(socketserver.BaseRequestHandler): def handle(self): try: self.data_len=self.request.recv(4) (byte_len,) = struct.unpack(">I", self.data_len) print("{} received byte len:".format(self.client_address),byte_len) # 分包接收 content = b'' while byte_len >0: if byte_len>=BUFFER_SIZE: tmp = self.request.recv(BUFFER_SIZE) content += tmp byte_len -= BUFFER_SIZE else: tmp = self.request.recv(byte_len) content += tmp byte_len = 0 break base64content = content # print("{} received image:".format(self.client_address), base64content) # 进行预测 input_names = predictor.get_input_names() input_tensor = predictor.get_input_handle(input_names[0]) output_names = predictor.get_output_names() output_tensor = predictor.get_output_handle(output_names[0]) imgString = base64.b64decode(base64content) nparr = np.fromstring(imgString,np.uint8) img_recv = cv2.imdecode(nparr,cv2.IMREAD_UNCHANGED) #make mask of where the transparent bits are trans_mask = img_recv[:,:,3] == 0 #replace areas of transparency with white and not transparent img_recv[trans_mask] = [255, 255, 255, 255] img_save = cv2.cvtColor(img_recv, cv2.COLOR_BGRA2BGR) # cv2.imwrite("data/test/receive.png", img_save) img_save = cutImg(img_save) img = cv2.resize(img_save,(64,64)) inputs = preprocess(img) inputs = np.expand_dims( inputs, axis=0).repeat( 1, axis=0).copy() input_tensor.copy_from_cpu(inputs) predictor.run() output = output_tensor.copy_to_cpu() classes, scores = postprocess(output) print(classes) result="" for index in classes: label=all_label[index] print(label) result = result+label print(result) self.request.sendall(bytes(result, encoding = "utf8")) except Exception as e: print("出现如下异常%s"%e) finally: print(self.client_address,"连接断开") self.request.close() def setup(self): print("before handle,连接建立:",self.client_address) def finish(self): print("finish run after handle") if __name__=="__main__": HOST,PORT = "0.0.0.0",8123 server=socketserver.TCPServer((HOST,PORT),MyTCPHandler) print('启动端口8123') server.serve_forever()