linux-基於tensorflow2.x的手寫數字識別-基於MNIST數據集


數據集

數據集下載🔗MNIST
首先讀取數據集, 並打印相關信息
包括

  • 圖像的數量, 形狀
  • 像素的最大, 最小值
  • 以及看一下第一張圖片
path = 'MNIST/mnist.npz'
with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

print(f'dataset info: shape: {x_train.shape}, {y_train.shape}')
print(f'dataset info: max: {x_train.max()}')
print(f'dataset info: min: {x_train.min()}')

print("A sample:")
print("y_train: ", y_train[0])
# print("x_train: \n", x_train[0])
show_pic = x_train[0].copy()
show_pic = cv2.resize(show_pic, (28 * 10, 28 * 10))
cv2.imshow("A image sample", show_pic)
key = cv2.waitKey(0)
# 按 q 退出
if key == ord('q'):
    cv2.destroyAllWindows()
    print("show demo over")

轉換為tf 數據集的格式, 並進行歸一化

# convert to tf tensor
x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) // 255.
x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) // 255.
dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_train = dataset_train.batch(batch_size).repeat(class_num)

定義網絡

在這里定義一個簡單的全連接網絡

def build_simple_net():
    net = Sequential([
        layers.Dense(256, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(class_num)
    ])
    net.build(input_shape=(None, 28 * 28))
    # net.summary()
    return net

訓練

使用 SGD 優化器進行訓練

def train(print_info_step=250):
    net = build_simple_net()
    # 優化器
    optimizer = optimizers.SGD(lr=0.01)
    # 計算准確率
    acc = metrics.Accuracy()

    for step, (x, y) in enumerate(dataset_train):
        with tf.GradientTape() as tape:
            # [b, 28, 28] => [b, 784]
            x = tf.reshape(x, (-1, 28 * 28))
            # [b, 784] => [b, 10]
            out = net(x)
            # [b] => [b, 10]
            y_onehot = tf.one_hot(y, depth=class_num)
            # [b, 10]
            loss = tf.square(out - y_onehot)
            # [b]
            loss = tf.reduce_sum(loss) / batch_size

        # 反向傳播
        acc.update_state(tf.argmax(out, axis=1), y)
        grads = tape.gradient(loss, net.trainable_variables)
        optimizer.apply_gradients(zip(grads, net.trainable_variables))

        if acc.result() >= 0.90:
            net.save_weights(save_path)
            print(f'final acc: {acc.result()}, total step: {step}')
            break

        if step % print_info_step == 0:
            print(f'step: {step}, loss: {loss}, acc: {acc.result().numpy()}')
            acc.reset_states()

        if step % 500 == 0 and step != 0:
            print('save model')
            net.save_weights(save_path)

驗證

驗證在測試集的模型效果, 這里僅取出第一張進行驗證

def test_dataset():
    net = build_simple_net()
    # 加載模型
    net.load_weights(save_path)
    # 拿到測試集第一張圖片
    pred_image = x_test[0]
    pred_image = tf.reshape(pred_image, (-1, 28 * 28))
    pred = net.predict(pred_image)
    # print(pred)
    print(f'pred: {tf.argmax(pred, axis=1).numpy()}, label: {y_test[0]}')

應用

分割手寫數字, 並進行逐一識別

  • 先將圖像二值化
  • 找到輪廓
  • 得到數字的坐標
  • 轉為模型的需要的輸入格式, 並進行識別
  • 顯示
def split_number(img):
    result = []
    net = build_simple_net()
    # 加載模型
    net.load_weights(save_path)

    image = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2GRAY)
    ret, thresh = cv2.threshold(image, 127, 255, 0)
    contours, hierarchy = cv2.findContours(thresh, 1, 2)
    for cnt in contours[:-1]:
        x, y, w, h = cv2.boundingRect(cnt)

        image = img[y:y+h, x:x+w]
        image = cv2.resize(image, (28, 28))

        pred_image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.
        pred_image = tf.reshape(pred_image, (-1, 28 * 28))
        pred = net.predict(pred_image)
        out = tf.argmax(pred, axis=1).numpy()
        result = [out[0]] + result
        img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)

    cv2.imshow("demo", img)
    print(result)
    k = cv2.waitKey(0)
    # 按 q 退出
    if k == ord('q'):
        pass
    cv2.destroyAllWindows()

效果

單數字

多數字

附錄

所有代碼, 文件 tf2_mnist.py

import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Sequential, optimizers, metrics

# 屏蔽通知信息和警告信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# 每批幾張圖片
batch_size = 2
# 類別數
class_num = 10
# 保存模型的路徑
save_path = "./models/mnist.ckpt"
# 展示樣例
show_demo = False
# 驗證測試集
evaluate_dataset = False
# 是否訓練
run_train = False
# 圖片路徑, 僅用於 detect_image(), 當為False時不識別
image_path = 'images/36.png'

path = 'MNIST/mnist.npz'
with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

if show_demo:
    print(f'dataset info: shape: {x_train.shape}, {y_train.shape}')
    print(f'dataset info: max: {x_train.max()}')
    print(f'dataset info: min: {x_train.min()}')

    print("A sample:")
    print("y_train: ", y_train[0])
    # print("x_train: \n", x_train[0])
    show_pic = x_train[0].copy()
    show_pic = cv2.resize(show_pic, (28 * 10, 28 * 10))
    cv2.imshow("A image sample", show_pic)
    key = cv2.waitKey(0)
    if key == ord('q'):
        cv2.destroyAllWindows()
        print("show demo over")

# convert to tf tensor
x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) // 255.
x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) // 255.
dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_train = dataset_train.batch(batch_size).repeat(class_num)


def build_simple_net():
    net = Sequential([
        layers.Dense(256, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(class_num)
    ])
    net.build(input_shape=(None, 28 * 28))
    # net.summary()
    return net


def train(print_info_step=250):
    net = build_simple_net()
    # 優化器
    optimizer = optimizers.SGD(lr=0.01)
    # 計算准確率
    acc = metrics.Accuracy()

    for step, (x, y) in enumerate(dataset_train):
        with tf.GradientTape() as tape:
            # [b, 28, 28] => [b, 784]
            x = tf.reshape(x, (-1, 28 * 28))
            # [b, 784] => [b, 10]
            out = net(x)
            # [b] => [b, 10]
            y_onehot = tf.one_hot(y, depth=class_num)
            # [b, 10]
            loss = tf.square(out - y_onehot)
            # [b]
            loss = tf.reduce_sum(loss) / batch_size

        # 反向傳播
        acc.update_state(tf.argmax(out, axis=1), y)
        grads = tape.gradient(loss, net.trainable_variables)
        optimizer.apply_gradients(zip(grads, net.trainable_variables))

        if acc.result() >= 0.90:
            net.save_weights(save_path)
            print(f'final acc: {acc.result()}, total step: {step}')
            break

        if step % print_info_step == 0:
            print(f'step: {step}, loss: {loss}, acc: {acc.result().numpy()}')
            acc.reset_states()

        if step % 500 == 0 and step != 0:
            print('save model')
            net.save_weights(save_path)


def test_dataset():
    net = build_simple_net()
    # 加載模型
    net.load_weights(save_path)
    # 拿到測試集第一張圖片
    pred_image = x_test[0]
    pred_image = tf.reshape(pred_image, (-1, 28 * 28))
    pred = net.predict(pred_image)
    # print(pred)
    print(f'pred: {tf.argmax(pred, axis=1).numpy()}, label: {y_test[0]}')

def split_number(img):
    result = []
    net = build_simple_net()
    # 加載模型
    net.load_weights(save_path)

    image = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2GRAY)
    ret, thresh = cv2.threshold(image, 127, 255, 0)
    contours, hierarchy = cv2.findContours(thresh, 1, 2)
    for cnt in contours[:-1]:
        x, y, w, h = cv2.boundingRect(cnt)

        image = img[y:y+h, x:x+w]
        image = cv2.resize(image, (28, 28))

        pred_image = tf.convert_to_tensor(image, dtype=tf.float32) / 255.
        pred_image = tf.reshape(pred_image, (-1, 28 * 28))
        pred = net.predict(pred_image)
        out = tf.argmax(pred, axis=1).numpy()
        result = [out[0]] + result
        img = cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)

    cv2.imshow("demo", img)
    print(result)
    k = cv2.waitKey(0)
    if k == ord('q'):
        pass
    cv2.destroyAllWindows()


if __name__ == '__main__':
    if run_train:
        train()
    elif evaluate_dataset:
        test_dataset()
    elif image_path:
        image = cv2.imread(image_path)
        # detect_image(image)
        split_number(image)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM