keras Model 3 共享的層


1 入門

2 多個輸入和輸出

3 共享層

考慮這樣的一個問題:我們要判斷連個tweet是否來源於同一個人。

首先我們對兩個tweet進行處理,然后將處理的結構拼接在一起,之后跟一個邏輯回歸,輸出這兩條tweet來自同一個人概率。

因為我們對兩條tweet的處理是相同的,所以對第一條tweet的處理的模型,可以被重用來處理第二個tweet。我們考慮用LSTM進行處理。

假設我們的輸入是兩條 280*256的向量

首先定義輸入:

import keras
from keras.layers import Input, LSTM, Dense
from keras.models import Model

tweet_a = Input(shape=(280, 256))
tweet_b = Input(shape=(280, 256))

然后我們共享LSTM。共享層很簡單,只要實例化層一次,然后在你想處理的tensor上調用你想要應用的次數即可(翻譯無力,看代碼)

# This layer can take as input a matrix
# and will return a vector of size 64
shared_lstm = LSTM(64)

# When we reuse the same layer instance
# multiple times, the weights of the layer
# are also being reused
# (it is effectively *the same* layer)
encoded_a = shared_lstm(tweet_a)
encoded_b = shared_lstm(tweet_b)

# We can then concatenate the two vectors:
merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1)

# And add a logistic regression on top
predictions = Dense(1, activation='sigmoid')(merged_vector)

# We define a trainable model linking the
# tweet inputs to the predictions
model = Model(inputs=[tweet_a, tweet_b], outputs=predictions)

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.fit([data_a, data_b], labels, epochs=10)

其實,簡單點說,對一個層的多次調用,就是在共享這個層。這里有一個層的節點的概念

當你在一個輸入tensor上調用一個層時,就會生成一個輸出tensor,就會在這個層上添加一個節點,這個節點連接着這兩個tensor(輸入tensor和輸出tensor)。當你多次調用同一個層的時,

這個層生成的節點就會按照0 ,1, 2, 。。以此類推編號。

那么當一個層有多個節點的時候,我們怎么獲取它的輸出呢?

如果直接通過output獲取會出錯:

a = Input(shape=(280, 256))
b = Input(shape=(280, 256))

lstm = LSTM(32)
encoded_a = lstm(a)
encoded_b = lstm(b)

lstm.output
>> AttributeError: Layer lstm_1 has multiple inbound nodes,
hence the notion of "layer output" is ill-defined.
Use `get_output_at(node_index)` instead.

這時候應該通過索引進行調用:

assert lstm.get_output_at(0) == encoded_a
assert lstm.get_output_at(1) == encoded_b

對於輸入,也是同樣的

a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))

conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)

# Only one input so far, the following will work:
assert conv.input_shape == (None, 32, 32, 3)

conved_b = conv(b)
# now the `.input_shape` property wouldn't work, but this does:
assert conv.get_input_shape_at(0) == (None, 32, 32, 3)
assert conv.get_input_shape_at(1) == (None, 64, 64, 3)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM