使用LeNet算法實現手寫數字識別(代碼部分解釋)


# 確認當前環境的版本

import mindspore

print(mindspore.__version__)

1. 數據集下載

MNIST是一個手寫數字數據集,訓練集包含60000張手寫數字,測試集包含10000張手寫數字,共10類。

從華為雲OBS公共桶中下載。

import os

import moxing as mox

 

if not os.path.exists("./MNIST_Data.zip"):

    mox.file.copy("obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip", "./MNIST_Data.zip")

下載Minist數據集

!unzip -o MNIST_Data.zip -d ./

!tree ./MNIST_Data/

2、處理MNIST數據集

由於我們后面會采用LeNet這樣的卷積神經網絡對數據集進行訓練,而采用LeNet在訓練數據時,對數據格式是有所要求的,所以接下來的工作需要我們先查看數據集內的數據是什么樣的,這樣才能構造一個針對性的數據轉換函數,將數據集數據轉換成符合訓練要求的數據形式。

 

步驟1 查看原始數據集數據:

from mindspore import context 導入context 模塊,調用context的set_context方法

import matplotlib.pyplot as plt

import matplotlib

import numpy as np

import mindspore.dataset as ds

上面是調用畫圖所需要的庫

# device_target 可選 CPU/GPU/Ascend, 當選擇GPU時mindspore規格也需要切換到GPU device_target表示硬件信息,有三個選項(CPU、GPU、Ascend)

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

# device_id = int(os.getenv("DEVICE_ID"))

# context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)

train_data_path = "./MNIST_Data/train"    訓練數據集路徑

test_data_path = "./MNIST_Data/test" 測試數據集路徑

mnist_ds = ds.MnistDataset(train_data_path)

print('The type of mnist_ds:', type(mnist_ds))

print("Number of pictures contained in the mnist_ds:", mnist_ds.get_dataset_size())

 

dic_ds = mnist_ds.create_dict_iterator()

item = next(dic_ds) 字典類型數據

img = item["image"].asnumpy()

label = item["label"].asnumpy()

print("The item of mnist_ds:", item.keys())  

print("Tensor of image in item:", img.shape)

print("The label of item:", label)

 

plt.imshow(np.squeeze(img))

plt.title("number:%s"% item["label"].asnumpy())  

plt.show()   顯示圖像

步驟2 數據處理:

數據集對於訓練非常重要,好的數據集可以有效提高訓練精度和效率,在加載數據集前,我們通常會對數據集進行一些處理。

定義數據集及數據操作

定義完成后,使用create_datasets對原始數據進行增強操作,並抽取一個batch的數據,查看數據增強后的變化。

定義一個creat_dataset函數,對原始數據進行增強(進行數據增強提升樣本質量,避免局部最優等一些問題)

import mindspore.dataset.vision.c_transforms as CV

import mindspore.dataset.transforms.c_transforms as C

from mindspore.dataset.vision import Inter

from mindspore import dtype as mstype

 

 

def create_dataset(data_path, batch_size=32, repeat_size=1,

                   num_parallel_workers=1):

    """

    create dataset for train or test

    Args:

        data_path (str): Data path

        batch_size (int): The number of data records in each group

        repeat_size (int): The number of replicated data records

        num_parallel_workers (int): The number of parallel workers

    """

    # define dataset

    mnist_ds = ds.MnistDataset(data_path) 將原始的Minist數據集加載進來

 

# define some parameters needed for data enhancement and rough justification

定義數據增強和處理的一些參數

    resize_height, resize_width = 32, 32  圖片大小32*32

    rescale = 1.0 / 255.0 圖像縮放因子,使每個像素點在(0,255)范圍內

    shift = 0.0   圖像偏移因子

    rescale_nml = 1 / 0.3081   

    shift_nml = -1 * 0.1307 / 0.3081

 

    # according to the parameters, generate the corresponding data enhancement method

    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  使用給定的插值模式將輸入圖像調整為一定大小

    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)

rescale_op = CV.Rescale(rescale, shift) 使用給定的縮放和移位消除圖片在不同位置帶來的影響

hwc2chw_op = CV.HWC2CHW() 將輸入圖像從(H,W,C)轉化為(C,H,W)

    type_cast_op = C.TypeCast(mstype.int32)

轉化為Minist特定的數據類型mstype

# using map to apply operations to a dataset

將數據增強處理方法映射到(使用)在對應數據集的相應部分(image,label)

    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    

    # process the generated dataset

    buffer_size = 10000

    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)   將數據集進行打亂

    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)   數據集按照batch_size的大小分成若干批次,按照批次進行訓練,drop_remainder=True是保留數據集整除batch_size的,多余的舍棄掉

    mnist_ds = mnist_ds.repeat(repeat_size)

將數據集重復repeat_size次

    return mnist_ds

展示數據集

ms_dataset = create_dataset(train_data_path)

print('Number of groups in the dataset:', ms_dataset.get_dataset_size())  查看一共分成多少組,將數據一共分成60000/32=1875個批次

步驟3 進一步查看增強后的數據:

  • 從1875組數據中取出一組數據查看其數據張量及label。將張量數據和下標對應的值進行可視化。

data = next(ms_dataset.create_dict_iterator(output_numpy=True))  利用next獲取樣本並查看單個樣本格式

images = data["image"]

labels = data["label"]   獲取圖像張量

print('Tensor of image:', images.shape)

print('labels:', labels)

  • 將張量數據和下標對應的值進行可視化。

count = 1

for i in images:

    plt.subplot(4, 8, count)

    plt.imshow(np.squeeze(i))

    plt.title('num:%s'%labels[count-1])

    plt.xticks([])

    count += 1

    plt.axis("off")

plt.show()

3、 定義模型

