《python深度學習》筆記---6.1-2、word embedding-利用 Embedding 層學習詞嵌入


《python深度學習》筆記---6.1-2、word embedding-利用 Embedding 層學習詞嵌入

一、總結

一句話總結:

【考慮到僅查看每條評論的前 20 個單詞】:得到的驗證精度約為 76%,考慮到僅查看每條評論的前 20 個單詞,這個結果還是相當不錯 的。
【沒有考慮單詞之間的關系和句子結構】:但請注意,僅僅將嵌入序列展開並在上面訓練一個 Dense 層,會導致模型對輸入序列中的 每個單詞單獨處理,而沒有考慮單詞之間的關系和句子結構(舉個例子,這個模型可能會將 this movie is a bomb和this movie is the bomb兩條都歸為負面評論)。
【添加循環層或一維卷積層】:更好的做法是在嵌入序列上添 加循環層或一維卷積層,將每個序列作為整體來學習特征。
model.add(Embedding(10000, 8, input_length=maxlen))
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense

model = Sequential()
# We specify the maximum input length to our Embedding layer
# so we can later flatten the embedded inputs
# 指定 Embedding 層的最大輸入長度,以 便后面將嵌入輸入展平。
# Embedding 層 激活的形狀為 (samples, maxlen, 8)
model.add(Embedding(10000, 8, input_length=maxlen))
# After the Embedding layer, 
# our activations have shape `(samples, maxlen, 8)`.

# We flatten the 3D tensor of embeddings 
# into a 2D tensor of shape `(samples, maxlen * 8)`
model.add(Flatten())

# We add the classifier on top
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
model.summary()

history = model.fit(x_train, y_train,
                    epochs=10,
                    batch_size=32,
                    validation_split=0.2)

 

 

1、Embedding 層理解?

【字典:Embedding層實際上是一種字典查找】:最好將 Embedding 層理解為一個字典,將整數索引(表示特定單詞)映射為密集向量。它 接收整數作為輸入,並在內部字典中查找這些整數,然后返回相關聯的向量。Embedding 層實 際上是一種字典查找
【單詞索引-->Embedding層-->對應的詞向量】

 

 

2、Embedding 層的輸入?

【二維整數張量:(samples, sequence_length)】:Embedding 層的輸入是一個二維整數張量,其形狀為 (samples, sequence_length), 每個元素是一個整數序列。
【(32, 10)(32 個長度為10 的序列組成的批量)】:它能夠嵌入長度可變的序列,例如,對於前一個例子中的 Embedding 層,你可以輸入形狀為 (32, 10)(32 個長度為10 的序列組成的批量)或 (64, 15)(64 個長度為15 的序列組成的批量)的批量。
【不夠0填充,較長被截斷】:不過一批數據中的所有序列必須具有相同的 長度(因為需要將它們打包成一個張量),所以較短的序列應該用 0 填充,較長的序列應該被截斷。

 

 

3、Embedding 層 輸出?

【也就是在原來的基礎上擴充了embedding_dimensionality維】:【(samples, sequence_length, embedding_dimensionality)的三維浮點數張量】:這個Embedding 層返回一個形狀為(samples, sequence_length, embedding_ dimensionality) 的三維浮點數張量。然后可以用 RNN 層或一維卷積層來處理這個三維張量

 

 

4、IMDB 電影評論情感預測任務 word Embedding 層 實例?

【限制為前10 000 個最常見的單詞】:首先,我們需要快速准備 數據。將電影評論限制為前10 000 個最常見的單詞(第一次處理這個數據集時就是這么做的), 然后將評論長度限制為只有20 個單詞。
【將輸入的整數序列(二維整數張量)轉換為嵌入序列(三維浮點數張量),然后將這個張量展平為二維】:對於這10 000 個單詞,網絡將對每個詞都學習一個8 維嵌入,將輸入的整數序列(二維整數張量)轉換為嵌入序列(三維浮點數張量),然后將這個 張量展平為二維,最后在上面訓練一個 Dense 層用於分類。

 

 

二、word embedding-利用 Embedding 層學習詞嵌入

博客對應課程的視頻位置:

 

1、將一個 Embedding 層實例化

