關鍵代碼:
tflearn.DNN(net, checkpoint_path='model_resnet_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
我的demo:
def get_model(width, height, classes=40): # TODO, modify model network = input_data(shape=[None, width, height, 3]) # if RGB, 224,224,3 # Residual blocks # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18 n = 2 net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001) net = tflearn.residual_block(net, n, 16) net = tflearn.residual_block(net, 1, 32, downsample=True) net = tflearn.residual_block(net, n-1, 32) net = tflearn.residual_block(net, 1, 64, downsample=True) net = tflearn.residual_block(net, n-1, 64) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, classes, activation='softmax') #mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True) mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) return model def main(): trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True) testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True) #trainX = trainX.reshape([-1, width, height, 1]) #testX = testX.reshape([-1, width, height, 1]) print("sample data:") print(trainX[0]) print(trainY[0]) print(testX[-1]) print(testY[-1]) model = get_model(width, height, classes=3755) filename = 'tflearn_resnet/model.tflearn' # try to load model and resume training try: #model.load(filename) model.load("model_resnet_cifar10-195804") print("Model loaded OK. Resume training!") except: pass early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94) try: model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True, snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch. show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite') except StopIteration as e: print("OK, stop iterate!Good!") model.save(filename) del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:] filename = 'tflearn_resnet/model-infer.tflearn' model.save(filename)