在對手寫字體識別上,通常采用卷積神經網絡架構(CNN)進行學習預測,最經典的屬1998年由Yann LeCun創建的LeNet5架構,
結構示意如下圖:

 

import mindspore.nn as nn

from mindspore.common.initializer import Normal

 

class LeNet5(nn.Cell):

    """Lenet network structure."""

    # define the operator required

    def __init__(self, num_class=10, num_channel=1):

        super(LeNet5, self).__init__()   #繼承父類nn.cell的__init__方法

        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')

        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')   pad_mode是卷積方式 ‘valid’是pad_mode = 0

        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))

        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))

        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))

nn.Dense為致密連接層,它的第一個參數為輸入層的維度,第二個參數為輸出的維度,第三個參數為神經網絡可訓練參數W權重矩陣的初始化方式,默認為normal

        self.relu = nn.ReLU()

        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)  將圖片的寬度和高度都縮小一半

        self.flatten = nn.Flatten()   nn.Flatten為輸入展成平圖層,即去掉那些空的維度

 

    # use the preceding operators to construct networks

    def construct(self, x):        使用前饋操作構建神經網絡

        x = self.max_pool2d(self.relu(self.conv1(x)))

        x = self.max_pool2d(self.relu(self.conv2(x)))

        x = self.flatten(x)

        x = self.relu(self.fc1(x))

        x = self.relu(self.fc2(x))

        x = self.fc3(x)

        return x

    

network = LeNet5()

print("layer conv1:", network.conv1)

print("*"*40)

print("layer fc1:", network.fc1)

4、搭建訓練網絡並進行訓練

構建完成神經網絡后,就可以着手進行訓練網絡的構建,模型訓練函數為Model.train

此步驟案例中使用epoch=1,使用CPU訓練大概耗時15分鍾,為實現快速訓練,可選用更高規格的資源訓練。

import os

from mindspore import Tensor, Model

from mindspore import load_checkpoint, load_param_into_net

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor

from mindspore.nn.metrics import Accuracy

from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits

 

lr = 0.01   learingrate,學習率,可以使梯度下降的幅度變小,從而可以更好的訓練參數

momentum = 0.9

epoch_size = 1     每個epoch需要遍歷完成圖片的batch數,這里是只要遍歷一次

mnist_path = "./MNIST_Data"

model_path = "./models/ckpt/"

 

# clean up old run files before in Linux

os.system('rm -f {}*.ckpt {}*.meta {}*.pb'.format(model_path, model_path, model_path))    清理舊文件

 

# create the network

network = LeNet5()    創建神經網絡

 

# define the optimizer

net_opt = nn.Momentum(network.trainable_params(), lr, momentum)   使用了Momentum優化器進行優化

 

# define the loss function

net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

相當於softmax分類器      sparse指定標簽(label)是否使用稀疏模式,默認為false,reduction為損失的減少類型:mean表示平均值,一般情況下都是選擇平均地減少

# define the model

model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()} )

調用Model高級API   metrics 參數是指訓練和測試期

# save the network model and parameters for subsequence fine-tuning

config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)     保存訓練好的模型參數的路徑

# group layers into an object with training and evaluation features

ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=model_path, config=config_ck)    保存訓練好的模型參數的路徑

 

 

eval_dataset = create_dataset("./MNIST_Data/test")  測試數據集

 

step_loss = {"step": [], "loss_value": []}   回調函數中的數據格式

steps_eval = {"step": [], "acc": []}     回調函數中的數據格式

(收集step對應模型精度值accuracy的信息)

repeat_size = 1   

ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)

 

model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=False)   調用Model類的train方法進行訓練

5、測試數據驗證模型精度

# testing relate modules

def test_net(network, model, mnist_path):     測試神經網絡函數

    """Define the evaluation method."""

    print("============== Starting Testing ==============")   開始測試

    # load the saved model for evaluation

    param_dict = load_checkpoint("./models/ckpt/checkpoint_lenet-1_1875.ckpt")     加載保存的模型以進行評估

    # load parameter to the network

    load_param_into_net(network, param_dict)     向網絡加載參數

    # load testing dataset

    ds_eval = create_dataset(os.path.join(mnist_path, "test"))   加載測試數據集

    acc = model.eval(ds_eval, dataset_sink_mode=False)  獲取測試模型的精確度

    print("============== Accuracy:{} ==============".format(acc))

 

test_net(network, model, mnist_path)

6、推理(訓練后)

ds_test = create_dataset(test_data_path).create_dict_iterator()     測試數據集

data = next(ds_test)  

images = data["image"].asnumpy()     獲取數據張量

labels = data["label"].asnumpy()       獲取數據labels標簽

 

output = model.predict(Tensor(data['image']))

pred = np.argmax(output.asnumpy(), axis=1)

err_num = []     錯誤數字

index = 1

for i in range(len(labels)):

    plt.subplot(4, 8, i+1)    圖片按照4*8的方式擺放

    color = 'blue' if pred[i] == labels[i] else 'red'    對於圖片標簽顏色的不同,預測正確是藍色,錯誤是紅色

    plt.title("pre:{}".format(pred[i]), color=color)  輸出預測標簽

    plt.imshow(np.squeeze(images[i]))  顯示圖片

    plt.axis("off")

    if color == 'red':    判斷數字是否錯誤

        index = 0

        print("Row {}, column {} is incorrectly identified as {}, the correct value should be {}".format(int(i/8)+1, i%8+1, pred[i], labels[i]), '\n')

if index:   判斷是否所有數字都預測正確

    print("All the figures in this group are predicted correctly!")  

print(pred, "<--Predicted figures")    預測目標數字標簽

print(labels, "<--The right number")   說明數字正確標簽

plt.show()  顯示圖像


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM