handwriting_predict.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import numpy as np
  2. import cv2
  3. import time
  4. from PIL import Image
  5. from paddle.inference import Config
  6. from paddle.inference import create_predictor
  7. from handwriting_util import ResizeImage,CropImage,ToTensor,NormalizeImage
  8. # paddle模型文件
  9. MODEL_FILE = "output_2/inference/inference.pdmodel"
  10. MODEL_PARAM = "output_2/inference/inference.pdiparams"
  11. # 分配显存
  12. GPU_MEM = 8000
  13. # 前k个候选
  14. top_k = 5
  15. # 测试文件
  16. img_file = "data/test/test5.png"
  17. #标签文件
  18. label_file = "data/trade/labels.txt"
  19. # 建立Paddle预测器
  20. def create_paddle_perdictor():
  21. config=Config(MODEL_FILE, MODEL_PARAM)
  22. config.enable_use_gpu(GPU_MEM, 0)
  23. config.set_cpu_math_library_num_threads(10)
  24. config.disable_glog_info()
  25. config.switch_ir_optim(True)
  26. config.enable_memory_optim()
  27. # use zero copy
  28. config.switch_use_feed_fetch_ops(False)
  29. predictor = create_predictor(config)
  30. return predictor
  31. # 图像预处理,转为Tensor张量
  32. def preprocess(img):
  33. resize_op = ResizeImage(resize_short=256)
  34. img = resize_op(img)
  35. crop_op = CropImage(size=(224, 224))
  36. img = crop_op(img)
  37. img_mean = [0.485, 0.456, 0.406]
  38. img_std = [0.229, 0.224, 0.225]
  39. img_scale = 1.0 / 255.0
  40. normalize_op = NormalizeImage(
  41. scale=img_scale, mean=img_mean, std=img_std)
  42. img = normalize_op(img)
  43. tensor_op = ToTensor()
  44. img = tensor_op(img)
  45. return img
  46. # 结果分析处理
  47. def postprocess(output):
  48. output = output.flatten()
  49. classes = np.argpartition(output, -top_k)[-top_k:]
  50. classes = classes[np.argsort(-output[classes])]
  51. scores = output[classes]
  52. return classes, scores
  53. # 读取标签文件
  54. def readLabels():
  55. index = 0
  56. labels=[]
  57. with open(label_file) as file_obj:
  58. for line in file_obj:
  59. labels.append(line.strip())
  60. index=index+1
  61. return labels
  62. def dealTransform():
  63. img=Image.open('data/test/test.png')
  64. img=transparence2white(img) # 将图片传入,改变背景色后,返回
  65. img=img.resize((64,64))
  66. img.save(img_file) # 保存图片
  67. def transparence2white(img):
  68. # img=img.convert('RGBA') # 此步骤是将图像转为灰度(RGBA表示4x8位像素,带透明度掩模的真彩色;CMYK为4x8位像素,分色等),可以省略
  69. sp=img.size
  70. width=sp[0]
  71. height=sp[1]
  72. print(sp)
  73. for yh in range(height):
  74. for xw in range(width):
  75. dot=(xw,yh)
  76. color_d=img.getpixel(dot) # 与cv2不同的是,这里需要用getpixel方法来获取维度数据
  77. if(color_d[3]==0):
  78. color_d=(255,255,255,255)
  79. img.putpixel(dot,color_d) # 赋值的方法是通过putpixel
  80. return img
  81. # 减少图像白边
  82. def cutImg(img):
  83. dst = 255- img
  84. gray = cv2.cvtColor(dst,cv2.COLOR_BGR2GRAY)
  85. ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
  86. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  87. minx = 65523
  88. miny = 65523
  89. maxx = 0
  90. maxy = 0
  91. for contour in contours:
  92. boundRect = cv2.boundingRect(contour)
  93. (x,y,w,h)=boundRect
  94. if x<minx:
  95. minx=x
  96. if y<miny:
  97. miny=y
  98. if x+w>maxx:
  99. maxx=x+w
  100. if y+h>maxy:
  101. maxy=y+h
  102. cutw=maxx-minx
  103. cuth=maxy-miny
  104. # w>h
  105. # if cutw>cuth:
  106. # cuty1=int((maxy+miny)/2-cutw/2)
  107. # cuty2=int((maxy+miny)/2+cutw/2)
  108. # img = img[cuty1:cuty2,minx:maxx]
  109. # if cutw<cuth:
  110. # cutx1=int((maxx+minx)/2-cuth/2)
  111. # cutx2=int((maxx+minx)/2+cuth/2)
  112. # img = img[miny:maxy,cutx1:cutx2]
  113. img = img[miny:maxy,minx:maxx]
  114. return img
  115. all_label = readLabels()
  116. # dealTransform()
  117. predictor = create_paddle_perdictor();
  118. input_names = predictor.get_input_names()
  119. input_tensor = predictor.get_input_handle(input_names[0])
  120. output_names = predictor.get_output_names()
  121. output_tensor = predictor.get_output_handle(output_names[0])
  122. img = cv2.imread(img_file)[:, :, ::-1]
  123. img=cutImg(img)
  124. cv2.imwrite("dealcut.png",img)
  125. img = cv2.resize(img, (64,64))
  126. inputs = preprocess(img)
  127. inputs = np.expand_dims(
  128. inputs, axis=0).repeat(
  129. 1, axis=0).copy()
  130. input_tensor.copy_from_cpu(inputs)
  131. predictor.run()
  132. output = output_tensor.copy_to_cpu()
  133. classes, scores = postprocess(output)
  134. print(classes)
  135. for index in classes:
  136. print(all_label[index])