MindSpore 初探, 使用LeNet訓練minist數據集


如題所述,官網地址:

https://www.mindspore.cn/tutorial/zh-CN/r1.2/quick_start.html

 

 

數據集下載:

mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test
wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte
wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte
wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte
wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte
tree ./datasets/MNIST_Data

 

 

 

 

 

 

 

 

個人整合后的代碼:

#!/usr/bin python
# encoding:UTF-8

"""" 對輸入的超參數進行處理 """
import os
import argparse

""" 設置運行的背景context """
from mindspore import context

""" 對數據集進行預處理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype

""" 構建神經網絡 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal

""" 訓練時對模型參數的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

""" 導入模型訓練需要的庫 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model


parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])

args = parser.parse_known_args()[0]

# 為mindspore設置運行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)


def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    # 定義數據集
    mnist_ds = ds.MnistDataset(data_path)
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # 定義所需要操作的map映射
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # 使用map映射函數,將數據操作應用到數據集
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    # 進行shuffle、batch、repeat操作
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds


class LeNet5(nn.Cell):
    """
    Lenet網絡結構
    """

    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 定義所需要的運算
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        # 使用定義好的運算構建前向網絡
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# 實例化網絡
net = LeNet5()

# 定義損失函數
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# 定義優化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)

# 設置模型保存參數
# 每125steps保存一次模型參數,最多保留15個文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 應用模型保存參數
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)


def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
    """定義訓練的方法"""
    # 加載訓練數據集
    ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)


def test_net(network, model, data_path):
    """定義驗證的方法"""
    ds_eval = create_dataset(os.path.join(data_path, "test"))
    acc = model.eval(ds_eval, dataset_sink_mode=False)
    print("{}".format(acc))


mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
test_net(net, model, mnist_path)

 

 

 

訓練結果:

epoch: 1 step: 125, loss is 2.2982173
epoch: 1 step: 250, loss is 2.296105
epoch: 1 step: 375, loss is 2.3065567
epoch: 1 step: 500, loss is 2.3062077
epoch: 1 step: 625, loss is 2.3096561
epoch: 1 step: 750, loss is 2.2847052
epoch: 1 step: 875, loss is 2.284628
epoch: 1 step: 1000, loss is 1.8122461
epoch: 1 step: 1125, loss is 0.4140602
epoch: 1 step: 1250, loss is 0.25238502
epoch: 1 step: 1375, loss is 0.17819008
epoch: 1 step: 1500, loss is 0.3202765
epoch: 1 step: 1625, loss is 0.12312577
epoch: 1 step: 1750, loss is 0.11027573
epoch: 1 step: 1875, loss is 0.2680659
{'Accuracy': 0.9598357371794872}
View Code

 

 

 

 

 

 

為網絡導入模型參數,並進行預測:

本步驟與上面的訓練步驟相關,需要前面設置好的數據集,並且需要前面已經訓練好的網絡參數。

import os
import numpy as np

""" 構建神經網絡 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore import Tensor

# 導入模型參數
from mindspore.train.serialization import load_checkpoint, load_param_into_net

""" 對數據集進行預處理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype

""" 導入模型訓練需要的庫 """
from mindspore.nn import Accuracy
from mindspore import Model


def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    # 定義數據集
    mnist_ds = ds.MnistDataset(data_path)
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # 定義所需要操作的map映射
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # 使用map映射函數,將數據操作應用到數據集
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)

    # 進行shuffle、batch、repeat操作
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds


class LeNet5(nn.Cell):
    """
    Lenet網絡結構
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 定義所需要的運算
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        # 使用定義好的運算構建前向網絡
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


# 實例化網絡
net = LeNet5()
# 定義損失函數
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 定義優化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
# 構建模型
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})


# 加載已經保存的用於測試的模型
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 加載參數到網絡中
load_param_into_net(net, param_dict)


_batch_size = 8
# 定義測試數據集,batch_size設置為1,則取出一張圖片
mnist_path = "./datasets/MNIST_Data"
ds_test = create_dataset(os.path.join(mnist_path, "test"), batch_size=_batch_size).create_dict_iterator()
data = next(ds_test)

# images為測試圖片,labels為測試圖片的實際分類
images = data["image"].asnumpy()
labels = data["label"].asnumpy()

# 使用函數model.predict預測image對應分類
output = model.predict(Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1)

# 輸出預測分類與實際分類
for i in range(_batch_size):
    print(f'Predicted: "{predicted[i]}", Actual: "{labels[i]}"')

 

 

 

 

運行結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM