AI手寫輸入法 - pytorch從入門到入道(二)


本章承接上一篇的手寫數字識別,利用訓練好的模型,結合pyqt畫板,實現簡易手寫輸入法,為"hello world"例子增添樂趣。

pyqt是開發圖形界面的框架,可以百度查找相關資料了解安裝及基礎方法,我搭建的環境是pycharm+pyqt5+qtdesigner,配置好之后的界面長這樣:

在左邊的項目中右鍵某個文件,也可以打開qt菜單

具體怎么畫界面不展開了,直接看下代碼:

  1 # coding: utf-8
  2 from PyQt5.QtWidgets import *
  3 from PyQt5.QtGui import *
  4 from PyQt5.QtCore import *
  5 import sys
  6 sys.path.append(r'../ml/torch')
  7 from digit_recog import Net
  8 import torch
  9 import os
 10 import numpy as np
 11 import matplotlib.pyplot as plt
 12 from PIL import Image
 13 
 14 
 15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 16 net = Net().to(device)
 17 # 加載參數
 18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth'))
 19 # 參數加載到指定模型
 20 net.load_state_dict(nn_state)
 21 net.eval()
 22 
 23 
 24 def predict(img):
 25     # 讀取圖片並重設尺寸
 26     image = Image.open(img).resize((28, 28))
 27     # 灰度圖
 28     gray_image = image.convert('L')
 29     # plt.imshow(gray_image)
 30     # plt.show()
 31     # 圖片數據處理
 32     im_data = np.array(gray_image)
 33     im_data = torch.from_numpy(im_data).float()
 34     im_data = im_data.view(1, 1, 28, 28)
 35     # 神經網絡運算
 36     outputs = net(im_data)
 37     # 取最大預測值
 38     _, pred = torch.max(outputs, 1)
 39     return pred.item()
 40 
 41 
 42 class SimpleDrawingBoard(QWidget):
 43     win = ''
 44     wins = []
 45 
 46     @classmethod
 47     def showWin(cls):
 48         # 聚焦到已有窗口
 49         if not cls.win:
 50             cls.win = cls()
 51             cls.win.show()
 52         else:
 53             cls.win.activateWindow()
 54 
 55     def __init__(self, parent=None):
 56         super(SimpleDrawingBoard, self).__init__(parent)
 57 
 58         self.setWindowTitle(u"手寫數字識別")
 59         self.setWindowFlags(Qt.WindowStaysOnTopHint)
 60         self.size = (400, 350)
 61         self.resize(*self.size)
 62         self.setWindowFlag(Qt.FramelessWindowHint)  # 隱藏邊框
 63         # self.setWindowOpacity(0.9)  # 設置窗口透明度
 64         # self.setAttribute(Qt.WA_TranslucentBackground)  # 設置窗口背景透明
 65 
 66         self.canvasSize = (280, 350)
 67         self.sizeOffset = [a - b for a, b in zip(self.size, self.canvasSize)]
 68         self.canvas = QPixmap(*self.canvasSize)
 69         self.canvas.fill(Qt.black)
 70         self.tempCanvas = QPixmap()
 71         self.lastPoint = QPoint()
 72         self.endPoint = QPoint()
 73         self.isDrawing = False
 74         self.penSize = 15
 75 
 76         self.initUI()
 77 
 78     def initUI(self):
 79         self.penSizeLabel = QLabel(u'畫筆粗細')
 80         self.penSizeSpinBox = QSpinBox()
 81         self.penSizeSpinBox.setValue(self.penSize)
 82         self.penSizeSpinBox.valueChanged.connect(self.penSizeSpinBox_valueChanged)
 83         self.penSizeSpinBox.setFixedWidth(80)
 84 
 85         self.clearButton = QPushButton(u'清空')
 86         self.clearButton.setFixedWidth(80)
 87         self.clearButton.clicked.connect(self.clearPainter)
 88 
 89         self.closeButton = QPushButton(u'關閉')
 90         self.closeButton.setFixedWidth(80)
 91         self.closeButton.clicked.connect(self.close)
 92 
 93         self.inputLabel = QLabel(self)
 94         self.inputLabel.setFixedSize(80, 200)
 95         self.inputLabel.setAutoFillBackground(True)
 96         self.inputLabel.setAlignment(Qt.AlignCenter)
 97         self.inputLabel.setStyleSheet('''QLabel{background:#F76677;border-radius:5px;font-size:60px;font-weight:bolder;}''')
 98 
 99         mainLayout = QVBoxLayout(self)
100 
101         toolbarLayout = QGridLayout()
102         # toolbarLayout.setSpacing(20)
103         toolbarLayout.addWidget(self.penSizeLabel, 0, 0, 1, 1)
104         toolbarLayout.addWidget(self.penSizeSpinBox, 1, 0, 1, 1)
105         toolbarLayout.addWidget(self.clearButton, 2, 0, 1, 1)
106         toolbarLayout.addWidget(self.closeButton, 3, 0, 1, 1)
107         toolbarLayout.addWidget(self.inputLabel, 4, 0, 1, 1)
108 
109         toolbarLayout.setAlignment(Qt.AlignLeft)
110 
111         mainLayout.addLayout(toolbarLayout)
112         mainLayout.addStretch(1)
113 
114     def penSizeSpinBox_valueChanged(self):
115         # 設置畫筆粗細
116         self.penSize = self.penSizeSpinBox.value()
117 
118     def paintEvent(self, event):
119         pp = QPainter(self.canvas)
120         pen = QPen(QColor(255, 255, 255), self.penSize)
121         pp.setPen(pen)
122         if self.lastPoint != self.endPoint:
123             pp.drawLine(self.lastPoint - QPoint(*self.sizeOffset), self.endPoint - QPoint(*self.sizeOffset))
124         painter = QPainter(self)
125         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
126         self.lastPoint = self.endPoint
127 
128     def clearPainter(self):
129         print('clear...')
130         self.canvas.fill(Qt.black)
131         painter = QPainter(self)
132         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
133         self.lastPoint = self.endPoint
134         self.update()
135         self.inputLabel.clear()
136 
137     def mousePressEvent(self, event):
138         # 按下左鍵
139         if event.button() == Qt.LeftButton:
140             self.lastPoint = event.pos()
141             self.endPoint = self.lastPoint
142             self.isDrawing = True
143 
144     def mouseMoveEvent(self, event):
145         if self.isDrawing:
146             self.update()
147             self.endPoint = event.pos()
148 
149     def mouseReleaseEvent(self, event):
150         if event.button() == Qt.LeftButton:
151             self.isDrawing = False
152             self.endPoint = event.pos()
153             self.update()
154             self.canvas.toImage().save('input.png')
155             input = predict('input.png')
156             self.inputLabel.setText(str(input))
157             print('你輸入的是{}'.format(input))
158 
159 
160 if __name__ == '__main__':
161     app = QApplication.instance()
162     if not app:
163         app = QApplication(sys.argv)
164     SimpleDrawingBoard.showWin()
165     app.exec_()

上面引入前一章訓練好的模型,位於不同的文件夾內,需要加上這一行代碼:

sys.path.append(r'../ml/torch')

看下運行效果:

上面寫了兩個數字,識別輸出正確!

helloworld例子比較枯燥,通過動手參與與AI交互增強信心樂趣,信心是一步步建立起來的,而大的突破亦是如此,后面會持續圍繞簡單的例子,深入發掘AI的樂趣與應用場景。

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM