【Keras案例學習】 sklearn包裝器使用示范(mnist_sklearn_wrapper)


import numpy as np 
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
# sklean接口的包裝器KerasClassifier,作為sklearn的分類器接口
from keras.wrappers.scikit_learn import KerasClassifier
# 窮搜所有特定的參數值選出最好的模型參數
from sklearn.grid_search import GridSearchCV
Using TensorFlow backend.
# 類別的數目
nb_classes = 10
# 輸入圖像的維度
img_rows, img_cols = 28, 28
# 讀取數據
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 讀取的數據不包含通道維,因此shape為(60000,28,28)
# 為了保持和后端tensorflow的數據格式一致,將數據補上通道維
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
# 新的數據shape為 (60000,28,28,1), 1代表通道是1,也就是灰階圖片
# 指明輸入數據的大小,便於后面搭建網絡的第一層傳入該參數
input_shape = (img_rows, img_cols, 1)
# 數據類型改為float32,單精度浮點數
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
# 數據歸一化(圖像數據常用)
X_train /= 255
X_test /= 255
# 將類別標簽轉換為one-hot編碼
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)
# 定義配置卷積網絡模型的函數
def make_model(dense_layer_sizes, nb_filters, nb_conv, nb_pool):
    model = Sequential()
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                            border_mode='valid',
                            input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
    model.add(Dropout(0.25))
    
    model.add(Flatten())
    for layer_size in dense_layer_sizes:
        model.add(Dense(layer_size))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))
    
    model.compile(loss='categorical_crossentropy',
                  optimizer='adadelta',
                  metrics=['accuracy'])
    return model 
# 全連接層的備選參數列表
dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
# 實現為Keras准備的sklearn分類器接口,創建一個分類器/評估器對象
# 傳入的參數為:
# build_fn: callable function or class instance
# **sk_params: model parameters & fitting parameters
# 具體分析如下:
# 傳入的第一個參數(build_fn)為可回調的函數,該函數建立、配置並返回一個Keras model,
# 該model將被用來訓練/預測,這里我們傳入了剛剛定義好的make_model函數
# 傳入的第二個參數(**sk_params)為關鍵字參數(關鍵字參數在函數內部自動組裝為一個dict),
# 既可以是模型的參數,也可以是訓練的參數,合法的模型參數就是build_fn的參數,
# 注意,像所有sklearn中其他的評估器(estimator)一樣,build_fn應當為其參數提供默認值,
# 以便我們在建立estimator的時候不用向sk_params傳入任何值。
# sk_params也可以接收用來調用fit/predict/predict_proba/score方法的參數,
# 例如'nb_epoch','batch_size'
# fit/predict/predict_proba/score方法的參數將會優先從傳入fit/predict/predict_proba/score
# 的字典參數中選擇,其次才從傳入sk_params的參數中選,最后才選擇keras的Sequential模型的默認參數中選擇
# 這里我們傳入了用於調用fit方法的batch_size參數
my_classifier = KerasClassifier(make_model, batch_size=32)
# 當調用sklearn的grid_search接口時,合法的可調參數就是傳給sk_params的參數,包括訓練參數
# 換句話說,就是可以用grid_search來選擇最佳的batch_size/nb_epoch,或者其他的一些模型參數

# GridSearchCV類,窮搜(Exhaustive search)評估器中所有特定的參數,
# 其重要的兩類方法為fit和predict
# 傳入參數為評估器對象my_classifier,由每一個grid point實例化一個estimator
# 參數網格param_grid,類型為dict,需要嘗試的參數名稱以及對應的數值
# 評估方式scoring,這里采用對數損失來評估
validator = GridSearchCV(my_classifier,
                         param_grid={'dense_layer_sizes': dense_size_candidates,
                                     'nb_epoch': [3,6],
                                     'nb_filters': [8],
                                     'nb_conv': [3],
                                     'nb_pool': [2]},
                         scoring='log_loss')
