seaborn畫熱力圖注意的幾點問題


最近在使用注意力機制實現文本分類,我們需要觀察每一個樣本中,模型的重心放在哪里了,就是觀察到權重最大的token。這時我們需要使用熱力圖進行可視化。

我這里用到:seaborn

seaborn.heatmap

seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', annotkws=None, linewidths=0, linecolor='white', cbar=True, cbarkws=None, cbar_ax=None, square=False, ax=None, xticklabels=True, yticklabels=True, mask=None, **kwargs)

  • data:矩陣數據集,可以使numpy的數組(array),如果是pandas的dataframe,則df的index/column信息會分別對應到heatmap的columns和rows
  • linewidths,熱力圖矩陣之間的間隔大小
  • vmax,vmin, 圖例中最大值和最小值的顯示值,沒有該參數時默認不顯示

 

data就是我們注意力矩陣的數據。注意,由於注意力的整理數值都偏小,直接使用數據顯示的效果難以區分,我們可以將其放大100倍后來獲取更加的效果。 先上代碼吧!

fr = open('./pkl/attention_matrix.pkl', 'rb')
tokens, attention = pickle.load(fr)
plt.figure(figsize=(30,20))
sns.heatmap(attention, vamx=100, vmin=0)
plt.savefig('./log/attention_matrix.png')

# 獲取數據
import heapq
check_file = './log/check_attention_keywords.txt'
clean(check_file)
fw = open(check_file, 'a', encoding='utf8')
for t, a in zip(tokens, attention):
    temp = []
    max_num_index_list = map(list(a).index, heapq.nlargest(5, list(a))
    for index in max_num_index_list:
        word = t[index]
        print(word)
        temp.append(word)
    fw.write(str(temp)+'\n')

 

  我這里取出注意力值最大的前5個詞拿出來看的

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM