Keras lstm 文本分類示例


#基於IMDB數據集的簡單文本分類任務

#一層embedding層+一層lstm層+一層全連接層

#基於Keras 2.1.1 Tensorflow 1.4.0

代碼:

 1 '''Trains an LSTM model on the IMDB sentiment classification task.
 2 The dataset is actually too small for LSTM to be of any advantage
 3 compared to simpler, much faster methods such as TF-IDF + LogReg.
 4 # Notes
 5 - RNNs are tricky. Choice of batch size is important,
 6 choice of loss and optimizer is critical, etc.
 7 Some configurations won't converge.
 8 - LSTM loss decrease patterns during training can be quite different
 9 from what you see with CNNs/MLPs/etc.
10 '''
11 from __future__ import print_function
12 
13 from keras.preprocessing import sequence
14 from keras.models import Sequential
15 from keras.layers import Dense, Embedding
16 from keras.layers import LSTM
17 from keras.datasets import imdb
18 
19 max_features = 20000
20 maxlen = 80  # cut texts after this number of words (among top max_features most common words)
21 batch_size = 32
22 
23 print('Loading data...')
24 (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
25 print(len(x_train), 'train sequences')
26 print(len(x_test), 'test sequences')
27 
28 print('Pad sequences (samples x time)')
29 x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
30 x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
31 print('x_train shape:', x_train.shape)
32 print('x_test shape:', x_test.shape)
33 
34 print('Build model...')    
35 model = Sequential()
36 model.add(Embedding(max_features, 128))
37 model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
38 model.add(Dense(1, activation='sigmoid'))
39 model.summary()
40 
41 # try using different optimizers and different optimizer configs
42 model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
43 
44 print('Train...')
45 model.fit(x_train, y_train,batch_size=batch_size,epochs=15,validation_data=(x_test, y_test))
46 score, acc = model.evaluate(x_test, y_test,batch_size=batch_size)
47 print('Test score:', score)
48 print('Test accuracy:', acc)

結果:

Test accuracy: 0.81248

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM