在MindSpore中使用model.train训练网络时我们难以处理间断性的任务,为此我们可以考虑使用MindSpore中的Callback机制。
Callback 函数可以在 model.train 的每一步(step)训练结束后进行自定义的操作。
Callback 函数:
from mindspore.train.callback import Callback
在官方文档中一般使用 Callback 函数来记录每一步的loss 或 在一定训练步数后进行算法评估:
官网地址:
https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/quick_start/quick_start.html
具体使用的代码:
参考:https://www.cnblogs.com/devilmaycry812839668/p/14971668.html

import matplotlib.pyplot as plt import matplotlib import numpy as np import os import mindspore.nn as nn from mindspore.nn import Accuracy from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore import dtype as mstype import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.vision import Inter from mindspore.common.initializer import Normal from mindspore import Tensor, Model from mindspore.train.callback import Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path (str): Data path batch_size (int): The number of data records in each group repeat_size (int): The number of replicated data records num_parallel_workers (int): The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to a dataset mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """Lenet network structure.""" # define the operator required def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # custom callback function class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval self.steps = 0 def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps = self.steps+1 cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"]) def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path): ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size) eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32) # save the network model and parameters for subsequence fine-tuning config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16) # group layers into an object with training and evaluation features ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck) steps_loss = {"step": [], "loss_value": []} steps_eval = {"step": [], "acc": []} # collect the steps,loss and accuracy information step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval) model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=False) return steps_loss, steps_eval epoch_size = 1 repeat_size = 1 mnist_path = "./datasets/MNIST_Data" model_path = "./models/ckpt/mindspore_quick_start/" # clean up old run files before in Linux os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) lr = 0.01 momentum = 0.9 # create the network network = LeNet5() # define the optimizer net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # define the loss function net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path) print(steps_loss, steps_eval)
运行结果:

核心代码:
from mindspore.train.callback import Callback # custom callback function class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval self.steps = 0 def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps = self.steps+1 cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"])
可以看到,继承 Callback 类后我们可以自己定义新的功能类,只要我们实现 step_end 方法即可。
默认传入给 step_end 方法的参数 run_context 可以通过以下方法获得当前刚结束的step数和当前的epoch数:
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
其中,cb_params.cur_epoch_num 为当前的epoch数,
cb_params.cur_step_num 为在当前epoch中的当前步数,
需要注意的是,cb_params.cur_step_num 步数不是总共的计算步数,而是在当前epoch中的计算步数。
当前step训练中的损失值也是可以获得的,具体如下:
cb_params.net_outputs 代表当前step的损失值
=========================================================
上述代码,引入绘图功能的代码:

import matplotlib.pyplot as plt import matplotlib import numpy as np import os import mindspore.nn as nn from mindspore.nn import Accuracy from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore import dtype as mstype import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.vision import Inter from mindspore.common.initializer import Normal from mindspore import Tensor, Model from mindspore.train.callback import Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path (str): Data path batch_size (int): The number of data records in each group repeat_size (int): The number of replicated data records num_parallel_workers (int): The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to a dataset mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """Lenet network structure.""" # define the operator required def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # custom callback function class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval self.steps = 0 def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps = self.steps+1 cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"]) def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path): ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size) eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32) # save the network model and parameters for subsequence fine-tuning config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16) # group layers into an object with training and evaluation features ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck) steps_loss = {"step": [], "loss_value": []} steps_eval = {"step": [], "acc": []} # collect the steps,loss and accuracy information step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval) model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=True) return steps_loss, steps_eval epoch_size = 1 repeat_size = 1 mnist_path = "./datasets/MNIST_Data" model_path = "./models/ckpt/mindspore_quick_start/" # clean up old run files before in Linux os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) lr = 0.01 momentum = 0.9 # create the network network = LeNet5() # define the optimizer net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # define the loss function net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path) steps = steps_loss["step"] loss_value = steps_loss["loss_value"] steps = list(map(int, steps)) loss_value = list(map(float, loss_value)) plt.plot(steps, loss_value, color="red") plt.xlabel("Steps") plt.ylabel("Loss_value") plt.title("Change chart of model loss value") plt.show() def eval_show(steps_eval): plt.xlabel("step number") plt.ylabel("Model accuracy") plt.title("Model accuracy variation chart") plt.plot(steps_eval["step"], steps_eval["acc"], "red") plt.show() eval_show(steps_eval)
