123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- import socketserver
- import struct
- import base64
- import numpy as np
- import cv2
- import time
- from PIL import Image
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- import torchvision.transforms as transforms
- from model import MobileNetV3_large
- # 分包接收长度
- BUFFER_SIZE = 1024
- # pytorch模型文件
- weights = 'weights/best.pkl'
- # 前k个候选
- top_k = 10
- # 均值、方差
- handwrite_mean=(0.877, 0.877, 0.877)
- handwrite_std=(0.200, 0.200, 0.200)
- #标签文件
- label_file = "data/labels.txt"
- # 读取标签文件
- def readLabels():
- index = 0
- labels=[]
- with open(label_file) as file_obj:
- for line in file_obj:
- labels.append(line.strip())
- index=index+1
- return labels
- # 减少图像白边
- def cutImg(img):
- dst = 255- img
- gray = cv2.cvtColor(dst,cv2.COLOR_BGR2GRAY)
- ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
-
- contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
- minx = 65523
- miny = 65523
- maxx = 0
- maxy = 0
- for contour in contours:
- boundRect = cv2.boundingRect(contour)
- (x,y,w,h)=boundRect
- if x<minx:
- minx=x
- if y<miny:
- miny=y
- if x+w>maxx:
- maxx=x+w
- if y+h>maxy:
- maxy=y+h
- cutw=maxx-minx
- cuth=maxy-miny
- img = img[miny:maxy,minx:maxx]
- return img
- all_label = readLabels()
- net = MobileNetV3_large(num_classes = 3755)
- net.eval()
- if torch.cuda.is_available():
- net.cuda()
- net.load_state_dict(torch.load(weights))
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean = handwrite_mean, std = handwrite_std)
- ])
- class MyTCPHandler(socketserver.BaseRequestHandler):
- def handle(self):
- try:
-
- self.data_len=self.request.recv(4)
- (byte_len,) = struct.unpack(">I", self.data_len)
- print("{} received byte len:".format(self.client_address),byte_len)
- # 分包接收
- content = b''
- while byte_len >0:
- if byte_len>=BUFFER_SIZE:
- tmp = self.request.recv(BUFFER_SIZE)
- content += tmp
- byte_len -= BUFFER_SIZE
- else:
- tmp = self.request.recv(byte_len)
- content += tmp
- byte_len = 0
- break
-
- base64content = content
- # print("{} received image:".format(self.client_address), base64content)
-
- # 进行预测
- imgString = base64.b64decode(base64content)
- nparr = np.fromstring(imgString,np.uint8)
- img_recv = cv2.imdecode(nparr,cv2.IMREAD_UNCHANGED)
- #make mask of where the transparent bits are
- trans_mask = img_recv[:,:,3] == 0
- #replace areas of transparency with white and not transparent
- img_recv[trans_mask] = [255, 255, 255, 255]
- img_save = cv2.cvtColor(img_recv, cv2.COLOR_BGRA2BGR)
- # cv2.imwrite("data/test/receive.png", img_save)
- img_save = cutImg(img_save)
- image = Image.fromarray(cv2.cvtColor(img_save,cv2.COLOR_BGR2RGB))
- img_tensor = transform(image).unsqueeze(0)
- if torch.cuda.is_available():
- img_tensor=img_tensor.cuda()
- net_output = net(img_tensor)
- print(net_output)
- _, predicted = torch.max(net_output.data, 1)
- result = predicted[0].item()
- print("预测的结果为:",result)
- top = torch.topk(input=net_output.data, k=top_k)
- topk = top[1].cpu().numpy()[0]
- result=""
- for index in topk:
- label=all_label[index]
- print(label)
- result = result+label
- print(result)
- self.request.sendall(bytes(result, encoding = "utf8"))
-
- except Exception as e:
- print("出现如下异常%s"%e)
- finally:
- print(self.client_address,"连接断开")
- self.request.close()
- def setup(self):
- print("before handle,连接建立:",self.client_address)
- def finish(self):
- print("finish run after handle")
- if __name__=="__main__":
- HOST,PORT = "0.0.0.0",8123
- server=socketserver.TCPServer((HOST,PORT),MyTCPHandler)
- print('启动端口8123')
- server.serve_forever()
|