tensorflow=2.0+
在使用tensorflow加載模型的時候有時候需要查看這個模型某一層的輸出。
搭建一個簡單的神經網絡,識別cifar數據集:
點擊查看代碼
model = tf.keras.models.Sequential()
model.add(Flatten())
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dense(1024, activation='relu', name="dense_test1"))
model.add(Dense(2048, activation='relu', name="dense_test2"))
model.add(Dense(2048, activation='relu', name="dense_test3"))
model.add(Dense(2048, activation='relu', name="dense_test4"))
model.add(Dense(2048, activation='relu', name="dense_test5"))
model.add(Dense(10, activation='softmax', name="dense_test6"))
搭建好后進行訓練可以得到一個訓練好的神經網絡
此時通過模型去識別cifar測試集
model.predict(x_test)#x_test是cifar數據集
可以得到預測結果,此時若想查看中間某一層的輸出,比如全連接的第二層——dense_test2的輸出怎么辦呢。
這時候,可以直接截取此模型的子模型,直接將 dense_test2作為最后一層輸出。
sub_model = tf.keras.Model(inputs = model.input, outputs = model.get_layer(dense_test2).output)
sub_model.predict(x_test)
這樣就可以獲取中間某一層的輸出了。