import socketserver import struct import base64 import numpy as np import cv2 import time from PIL import Image import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms from model import MobileNetV3_large # 分包接收长度 BUFFER_SIZE = 1024 # pytorch模型文件 weights = 'weights/best.pkl' # 前k个候选 top_k = 10 # 均值、方差 handwrite_mean=(0.877, 0.877, 0.877) handwrite_std=(0.200, 0.200, 0.200) #标签文件 label_file = "data/labels.txt" # 读取标签文件 def readLabels(): index = 0 labels=[] with open(label_file) as file_obj: for line in file_obj: labels.append(line.strip()) index=index+1 return labels # 减少图像白边 def cutImg(img): dst = 255- img gray = cv2.cvtColor(dst,cv2.COLOR_BGR2GRAY) ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY) contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) minx = 65523 miny = 65523 maxx = 0 maxy = 0 for contour in contours: boundRect = cv2.boundingRect(contour) (x,y,w,h)=boundRect if xmaxx: maxx=x+w if y+h>maxy: maxy=y+h cutw=maxx-minx cuth=maxy-miny img = img[miny:maxy,minx:maxx] return img all_label = readLabels() net = MobileNetV3_large(num_classes = 3755) net.eval() if torch.cuda.is_available(): net.cuda() net.load_state_dict(torch.load(weights)) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean = handwrite_mean, std = handwrite_std) ]) class MyTCPHandler(socketserver.BaseRequestHandler): def handle(self): try: self.data_len=self.request.recv(4) (byte_len,) = struct.unpack(">I", self.data_len) print("{} received byte len:".format(self.client_address),byte_len) # 分包接收 content = b'' while byte_len >0: if byte_len>=BUFFER_SIZE: tmp = self.request.recv(BUFFER_SIZE) content += tmp byte_len -= BUFFER_SIZE else: tmp = self.request.recv(byte_len) content += tmp byte_len = 0 break base64content = content # print("{} received image:".format(self.client_address), base64content) # 进行预测 imgString = base64.b64decode(base64content) nparr = np.fromstring(imgString,np.uint8) img_recv = cv2.imdecode(nparr,cv2.IMREAD_UNCHANGED) #make mask of where the transparent bits are trans_mask = img_recv[:,:,3] == 0 #replace areas of transparency with white and not transparent img_recv[trans_mask] = [255, 255, 255, 255] img_save = cv2.cvtColor(img_recv, cv2.COLOR_BGRA2BGR) # cv2.imwrite("data/test/receive.png", img_save) img_save = cutImg(img_save) image = Image.fromarray(cv2.cvtColor(img_save,cv2.COLOR_BGR2RGB)) img_tensor = transform(image).unsqueeze(0) if torch.cuda.is_available(): img_tensor=img_tensor.cuda() net_output = net(img_tensor) print(net_output) _, predicted = torch.max(net_output.data, 1) result = predicted[0].item() print("预测的结果为:",result) top = torch.topk(input=net_output.data, k=top_k) topk = top[1].cpu().numpy()[0] result="" for index in topk: label=all_label[index] print(label) result = result+label print(result) self.request.sendall(bytes(result, encoding = "utf8")) except Exception as e: print("出现如下异常%s"%e) finally: print(self.client_address,"连接断开") self.request.close() def setup(self): print("before handle,连接建立:",self.client_address) def finish(self): print("finish run after handle") if __name__=="__main__": HOST,PORT = "0.0.0.0",8123 server=socketserver.TCPServer((HOST,PORT),MyTCPHandler) print('启动端口8123') server.serve_forever()