最近在使用注意力機制實現文本分類,我們需要觀察每一個樣本中,模型的重心放在哪里了,就是觀察到權重最大的token。這時我們需要使用熱力圖進行可視化。
我這里用到:seaborn
seaborn.heatmap
seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', annotkws=None, linewidths=0, linecolor='white', cbar=True, cbarkws=None, cbar_ax=None, square=False, ax=None, xticklabels=True, yticklabels=True, mask=None, **kwargs)
- data:矩陣數據集,可以使numpy的數組(array),如果是pandas的dataframe,則df的index/column信息會分別對應到heatmap的columns和rows
- linewidths,熱力圖矩陣之間的間隔大小
- vmax,vmin, 圖例中最大值和最小值的顯示值,沒有該參數時默認不顯示
data就是我們注意力矩陣的數據。注意,由於注意力的整理數值都偏小,直接使用數據顯示的效果難以區分,我們可以將其放大100倍后來獲取更加的效果。 先上代碼吧!
fr = open('./pkl/attention_matrix.pkl', 'rb') tokens, attention = pickle.load(fr) plt.figure(figsize=(30,20)) sns.heatmap(attention, vamx=100, vmin=0) plt.savefig('./log/attention_matrix.png') # 獲取數據 import heapq check_file = './log/check_attention_keywords.txt' clean(check_file) fw = open(check_file, 'a', encoding='utf8') for t, a in zip(tokens, attention): temp = [] max_num_index_list = map(list(a).index, heapq.nlargest(5, list(a)) for index in max_num_index_list: word = t[index] print(word) temp.append(word) fw.write(str(temp)+'\n')
我這里取出注意力值最大的前5個詞拿出來看的