使用LeNet算法实现手写数字识别(代码部分解释)


# 确认当前环境的版本

import mindspore

print(mindspore.__version__)

1. 数据集下载

MNIST是一个手写数字数据集,训练集包含60000张手写数字,测试集包含10000张手写数字,共10类。

从华为云OBS公共桶中下载。

import os

import moxing as mox

 

if not os.path.exists("./MNIST_Data.zip"):

    mox.file.copy("obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip", "./MNIST_Data.zip")

下载Minist数据集

!unzip -o MNIST_Data.zip -d ./

!tree ./MNIST_Data/

2、处理MNIST数据集

由于我们后面会采用LeNet这样的卷积神经网络对数据集进行训练,而采用LeNet在训练数据时,对数据格式是有所要求的,所以接下来的工作需要我们先查看数据集内的数据是什么样的,这样才能构造一个针对性的数据转换函数,将数据集数据转换成符合训练要求的数据形式。

 

步骤1 查看原始数据集数据:

from mindspore import context 导入context 模块,调用context的set_context方法

import matplotlib.pyplot as plt

import matplotlib

import numpy as np

import mindspore.dataset as ds

上面是调用画图所需要的库

# device_target 可选 CPU/GPU/Ascend, 当选择GPU时mindspore规格也需要切换到GPU device_target表示硬件信息,有三个选项(CPU、GPU、Ascend)

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

# device_id = int(os.getenv("DEVICE_ID"))

# context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)

train_data_path = "./MNIST_Data/train"    训练数据集路径

test_data_path = "./MNIST_Data/test" 测试数据集路径

mnist_ds = ds.MnistDataset(train_data_path)

print('The type of mnist_ds:', type(mnist_ds))

print("Number of pictures contained in the mnist_ds:", mnist_ds.get_dataset_size())

 

dic_ds = mnist_ds.create_dict_iterator()

item = next(dic_ds) 字典类型数据

img = item["image"].asnumpy()

label = item["label"].asnumpy()

print("The item of mnist_ds:", item.keys())  

print("Tensor of image in item:", img.shape)

print("The label of item:", label)

 

plt.imshow(np.squeeze(img))

plt.title("number:%s"% item["label"].asnumpy())  

plt.show()   显示图像

步骤2 数据处理:

数据集对于训练非常重要,好的数据集可以有效提高训练精度和效率,在加载数据集前,我们通常会对数据集进行一些处理。

定义数据集及数据操作

定义完成后,使用create_datasets对原始数据进行增强操作,并抽取一个batch的数据,查看数据增强后的变化。

定义一个creat_dataset函数,对原始数据进行增强(进行数据增强提升样本质量,避免局部最优等一些问题)

import mindspore.dataset.vision.c_transforms as CV

import mindspore.dataset.transforms.c_transforms as C

from mindspore.dataset.vision import Inter

from mindspore import dtype as mstype

 

 

def create_dataset(data_path, batch_size=32, repeat_size=1,

                   num_parallel_workers=1):

    """

    create dataset for train or test

    Args:

        data_path (str): Data path

        batch_size (int): The number of data records in each group

        repeat_size (int): The number of replicated data records

        num_parallel_workers (int): The number of parallel workers

    """

    # define dataset

    mnist_ds = ds.MnistDataset(data_path) 将原始的Minist数据集加载进来

 

# define some parameters needed for data enhancement and rough justification

定义数据增强和处理的一些参数

    resize_height, resize_width = 32, 32  图片大小32*32

    rescale = 1.0 / 255.0 图像缩放因子,使每个像素点在(0,255)范围内

    shift = 0.0   图像偏移因子

    rescale_nml = 1 / 0.3081   

    shift_nml = -1 * 0.1307 / 0.3081

 

    # according to the parameters, generate the corresponding data enhancement method

    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  使用给定的插值模式将输入图像调整为一定大小

    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)

rescale_op = CV.Rescale(rescale, shift) 使用给定的缩放和移位消除图片在不同位置带来的影响

hwc2chw_op = CV.HWC2CHW() 将输入图像从(H,W,C)转化为(C,H,W)

    type_cast_op = C.TypeCast(mstype.int32)

转化为Minist特定的数据类型mstype

# using map to apply operations to a dataset

将数据增强处理方法映射到(使用)在对应数据集的相应部分(image,label)

    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    

    # process the generated dataset

    buffer_size = 10000

    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)   将数据集进行打乱

    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)   数据集按照batch_size的大小分成若干批次,按照批次进行训练,drop_remainder=True是保留数据集整除batch_size的,多余的舍弃掉

    mnist_ds = mnist_ds.repeat(repeat_size)

将数据集重复repeat_size次

    return mnist_ds

展示数据集

ms_dataset = create_dataset(train_data_path)

print('Number of groups in the dataset:', ms_dataset.get_dataset_size())  查看一共分成多少组,将数据一共分成60000/32=1875个批次

步骤3 进一步查看增强后的数据:

  • 从1875组数据中取出一组数据查看其数据张量及label。将张量数据和下标对应的值进行可视化。

data = next(ms_dataset.create_dict_iterator(output_numpy=True))  利用next获取样本并查看单个样本格式

images = data["image"]

labels = data["label"]   获取图像张量

print('Tensor of image:', images.shape)

print('labels:', labels)

  • 将张量数据和下标对应的值进行可视化。

count = 1

for i in images:

    plt.subplot(4, 8, count)

    plt.imshow(np.squeeze(i))

    plt.title('num:%s'%labels[count-1])

    plt.xticks([])

    count += 1

    plt.axis("off")

plt.show()

3、 定义模型

在对手写字体识别上,通常采用卷积神经网络架构(CNN)进行学习预测,最经典的属1998年由Yann LeCun创建的LeNet5架构,
结构示意如下图:

 

import mindspore.nn as nn

from mindspore.common.initializer import Normal

 

class LeNet5(nn.Cell):

    """Lenet network structure."""

    # define the operator required

    def __init__(self, num_class=10, num_channel=1):

        super(LeNet5, self).__init__()   #继承父类nn.cell的__init__方法

        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')

        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')   pad_mode是卷积方式 ‘valid’是pad_mode = 0

        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))

        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))

        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))

nn.Dense为致密连接层,它的第一个参数为输入层的维度,第二个参数为输出的维度,第三个参数为神经网络可训练参数W权重矩阵的初始化方式,默认为normal

        self.relu = nn.ReLU()

        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)  将图片的宽度和高度都缩小一半

        self.flatten = nn.Flatten()   nn.Flatten为输入展成平图层,即去掉那些空的维度

 

    # use the preceding operators to construct networks

    def construct(self, x):        使用前馈操作构建神经网络

        x = self.max_pool2d(self.relu(self.conv1(x)))

        x = self.max_pool2d(self.relu(self.conv2(x)))

        x = self.flatten(x)

        x = self.relu(self.fc1(x))

        x = self.relu(self.fc2(x))

        x = self.fc3(x)

        return x

    

network = LeNet5()

print("layer conv1:", network.conv1)

print("*"*40)

print("layer fc1:", network.fc1)

4、搭建训练网络并进行训练

构建完成神经网络后,就可以着手进行训练网络的构建,模型训练函数为Model.train

此步骤案例中使用epoch=1,使用CPU训练大概耗时15分钟,为实现快速训练,可选用更高规格的资源训练。

import os

from mindspore import Tensor, Model

from mindspore import load_checkpoint, load_param_into_net

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor

from mindspore.nn.metrics import Accuracy

from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits

 

lr = 0.01   learingrate,学习率,可以使梯度下降的幅度变小,从而可以更好的训练参数

momentum = 0.9

epoch_size = 1     每个epoch需要遍历完成图片的batch数,这里是只要遍历一次

mnist_path = "./MNIST_Data"

model_path = "./models/ckpt/"

 

# clean up old run files before in Linux

os.system('rm -f {}*.ckpt {}*.meta {}*.pb'.format(model_path, model_path, model_path))    清理旧文件

 

# create the network

network = LeNet5()    创建神经网络

 

# define the optimizer

net_opt = nn.Momentum(network.trainable_params(), lr, momentum)   使用了Momentum优化器进行优化

 

# define the loss function

net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

相当于softmax分类器      sparse指定标签(label)是否使用稀疏模式,默认为false,reduction为损失的减少类型:mean表示平均值,一般情况下都是选择平均地减少

# define the model

model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()} )

调用Model高级API   metrics 参数是指训练和测试期

# save the network model and parameters for subsequence fine-tuning

config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16)     保存训练好的模型参数的路径

# group layers into an object with training and evaluation features

ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=model_path, config=config_ck)    保存训练好的模型参数的路径

 

 

eval_dataset = create_dataset("./MNIST_Data/test")  测试数据集

 

step_loss = {"step": [], "loss_value": []}   回调函数中的数据格式

steps_eval = {"step": [], "acc": []}     回调函数中的数据格式

(收集step对应模型精度值accuracy的信息)

repeat_size = 1   

ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)

 

model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=False)   调用Model类的train方法进行训练

5、测试数据验证模型精度

# testing relate modules

def test_net(network, model, mnist_path):     测试神经网络函数

    """Define the evaluation method."""

    print("============== Starting Testing ==============")   开始测试

    # load the saved model for evaluation

    param_dict = load_checkpoint("./models/ckpt/checkpoint_lenet-1_1875.ckpt")     加载保存的模型以进行评估

    # load parameter to the network

    load_param_into_net(network, param_dict)     向网络加载参数

    # load testing dataset

    ds_eval = create_dataset(os.path.join(mnist_path, "test"))   加载测试数据集

    acc = model.eval(ds_eval, dataset_sink_mode=False)  获取测试模型的精确度

    print("============== Accuracy:{} ==============".format(acc))

 

test_net(network, model, mnist_path)

6、推理(训练后)

ds_test = create_dataset(test_data_path).create_dict_iterator()     测试数据集

data = next(ds_test)  

images = data["image"].asnumpy()     获取数据张量

labels = data["label"].asnumpy()       获取数据labels标签

 

output = model.predict(Tensor(data['image']))

pred = np.argmax(output.asnumpy(), axis=1)

err_num = []     错误数字

index = 1

for i in range(len(labels)):

    plt.subplot(4, 8, i+1)    图片按照4*8的方式摆放

    color = 'blue' if pred[i] == labels[i] else 'red'    对于图片标签颜色的不同,预测正确是蓝色,错误是红色

    plt.title("pre:{}".format(pred[i]), color=color)  输出预测标签

    plt.imshow(np.squeeze(images[i]))  显示图片

    plt.axis("off")

    if color == 'red':    判断数字是否错误

        index = 0

        print("Row {}, column {} is incorrectly identified as {}, the correct value should be {}".format(int(i/8)+1, i%8+1, pred[i], labels[i]), '\n')

if index:   判断是否所有数字都预测正确

    print("All the figures in this group are predicted correctly!")  

print(pred, "<--Predicted figures")    预测目标数字标签

print(labels, "<--The right number")   说明数字正确标签

plt.show()  显示图像


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM