通常,我們使用bert做文本分類,泛化性好、表現優秀。在進行文本相似性計算任務時,往往是對語料訓練詞向量,再聚合文本向量embedding數據,計算相似度;但是,word2vec是靜態詞向量,表征能力有限,此時,可以用已進行特定環境下訓練的bert模型,抽取出cls向量作為整個句子的表征向量以供下游任務使用,可以說是一個附加產物;主要流程如下:
1)加載ckpt模型
2)確定輸出tensor名稱,在bert中,cls的名稱為:bert/pooler/dense/Tanh(而不是SoftMax)
3)存儲為pb model
主代碼:
def extract_bert_vector(): """ 抽取bert 768 特征向量 :return: """ OUTPUT_GRAPH = 'pb_model/bert_encoder.pb' output_node = ["bert/pooler/dense/Tanh"] ckpt_model = r'output' bert_config_file = r'chinese_L-12_H-768_A-12/bert_config.json' max_seq_length = 200 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True sess = tf.Session(config=gpu_config) graph = tf.get_default_graph() with open(r'data/file_dict.json', 'r') as fr: label_list = json.load(fr) with graph.as_default(): print("going to restore checkpoint") input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(bert_config_file) (loss, per_example_loss, logits, probabilities) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=len(label_list), use_one_hot_embeddings=False) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(ckpt_model)) graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node) with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f: f.write(graph.SerializeToString()) print('extract vector pb model saved!')
768維度明顯過高,采用白化處理,將768->256
代碼如下:
def compute_kernel_bias(vecs, n_components=256): """計算kernel和bias vecs.shape = [num_samples, embedding_size], 最后的變換:y = (x + bias).dot(kernel) """ mu = vecs.mean(axis=0, keepdims=True) cov = np.cov(vecs.T) u, s, vh = np.linalg.svd(cov) W = np.dot(u, np.diag(1 / np.sqrt(s))) return W[:, :n_components], -mu def transform_and_normalize(vecs, kernel=None, bias=None): """ 最終向量標准化 """ if not (kernel is None or bias is None): vecs = (vecs + bias).dot(kernel) return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5