在MindSpore中使用model.train訓練網絡時我們難以處理間斷性的任務,為此我們可以考慮使用MindSpore中的Callback機制。
Callback 函數可以在 model.train 的每一步(step)訓練結束后進行自定義的操作。
Callback 函數:
from mindspore.train.callback import Callback
在官方文檔中一般使用 Callback 函數來記錄每一步的loss 或 在一定訓練步數后進行算法評估:
官網地址:
https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/quick_start/quick_start.html
具體使用的代碼:
參考:https://www.cnblogs.com/devilmaycry812839668/p/14971668.html
import matplotlib.pyplot as plt import matplotlib import numpy as np import os import mindspore.nn as nn from mindspore.nn import Accuracy from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore import dtype as mstype import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.vision import Inter from mindspore.common.initializer import Normal from mindspore import Tensor, Model from mindspore.train.callback import Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path (str): Data path batch_size (int): The number of data records in each group repeat_size (int): The number of replicated data records num_parallel_workers (int): The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to a dataset mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """Lenet network structure.""" # define the operator required def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # custom callback function class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval self.steps = 0 def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps = self.steps+1 cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"]) def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path): ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size) eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32) # save the network model and parameters for subsequence fine-tuning config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16) # group layers into an object with training and evaluation features ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck) steps_loss = {"step": [], "loss_value": []} steps_eval = {"step": [], "acc": []} # collect the steps,loss and accuracy information step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval) model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=False) return steps_loss, steps_eval epoch_size = 1 repeat_size = 1 mnist_path = "./datasets/MNIST_Data" model_path = "./models/ckpt/mindspore_quick_start/" # clean up old run files before in Linux os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) lr = 0.01 momentum = 0.9 # create the network network = LeNet5() # define the optimizer net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # define the loss function net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path) print(steps_loss, steps_eval)
運行結果:

核心代碼:
from mindspore.train.callback import Callback # custom callback function class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval self.steps = 0 def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps = self.steps+1 cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"])
可以看到,繼承 Callback 類后我們可以自己定義新的功能類,只要我們實現 step_end 方法即可。
默認傳入給 step_end 方法的參數 run_context 可以通過以下方法獲得當前剛結束的step數和當前的epoch數:
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
其中,cb_params.cur_epoch_num 為當前的epoch數,
cb_params.cur_step_num 為在當前epoch中的當前步數,
需要注意的是,cb_params.cur_step_num 步數不是總共的計算步數,而是在當前epoch中的計算步數。
當前step訓練中的損失值也是可以獲得的,具體如下:
cb_params.net_outputs 代表當前step的損失值
=========================================================
上述代碼,引入繪圖功能的代碼:
import matplotlib.pyplot as plt import matplotlib import numpy as np import os import mindspore.nn as nn from mindspore.nn import Accuracy from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore import dtype as mstype import mindspore.dataset as ds import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.vision import Inter from mindspore.common.initializer import Normal from mindspore import Tensor, Model from mindspore.train.callback import Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path (str): Data path batch_size (int): The number of data records in each group repeat_size (int): The number of replicated data records num_parallel_workers (int): The number of parallel workers """ # define dataset mnist_ds = ds.MnistDataset(data_path) # define some parameters needed for data enhancement and rough justification resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # according to the parameters, generate the corresponding data enhancement method resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # using map to apply operations to a dataset mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # process the generated dataset buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """Lenet network structure.""" # define the operator required def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # custom callback function class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval self.steps = 0 def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num #cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps = self.steps+1 cur_step = self.steps self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"]) def train_model(_model, _epoch_size, _repeat_size, _mnist_path, _model_path): ds_train = create_dataset(os.path.join(_mnist_path, "train"), 32, _repeat_size) eval_dataset = create_dataset(os.path.join(_mnist_path, "test"), 32) # save the network model and parameters for subsequence fine-tuning config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=16) # group layers into an object with training and evaluation features ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=_model_path, config=config_ck) steps_loss = {"step": [], "loss_value": []} steps_eval = {"step": [], "acc": []} # collect the steps,loss and accuracy information step_loss_acc_info = StepLossAccInfo(_model, eval_dataset, steps_loss, steps_eval) model.train(_epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125), step_loss_acc_info], dataset_sink_mode=True) return steps_loss, steps_eval epoch_size = 1 repeat_size = 1 mnist_path = "./datasets/MNIST_Data" model_path = "./models/ckpt/mindspore_quick_start/" # clean up old run files before in Linux os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) lr = 0.01 momentum = 0.9 # create the network network = LeNet5() # define the optimizer net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # define the loss function net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # define the model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) steps_loss, steps_eval = train_model(model, epoch_size, repeat_size, mnist_path, model_path) steps = steps_loss["step"] loss_value = steps_loss["loss_value"] steps = list(map(int, steps)) loss_value = list(map(float, loss_value)) plt.plot(steps, loss_value, color="red") plt.xlabel("Steps") plt.ylabel("Loss_value") plt.title("Change chart of model loss value") plt.show() def eval_show(steps_eval): plt.xlabel("step number") plt.ylabel("Model accuracy") plt.title("Model accuracy variation chart") plt.plot(steps_eval["step"], steps_eval["acc"], "red") plt.show() eval_show(steps_eval)