In [1]:
from tensorflow.keras.layers import Embedding # The Embedding layer takes at least two arguments: # the number of possible tokens, here 1000 (1 + maximum word index), # and the dimensionality of the embeddings, here 64. embedding_layer = Embedding(1000, 64) 
In [2]:
print(embedding_layer) 
<tensorflow.python.keras.layers.embeddings.Embedding object at 0x000001B71A106EC8>
In [3]:
print(dir(embedding_layer)) 
['_TF_MODULE_IGNORED_PROPERTIES', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_activity_regularizer', '_add_trackable', '_add_variable_with_custom_getter', '_auto_track_sub_layers', '_autocast', '_autographed_call', '_batch_input_shape', '_build_input_shape', '_call_accepts_kwargs', '_call_arg_was_passed', '_call_fn_arg_defaults', '_call_fn_arg_positions', '_call_fn_args', '_call_full_argspec', '_callable_losses', '_cast_single_input', '_checkpoint_dependencies', '_clear_losses', '_compute_dtype', '_compute_dtype_object', '_dedup_weights', '_default_training_arg', '_deferred_dependencies', '_dtype', '_dtype_defaulted_to_floatx', '_dtype_policy', '_dynamic', '_eager_losses', '_expects_mask_arg', '_expects_training_arg', '_flatten', '_flatten_layers', '_functional_construction_call', '_gather_children_attribute', '_gather_saveables_for_checkpoint', '_get_call_arg_value', '_get_existing_metric', '_get_input_masks', '_get_node_attribute_at_index', '_get_save_spec', '_get_trainable_state', '_handle_activity_regularization', '_handle_deferred_dependencies', '_handle_weight_regularization', '_inbound_nodes', '_infer_output_signature', '_init_call_fn_args', '_init_set_name', '_initial_weights', '_input_spec', '_is_layer', '_keras_api_names', '_keras_api_names_v1', '_keras_tensor_symbolic_call', '_layers', '_list_extra_dependencies_for_serialization', '_list_functions_for_serialization', '_lookup_dependency', '_losses', '_map_resources', '_maybe_build', '_maybe_cast_inputs', '_maybe_create_attribute', '_maybe_initialize_trackable', '_metrics', '_metrics_lock', '_must_restore_from_config', '_name', '_name_based_attribute_restore', '_name_based_restores', '_name_scope', '_no_dependency', '_non_trainable_weights', '_obj_reference_counts', '_obj_reference_counts_dict', '_object_identifier', '_outbound_nodes', '_preload_simple_restoration', '_restore_from_checkpoint_position', '_saved_model_inputs_spec', '_self_setattr_tracking', '_set_call_arg_value', '_set_connectivity_metadata', '_set_dtype_policy', '_set_mask_keras_history_checked', '_set_mask_metadata', '_set_save_spec', '_set_trainable_state', '_set_training_mode', '_setattr_tracking', '_should_cast_single_input', '_single_restoration_from_checkpoint_position', '_split_out_first_arg', '_stateful', '_supports_masking', '_symbolic_call', '_tf_api_names', '_tf_api_names_v1', '_thread_local', '_track_trackable', '_trackable_saved_model_saver', '_tracking_metadata', '_trainable', '_trainable_weights', '_unconditional_checkpoint_dependencies', '_unconditional_dependency_names', '_update_uid', '_updates', '_warn_about_input_casting', 'activity_regularizer', 'add_loss', 'add_metric', 'add_update', 'add_variable', 'add_weight', 'apply', 'build', 'built', 'call', 'compute_mask', 'compute_output_shape', 'compute_output_signature', 'count_params', 'dtype', 'dynamic', 'embeddings_constraint', 'embeddings_initializer', 'embeddings_regularizer', 'from_config', 'get_config', 'get_input_at', 'get_input_mask_at', 'get_input_shape_at', 'get_losses_for', 'get_output_at', 'get_output_mask_at', 'get_output_shape_at', 'get_updates_for', 'get_weights', 'inbound_nodes', 'input', 'input_dim', 'input_length', 'input_mask', 'input_shape', 'input_spec', 'losses', 'mask_zero', 'metrics', 'name', 'name_scope', 'non_trainable_variables', 'non_trainable_weights', 'outbound_nodes', 'output', 'output_dim', 'output_mask', 'output_shape', 'set_weights', 'stateful', 'submodules', 'supports_masking', 'trainable', 'trainable_variables', 'trainable_weights', 'updates', 'variables', 'weights', 'with_name_scope']

2、加載 IMDB 數據,准備用於 Embedding 層

In [4]:
from tensorflow.keras.datasets import imdb from tensorflow.keras import preprocessing # Number of words to consider as features max_features = 10000 # Cut texts after this number of words # (among top max_features most common words) maxlen = 20 # 將數據加載為整數列表 # Load the data as lists of integers. (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) print(x_train.shape) print(x_test.shape) print(x_train[0]) print(x_train[1]) print(x_train) 
(25000,)
(25000,)
[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
[1, 194, 1153, 194, 8255, 78, 228, 5, 6, 1463, 4369, 5012, 134, 26, 4, 715, 8, 118, 1634, 14, 394, 20, 13, 119, 954, 189, 102, 5, 207, 110, 3103, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2300, 1523, 5, 647, 4, 116, 9, 35, 8163, 4, 229, 9, 340, 1322, 4, 118, 9, 4, 130, 4901, 19, 4, 1002, 5, 89, 29, 952, 46, 37, 4, 455, 9, 45, 43, 38, 1543, 1905, 398, 4, 1649, 26, 6853, 5, 163, 11, 3215, 2, 4, 1153, 9, 194, 775, 7, 8255, 2, 349, 2637, 148, 605, 2, 8003, 15, 123, 125, 68, 2, 6853, 15, 349, 165, 4362, 98, 5, 4, 228, 9, 43, 2, 1157, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 4373, 228, 8255, 5, 2, 656, 245, 2350, 5, 4, 9837, 131, 152, 491, 18, 2, 32, 7464, 1212, 14, 9, 6, 371, 78, 22, 625, 64, 1382, 9, 8, 168, 145, 23, 4, 1690, 15, 16, 4, 1355, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95]
[list([1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32])
 list([1, 194, 1153, 194, 8255, 78, 228, 5, 6, 1463, 4369, 5012, 134, 26, 4, 715, 8, 118, 1634, 14, 394, 20, 13, 119, 954, 189, 102, 5, 207, 110, 3103, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2300, 1523, 5, 647, 4, 116, 9, 35, 8163, 4, 229, 9, 340, 1322, 4, 118, 9, 4, 130, 4901, 19, 4, 1002, 5, 89, 29, 952, 46, 37, 4, 455, 9, 45, 43, 38, 1543, 1905, 398, 4, 1649, 26, 6853, 5, 163, 11, 3215, 2, 4, 1153, 9, 194, 775, 7, 8255, 2, 349, 2637, 148, 605, 2, 8003, 15, 123, 125, 68, 2, 6853, 15, 349, 165, 4362, 98, 5, 4, 228, 9, 43, 2, 1157, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 4373, 228, 8255, 5, 2, 656, 245, 2350, 5, 4, 9837, 131, 152, 491, 18, 2, 32, 7464, 1212, 14, 9, 6, 371, 78, 22, 625, 64, 1382, 9, 8, 168, 145, 23, 4, 1690, 15, 16, 4, 1355, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95])
 list([1, 14, 47, 8, 30, 31, 7, 4, 249, 108, 7, 4, 5974, 54, 61, 369, 13, 71, 149, 14, 22, 112, 4, 2401, 311, 12, 16, 3711, 33, 75, 43, 1829, 296, 4, 86, 320, 35, 534, 19, 263, 4821, 1301, 4, 1873, 33, 89, 78, 12, 66, 16, 4, 360, 7, 4, 58, 316, 334, 11, 4, 1716, 43, 645, 662, 8, 257, 85, 1200, 42, 1228, 2578, 83, 68, 3912, 15, 36, 165, 1539, 278, 36, 69, 2, 780, 8, 106, 14, 6905, 1338, 18, 6, 22, 12, 215, 28, 610, 40, 6, 87, 326, 23, 2300, 21, 23, 22, 12, 272, 40, 57, 31, 11, 4, 22, 47, 6, 2307, 51, 9, 170, 23, 595, 116, 595, 1352, 13, 191, 79, 638, 89, 2, 14, 9, 8, 106, 607, 624, 35, 534, 6, 227, 7, 129, 113])
 ...
 list([1, 11, 6, 230, 245, 6401, 9, 6, 1225, 446, 2, 45, 2174, 84, 8322, 4007, 21, 4, 912, 84, 2, 325, 725, 134, 2, 1715, 84, 5, 36, 28, 57, 1099, 21, 8, 140, 8, 703, 5, 2, 84, 56, 18, 1644, 14, 9, 31, 7, 4, 9406, 1209, 2295, 2, 1008, 18, 6, 20, 207, 110, 563, 12, 8, 2901, 2, 8, 97, 6, 20, 53, 4767, 74, 4, 460, 364, 1273, 29, 270, 11, 960, 108, 45, 40, 29, 2961, 395, 11, 6, 4065, 500, 7, 2, 89, 364, 70, 29, 140, 4, 64, 4780, 11, 4, 2678, 26, 178, 4, 529, 443, 2, 5, 27, 710, 117, 2, 8123, 165, 47, 84, 37, 131, 818, 14, 595, 10, 10, 61, 1242, 1209, 10, 10, 288, 2260, 1702, 34, 2901, 2, 4, 65, 496, 4, 231, 7, 790, 5, 6, 320, 234, 2766, 234, 1119, 1574, 7, 496, 4, 139, 929, 2901, 2, 7750, 5, 4241, 18, 4, 8497, 2, 250, 11, 1818, 7561, 4, 4217, 5408, 747, 1115, 372, 1890, 1006, 541, 9303, 7, 4, 59, 2, 4, 3586, 2])
 list([1, 1446, 7079, 69, 72, 3305, 13, 610, 930, 8, 12, 582, 23, 5, 16, 484, 685, 54, 349, 11, 4120, 2959, 45, 58, 1466, 13, 197, 12, 16, 43, 23, 2, 5, 62, 30, 145, 402, 11, 4131, 51, 575, 32, 61, 369, 71, 66, 770, 12, 1054, 75, 100, 2198, 8, 4, 105, 37, 69, 147, 712, 75, 3543, 44, 257, 390, 5, 69, 263, 514, 105, 50, 286, 1814, 23, 4, 123, 13, 161, 40, 5, 421, 4, 116, 16, 897, 13, 2, 40, 319, 5872, 112, 6700, 11, 4803, 121, 25, 70, 3468, 4, 719, 3798, 13, 18, 31, 62, 40, 8, 7200, 4, 2, 7, 14, 123, 5, 942, 25, 8, 721, 12, 145, 5, 202, 12, 160, 580, 202, 12, 6, 52, 58, 2, 92, 401, 728, 12, 39, 14, 251, 8, 15, 251, 5, 2, 12, 38, 84, 80, 124, 12, 9, 23])
 list([1, 17, 6, 194, 337, 7, 4, 204, 22, 45, 254, 8, 106, 14, 123, 4, 2, 270, 2, 5, 2, 2, 732, 2098, 101, 405, 39, 14, 1034, 4, 1310, 9, 115, 50, 305, 12, 47, 4, 168, 5, 235, 7, 38, 111, 699, 102, 7, 4, 4039, 9245, 9, 24, 6, 78, 1099, 17, 2345, 2, 21, 27, 9685, 6139, 5, 2, 1603, 92, 1183, 4, 1310, 7, 4, 204, 42, 97, 90, 35, 221, 109, 29, 127, 27, 118, 8, 97, 12, 157, 21, 6789, 2, 9, 6, 66, 78, 1099, 4, 631, 1191, 5, 2642, 272, 191, 1070, 6, 7585, 8, 2197, 2, 2, 544, 5, 383, 1271, 848, 1468, 2, 497, 2, 8, 1597, 8778, 2, 21, 60, 27, 239, 9, 43, 8368, 209, 405, 10, 10, 12, 764, 40, 4, 248, 20, 12, 16, 5, 174, 1791, 72, 7, 51, 6, 1739, 22, 4, 204, 131, 9])]

為什么輸入是25000,因為這里是25000個list,numpy肯定統計不了list的形狀

從結果看好像是截斷操作

In [5]:
# 將整數列表轉換成形狀為(samples,maxlen)的二維整數張量
# This turns our lists of integers # into a 2D integer tensor of shape `(samples, maxlen)` x_train = preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen) x_test = preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen) print(x_train.shape) print(x_test.shape) print(x_train[0]) print(x_train[1]) print(x_train) 
(25000, 20)
(25000, 20)
[  65   16   38 1334   88   12   16  283    5   16 4472  113  103   32
   15   16 5345   19  178   32]
[  23    4 1690   15   16    4 1355    5   28    6   52  154  462   33
   89   78  285   16  145   95]
[[  65   16   38 ...   19  178   32]
 [  23    4 1690 ...   16  145   95]
 [1352   13  191 ...    7  129  113]
 ...
 [  11 1818 7561 ...    4 3586    2]
 [  92  401  728 ...   12    9   23]
 [ 764   40    4 ...  204  131    9]]

從結果看好像是截斷操作

3、在 IMDB 數據上使用 Embedding 層和分類器

In [6]:
from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Flatten, Dense model = Sequential() # We specify the maximum input length to our Embedding layer # so we can later flatten the embedded inputs # 指定 Embedding 層的最大輸入長度,以 便后面將嵌入輸入展平。 # Embedding 層 激活的形狀為 (samples, maxlen, 8) model.add(Embedding(10000, 8, input_length=maxlen)) # After the Embedding layer, # our activations have shape `(samples, maxlen, 8)`. # We flatten the 3D tensor of embeddings # into a 2D tensor of shape `(samples, maxlen * 8)` model.add(Flatten()) # We add the classifier on top model.add(Dense(1, activation='sigmoid')) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) model.summary() history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2) 
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_1 (Embedding)      (None, 20, 8)             80000     
_________________________________________________________________
flatten (Flatten)            (None, 160)               0         
_________________________________________________________________
dense (Dense)                (None, 1)                 161       
=================================================================
Total params: 80,161
Trainable params: 80,161
Non-trainable params: 0
_________________________________________________________________
Epoch 1/10
625/625 [==============================] - 2s 3ms/step - loss: 0.6654 - acc: 0.6302 - val_loss: 0.6112 - val_acc: 0.7034
Epoch 2/10
625/625 [==============================] - 2s 3ms/step - loss: 0.5369 - acc: 0.7495 - val_loss: 0.5245 - val_acc: 0.7310
Epoch 3/10
625/625 [==============================] - 2s 3ms/step - loss: 0.4599 - acc: 0.7895 - val_loss: 0.5001 - val_acc: 0.7496
Epoch 4/10
625/625 [==============================] - 2s 3ms/step - loss: 0.4215 - acc: 0.8086 - val_loss: 0.4934 - val_acc: 0.7530
Epoch 5/10
625/625 [==============================] - 2s 3ms/step - loss: 0.3961 - acc: 0.8226 - val_loss: 0.4923 - val_acc: 0.7586
Epoch 6/10
625/625 [==============================] - 2s 3ms/step - loss: 0.3747 - acc: 0.8329 - val_loss: 0.4967 - val_acc: 0.7564
Epoch 7/10
625/625 [==============================] - 2s 3ms/step - loss: 0.3568 - acc: 0.8450 - val_loss: 0.5009 - val_acc: 0.7580
Epoch 8/10
625/625 [==============================] - 2s 3ms/step - loss: 0.3396 - acc: 0.8555 - val_loss: 0.5062 - val_acc: 0.7578
Epoch 9/10
625/625 [==============================] - 2s 3ms/step - loss: 0.3235 - acc: 0.8649 - val_loss: 0.5149 - val_acc: 0.7578
Epoch 10/10
625/625 [==============================] - 2s 3ms/step - loss: 0.3071 - acc: 0.8730 - val_loss: 0.5216 - val_acc: 0.7552

得到的驗證精度約為 76%,考慮到僅查看每條評論的前 20 個單詞,這個結果還是相當不錯 的。但請注意,僅僅將嵌入序列展開並在上面訓練一個 Dense 層,會導致模型對輸入序列中的 每個單詞單獨處理,而沒有考慮單詞之間的關系和句子結構(舉個例子,這個模型可能會將 this movie is a bomb和this movie is the bomb兩條都歸為負面評論 a)。更好的做法是在嵌入序列上添 加循環層或一維卷積層,將每個序列作為整體來學習特征。

In [ ]:
 
 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM