本章承接上一篇的手寫數字識別,利用訓練好的模型,結合pyqt畫板,實現簡易手寫輸入法,為"hello world"例子增添樂趣。
pyqt是開發圖形界面的框架,可以百度查找相關資料了解安裝及基礎方法,我搭建的環境是pycharm+pyqt5+qtdesigner,配置好之后的界面長這樣:

在左邊的項目中右鍵某個文件,也可以打開qt菜單
具體怎么畫界面不展開了,直接看下代碼:
1 # coding: utf-8 2 from PyQt5.QtWidgets import * 3 from PyQt5.QtGui import * 4 from PyQt5.QtCore import * 5 import sys 6 sys.path.append(r'../ml/torch') 7 from digit_recog import Net 8 import torch 9 import os 10 import numpy as np 11 import matplotlib.pyplot as plt 12 from PIL import Image 13 14 15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 net = Net().to(device) 17 # 加載參數 18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth')) 19 # 參數加載到指定模型 20 net.load_state_dict(nn_state) 21 net.eval() 22 23 24 def predict(img): 25 # 讀取圖片並重設尺寸 26 image = Image.open(img).resize((28, 28)) 27 # 灰度圖 28 gray_image = image.convert('L') 29 # plt.imshow(gray_image) 30 # plt.show() 31 # 圖片數據處理 32 im_data = np.array(gray_image) 33 im_data = torch.from_numpy(im_data).float() 34 im_data = im_data.view(1, 1, 28, 28) 35 # 神經網絡運算 36 outputs = net(im_data) 37 # 取最大預測值 38 _, pred = torch.max(outputs, 1) 39 return pred.item() 40 41 42 class SimpleDrawingBoard(QWidget): 43 win = '' 44 wins = [] 45 46 @classmethod 47 def showWin(cls): 48 # 聚焦到已有窗口 49 if not cls.win: 50 cls.win = cls() 51 cls.win.show() 52 else: 53 cls.win.activateWindow() 54 55 def __init__(self, parent=None): 56 super(SimpleDrawingBoard, self).__init__(parent) 57 58 self.setWindowTitle(u"手寫數字識別") 59 self.setWindowFlags(Qt.WindowStaysOnTopHint) 60 self.size = (400, 350) 61 self.resize(*self.size) 62 self.setWindowFlag(Qt.FramelessWindowHint) # 隱藏邊框 63 # self.setWindowOpacity(0.9) # 設置窗口透明度 64 # self.setAttribute(Qt.WA_TranslucentBackground) # 設置窗口背景透明 65 66 self.canvasSize = (280, 350) 67 self.sizeOffset = [a - b for a, b in zip(self.size, self.canvasSize)] 68 self.canvas = QPixmap(*self.canvasSize) 69 self.canvas.fill(Qt.black) 70 self.tempCanvas = QPixmap() 71 self.lastPoint = QPoint() 72 self.endPoint = QPoint() 73 self.isDrawing = False 74 self.penSize = 15 75 76 self.initUI() 77 78 def initUI(self): 79 self.penSizeLabel = QLabel(u'畫筆粗細') 80 self.penSizeSpinBox = QSpinBox() 81 self.penSizeSpinBox.setValue(self.penSize) 82 self.penSizeSpinBox.valueChanged.connect(self.penSizeSpinBox_valueChanged) 83 self.penSizeSpinBox.setFixedWidth(80) 84 85 self.clearButton = QPushButton(u'清空') 86 self.clearButton.setFixedWidth(80) 87 self.clearButton.clicked.connect(self.clearPainter) 88 89 self.closeButton = QPushButton(u'關閉') 90 self.closeButton.setFixedWidth(80) 91 self.closeButton.clicked.connect(self.close) 92 93 self.inputLabel = QLabel(self) 94 self.inputLabel.setFixedSize(80, 200) 95 self.inputLabel.setAutoFillBackground(True) 96 self.inputLabel.setAlignment(Qt.AlignCenter) 97 self.inputLabel.setStyleSheet('''QLabel{background:#F76677;border-radius:5px;font-size:60px;font-weight:bolder;}''') 98 99 mainLayout = QVBoxLayout(self) 100 101 toolbarLayout = QGridLayout() 102 # toolbarLayout.setSpacing(20) 103 toolbarLayout.addWidget(self.penSizeLabel, 0, 0, 1, 1) 104 toolbarLayout.addWidget(self.penSizeSpinBox, 1, 0, 1, 1) 105 toolbarLayout.addWidget(self.clearButton, 2, 0, 1, 1) 106 toolbarLayout.addWidget(self.closeButton, 3, 0, 1, 1) 107 toolbarLayout.addWidget(self.inputLabel, 4, 0, 1, 1) 108 109 toolbarLayout.setAlignment(Qt.AlignLeft) 110 111 mainLayout.addLayout(toolbarLayout) 112 mainLayout.addStretch(1) 113 114 def penSizeSpinBox_valueChanged(self): 115 # 設置畫筆粗細 116 self.penSize = self.penSizeSpinBox.value() 117 118 def paintEvent(self, event): 119 pp = QPainter(self.canvas) 120 pen = QPen(QColor(255, 255, 255), self.penSize) 121 pp.setPen(pen) 122 if self.lastPoint != self.endPoint: 123 pp.drawLine(self.lastPoint - QPoint(*self.sizeOffset), self.endPoint - QPoint(*self.sizeOffset)) 124 painter = QPainter(self) 125 painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas) 126 self.lastPoint = self.endPoint 127 128 def clearPainter(self): 129 print('clear...') 130 self.canvas.fill(Qt.black) 131 painter = QPainter(self) 132 painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas) 133 self.lastPoint = self.endPoint 134 self.update() 135 self.inputLabel.clear() 136 137 def mousePressEvent(self, event): 138 # 按下左鍵 139 if event.button() == Qt.LeftButton: 140 self.lastPoint = event.pos() 141 self.endPoint = self.lastPoint 142 self.isDrawing = True 143 144 def mouseMoveEvent(self, event): 145 if self.isDrawing: 146 self.update() 147 self.endPoint = event.pos() 148 149 def mouseReleaseEvent(self, event): 150 if event.button() == Qt.LeftButton: 151 self.isDrawing = False 152 self.endPoint = event.pos() 153 self.update() 154 self.canvas.toImage().save('input.png') 155 input = predict('input.png') 156 self.inputLabel.setText(str(input)) 157 print('你輸入的是{}'.format(input)) 158 159 160 if __name__ == '__main__': 161 app = QApplication.instance() 162 if not app: 163 app = QApplication(sys.argv) 164 SimpleDrawingBoard.showWin() 165 app.exec_()
上面引入前一章訓練好的模型,位於不同的文件夾內,需要加上這一行代碼:
sys.path.append(r'../ml/torch')
看下運行效果:


上面寫了兩個數字,識別輸出正確!
helloworld例子比較枯燥,通過動手參與與AI交互增強信心樂趣,信心是一步步建立起來的,而大的突破亦是如此,后面會持續圍繞簡單的例子,深入發掘AI的樂趣與應用場景。