# 根據各個參數值的不同組合在(X_train, y_train)上訓練模型
validator.fit(X_train, y_train)
# 打印出訓練過程中找到的最佳參數
print('Yhe parameters of the best model are: ')
print(validator.best_params_)
Epoch 1/3
40000/40000 [==============================] - 14s - loss: 0.8058 - acc: 0.7335    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.4620 - acc: 0.8545    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.3958 - acc: 0.8747    
19776/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.9589 - acc: 0.6804    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.5885 - acc: 0.8116    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.5021 - acc: 0.8429    
19488/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.9141 - acc: 0.6958    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.5716 - acc: 0.8136    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.4515 - acc: 0.8547    
19584/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.8968 - acc: 0.6983    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.5692 - acc: 0.8130    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.4600 - acc: 0.8494    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.4091 - acc: 0.8694    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3717 - acc: 0.8790    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.3461 - acc: 0.8898    
20000/20000 [==============================] - 1s     
Epoch 1/6
40000/40000 [==============================] - 11s - loss: 0.8089 - acc: 0.7310    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.4770 - acc: 0.8498    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.4086 - acc: 0.8704    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.3657 - acc: 0.8860    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3383 - acc: 0.8938    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.3164 - acc: 0.9027    
19520/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.8393 - acc: 0.7214    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.5132 - acc: 0.8379    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.4331 - acc: 0.8635    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.3813 - acc: 0.8808    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.3530 - acc: 0.8902    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.3278 - acc: 0.8986    
19936/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.5975 - acc: 0.8099    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.3181 - acc: 0.9048    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.2673 - acc: 0.9199    
19808/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 11s - loss: 0.6155 - acc: 0.8040    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.3500 - acc: 0.8951    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.2864 - acc: 0.9156    
20000/20000 [==============================] - 1s     
Epoch 1/3
40000/40000 [==============================] - 11s - loss: 0.7519 - acc: 0.7560    
Epoch 2/3
40000/40000 [==============================] - 10s - loss: 0.4660 - acc: 0.8580    
Epoch 3/3
40000/40000 [==============================] - 10s - loss: 0.3553 - acc: 0.8936    
19776/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.5869 - acc: 0.8162    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.3279 - acc: 0.9014    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2725 - acc: 0.9187    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.2366 - acc: 0.9291    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.2102 - acc: 0.9386    
Epoch 6/6
40000/40000 [==============================] - 16s - loss: 0.1954 - acc: 0.9423    
19840/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.5526 - acc: 0.8262    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.2903 - acc: 0.9142    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2361 - acc: 0.9302    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.2064 - acc: 0.9396    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.1886 - acc: 0.9443    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.1755 - acc: 0.9496    
19808/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 11s - loss: 0.7275 - acc: 0.7677    
Epoch 2/6
40000/40000 [==============================] - 10s - loss: 0.4141 - acc: 0.8772    
Epoch 3/6
40000/40000 [==============================] - 10s - loss: 0.3136 - acc: 0.9056    
Epoch 4/6
40000/40000 [==============================] - 10s - loss: 0.2651 - acc: 0.9210    
Epoch 5/6
40000/40000 [==============================] - 10s - loss: 0.2363 - acc: 0.9306    
Epoch 6/6
40000/40000 [==============================] - 10s - loss: 0.2092 - acc: 0.9380    
19552/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.7849 - acc: 0.7334    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.4506 - acc: 0.8587    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.3741 - acc: 0.8813    
19872/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.8744 - acc: 0.7068    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.5231 - acc: 0.8312    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.4305 - acc: 0.8635    
19552/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.7567 - acc: 0.7473    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.4200 - acc: 0.8685    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.3604 - acc: 0.8887    
19712/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.7111 - acc: 0.7676    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.4243 - acc: 0.8669    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.3638 - acc: 0.8873    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3223 - acc: 0.8995    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.2994 - acc: 0.9073    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.2823 - acc: 0.9135    
20000/20000 [==============================] - 2s     
Epoch 1/6
40000/40000 [==============================] - 12s - loss: 0.7588 - acc: 0.7513    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.4568 - acc: 0.8570    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.3757 - acc: 0.8819    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3256 - acc: 0.8969    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.2996 - acc: 0.9060    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.2702 - acc: 0.9146    
19904/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 12s - loss: 0.7798 - acc: 0.7464    
Epoch 2/6
40000/40000 [==============================] - 11s - loss: 0.4625 - acc: 0.8571    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.3869 - acc: 0.8814    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.3429 - acc: 0.8959    
Epoch 5/6
40000/40000 [==============================] - 11s - loss: 0.3143 - acc: 0.9035    
Epoch 6/6
40000/40000 [==============================] - 11s - loss: 0.2889 - acc: 0.9122    
19840/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.5828 - acc: 0.8161    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.3009 - acc: 0.9099    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.2393 - acc: 0.9291    
19680/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 12s - loss: 0.5584 - acc: 0.8246    
Epoch 2/3
40000/40000 [==============================] - 12s - loss: 0.2862 - acc: 0.9152    
Epoch 3/3
40000/40000 [==============================] - 11s - loss: 0.2334 - acc: 0.9319    
19488/20000 [============================>.] - ETA: 0sEpoch 1/3
40000/40000 [==============================] - 13s - loss: 0.6253 - acc: 0.8020    
Epoch 2/3
40000/40000 [==============================] - 11s - loss: 0.3054 - acc: 0.9093    
Epoch 3/3
40000/40000 [==============================] - 12s - loss: 0.2463 - acc: 0.9278    
19808/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.5753 - acc: 0.8200    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.2827 - acc: 0.9170    
Epoch 3/6
40000/40000 [==============================] - 11s - loss: 0.2217 - acc: 0.9339    
Epoch 4/6
40000/40000 [==============================] - 11s - loss: 0.1863 - acc: 0.9455    
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.1663 - acc: 0.9516    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1535 - acc: 0.9550    
19680/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.5670 - acc: 0.8247    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.2728 - acc: 0.9204    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2134 - acc: 0.9383    
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.1890 - acc: 0.9459    
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.1695 - acc: 0.9501    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1570 - acc: 0.9535    
19712/20000 [============================>.] - ETA: 0sEpoch 1/6
40000/40000 [==============================] - 13s - loss: 0.6227 - acc: 0.7986    
Epoch 2/6
40000/40000 [==============================] - 12s - loss: 0.3322 - acc: 0.9007    
Epoch 3/6
40000/40000 [==============================] - 12s - loss: 0.2469 - acc: 0.9258    
Epoch 4/6
40000/40000 [==============================] - 12s - loss: 0.2029 - acc: 0.9409    
Epoch 5/6
40000/40000 [==============================] - 12s - loss: 0.1748 - acc: 0.9496    
Epoch 6/6
40000/40000 [==============================] - 12s - loss: 0.1558 - acc: 0.9542    
19872/20000 [============================>.] - ETA: 0sEpoch 1/6
60000/60000 [==============================] - 19s - loss: 0.4922 - acc: 0.8482    
Epoch 2/6
60000/60000 [==============================] - 24s - loss: 0.2342 - acc: 0.9318    
Epoch 3/6
60000/60000 [==============================] - 24s - loss: 0.1843 - acc: 0.9485    
Epoch 4/6
60000/60000 [==============================] - 25s - loss: 0.1556 - acc: 0.9549    
Epoch 5/6
60000/60000 [==============================] - 24s - loss: 0.1450 - acc: 0.9581    
Epoch 6/6
60000/60000 [==============================] - 25s - loss: 0.1312 - acc: 0.9624    
Yhe parameters of the best model are: 
{'nb_conv': 3, 'nb_epoch': 6, 'nb_pool': 2, 'dense_layer_sizes': [64, 64], 'nb_filters': 8}
# validator.best_estimator_返回sklearn-warpped版本的最佳模型
# validator.best_estimator_.model返回未包裝的最佳模型
best_model = validator.best_estimator_.model
# 度量值的名稱
metric_names = best_model.metrics_names 
# metric_names = ['loss', 'acc']
# 度量值的數值
metric_values = best_model.evaluate(X_test, y_test)
# metric_values = [0.0550, 0.9826]
print()
for metric, value in zip(metric_names, metric_values):
    print(metric, ': ', value)
 9984/10000 [============================>.] - ETA: 0s
loss :  0.0550105490824
acc :  0.9826




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM