Python可視化:Seaborn庫熱力圖使用進階


前言

在日常工作中,經常可以見到各種各種精美的熱力圖,熱力圖的應用非常廣泛,下面一起來學習下Python的Seaborn庫中熱力圖(heatmap)如何來進行使用。

本次運行的環境為:

  • windows 64位系統

  • python 3.5

  • jupyter notebook

1 構造數據

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
% matplotlib inline
region = ['Albania', 'Algeria', 'Angola', 'Argentina', 'Armenia', 'Azerbaijan',
       'Bahamas', 'Bangladesh', 'Belize', 'Bhutan', 'Bolivia',
       'Bosnia and Herzegovina', 'Brazil', 'Burkina Faso', 'Burundi',
       'Cambodia', 'Cameroon', 'Cape Verde', 'Chile', 'China', 'Colombia',
       'Costa Rica', 'Cote d Ivoire', 'Cuba', 'Cyprus',
       "Democratic People's Republic of Korea",
       'Democratic Republic of the Congo', 'Dominican Republic', 'Ecuador',
       'Egypt', 'El Salvador', 'Equatorial Guinea', 'Ethiopia', 'Fiji',
       'Gambia', 'Georgia', 'Ghana', 'Guatemala', 'Guyana', 'Honduras']

kind = ['Afforestation & reforestation', 'Biofuels', 'Biogas',
        'Biomass', 'Cement', 'Energy efficiency', 'Fuel switch',
       'HFC reduction/avoidance', 'Hydro power',
        'Leak reduction', 'Material use', 'Methane avoidance',             
       'N2O decomposition', 'Other renewable energies',
       'PFC reduction and substitution','PV',
       'SF6 replacement', 'Transportation', 'Waste gas/heat utilization',
      'Wind power']
print(len(region))
print(len(kind))
40
20
np.random.seed(100)
arr_region = np.random.choice(region, size=(10000,))
list_region = list(arr_region)

arr_kind = np.random.choice(kind, size=(10000,))
list_kind = list(arr_kind)

values = np.random.randint(50, 1000, 10000)
list_values = list(values)

df = pd.DataFrame({'region':list_region,
                  'kind': list_kind,
                  'values':list_values})
df.head()

pt = df.pivot_table(index='kind', columns='region', values='values', aggfunc=np.sum)
pt.head()

f, ax = plt.subplots(figsize = (10, 4))
cmap = sns.cubehelix_palette(start = 1, rot = 3, gamma=0.8, as_cmap = True)
sns.heatmap(pt, cmap = cmap, linewidths = 0.05, ax = ax)
ax.set_title('Amounts per kind and region')
ax.set_xlabel('region')
ax.set_ylabel('kind')

f.savefig('sns_heatmap_normal.jpg', bbox_inches='tight')
# ax.set_xticklabels(ax.get_xticklabels(), rotation=-90)

2 Seaborn的heatmap各個參數介紹

seaborn.heatmap

seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt=’.2g’, annot_kws=None, linewidths=0, linecolor=’white’, cbar=True, cbar_kws=None, cbar_ax=None, square=False, ax=None, xticklabels=True, yticklabels=True, mask=None, **kwargs)

  • data:矩陣數據集,可以使numpy的數組(array),如果是pandas的dataframe,則df的index/column信息會分別對應到heatmap的columns和rows
  • linewidths,熱力圖矩陣之間的間隔大小
  • vmax,vmin, 圖例中最大值和最小值的顯示值,沒有該參數時默認不顯示

2.1 cmap

  • cmap:matplotlib的colormap名稱或顏色對象;如果沒有提供,默認為cubehelix map (數據集為連續數據集時) 或 RdBu_r (數據集為離散數據集時)
f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)

# cubehelix map顏色
cmap = sns.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)
sns.heatmap(pt, linewidths = 0.05, ax = ax1, vmax=15000, vmin=0, cmap=cmap)
ax1.set_title('cubehelix map')
ax1.set_xlabel('')
ax1.set_xticklabels([]) #設置x軸圖例為空值
ax1.set_ylabel('kind')

# matplotlib colormap
sns.heatmap(pt, linewidths = 0.05, ax = ax2, vmax=15000, vmin=0, cmap='rainbow') 
# rainbow為 matplotlib 的colormap名稱
ax2.set_title('matplotlib colormap')
ax2.set_xlabel('region')
ax2.set_ylabel('kind')

f.savefig('sns_heatmap_cmap.jpg', bbox_inches='tight')

2.2 center

  • center:將數據設置為圖例中的均值數據,即圖例中心的數據值;通過設置center值,可以調整生成的圖像顏色的整體深淺;設置center數據時,如果有數據溢出,則手動設置的vmax、vmin會自動改變
f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)

cmap = sns.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)

sns.heatmap(pt, linewidths = 0.05, ax = ax1, vmax=15000, vmin=0, cmap=cmap, center=None )
# center為None時,由於最小值為0,最大值為15000,相當於center值為vamx和vmin的均值,即7500
ax1.set_title('center=None')
ax1.set_xlabel('')
ax1.set_xticklabels([]) #設置x軸圖例為空值
ax1.set_ylabel('kind')

sns.heatmap(pt, linewidths = 0.05, ax = ax2, vmax=15000, vmin=0, cmap=cmap, center=3000 ) 
# 由於均值為2000,當center設置為3000時,大部分數據會比7500大,所以center=3000時,生成的圖片顏色要深
# 設置center數據時,如果有數據溢出,則手動設置的vmax或vmin會自動改變
ax2.set_title('center=3000')
ax2.set_xlabel('region')
ax2.set_ylabel('kind')

f.savefig('sns_heatmap_center.jpg', bbox_inches='tight')

2.3 robust

f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)

cmap = sns.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)

sns.heatmap(pt, linewidths = 0.05, ax = ax1, cmap=cmap, center=None, robust=False )
# robust默認為False
ax1.set_title('robust=False')
ax1.set_xlabel('')
ax1.set_xticklabels([]) #設置x軸圖例為空值
ax1.set_ylabel('kind')

sns.heatmap(pt, linewidths = 0.05, ax = ax2, cmap=cmap, center=None, robust=True ) 
# If True and vmin or vmax are absent, the colormap range is computed with robust quantiles instead of the extreme values.
ax2.set_title('robust=True')
ax2.set_xlabel('region')
ax2.set_ylabel('kind')

f.savefig('sns_heatmap_robust.jpg', bbox_inches='tight')

2.4 mask

f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)

cmap = sns.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)

p1 = sns.heatmap(pt, linewidths = 0.05,ax=ax1, vmax=15000, vmin=0, cmap=cmap, center=None, robust=False, mask=None )
# robust默認為False
ax1.set_title('mask=None')
ax1.set_xlabel('')
ax1.set_xticklabels([]) #設置x軸圖例為空值
ax1.set_ylabel('kind')

p2 = sns.heatmap(pt, linewidths = 0.05, ax=ax2, vmax=15000, vmin=0, cmap=cmap, center=None, robust=False, annot=False,mask=pt<10000 ) 
# mask: boolean array or DataFrame

ax2.set_title('mask: boolean DataFrame')
ax2.set_xlabel('region')
ax2.set_ylabel('kind')

f.savefig('sns_heatmap_mask.jpg', bbox_inches='tight')

2.5 xticklabels, yticklabels

  • xticklabels: 如果是True,則繪制dataframe的列名。如果是False,則不繪制列名。如果是列表,則繪制列表中的內容作為xticklabels。 如果是整數n,則繪制列名,但每個n繪制一個label。 默認為True。
  • yticklabels: 如果是True,則繪制dataframe的行名。如果是False,則不繪制行名。如果是列表,則繪制列表中的內容作為yticklabels。 如果是整數n,則繪制列名,但每個n繪制一個label。 默認為True。默認為True。
f, (ax1,ax2) = plt.subplots(figsize = (10, 8),nrows=2)

cmap = sns.cubehelix_palette(start = 1.5, rot = 3, gamma=0.8, as_cmap = True)

p1 = sns.heatmap(pt, linewidths = 0.05,ax=ax1, vmax=15000, vmin=0, cmap=cmap, center=None, robust=False, mask=None, xticklabels=False )
# robust默認為False
ax1.set_title('xticklabels=None')
ax1.set_xlabel('')
# ax1.set_xticklabels([]) #設置x軸圖例為空值
ax1.set_ylabel('kind')

p2 = sns.heatmap(pt, linewidths = 0.05, ax=ax2, vmax=15000, vmin=0, cmap=cmap, center=None, robust=False, annot=False,mask=None,xticklabels=3, yticklabels=list(range(20)) ) 
# mask: boolean array or DataFrame

ax2.set_title('xticklabels=3, yticklabels is a list')
ax2.set_xlabel('region')
ax2.set_ylabel('kind')

f.savefig('sns_heatmap_xyticklabels.jpg', bbox_inches='tight')

2.6 annot

  • annotate的縮寫,annot默認為False,當annot為True時,在heatmap中每個方格寫入數據
  • annot_kws,當annot為True時,可設置各個參數,包括大小,顏色,加粗,斜體字等
np.random.seed(0)
x = np.random.randn(10, 10)
f, (ax1, ax2) = plt.subplots(figsize=(8,8),nrows=2)

sns.heatmap(x, annot=True, ax=ax1)
sns.heatmap(x, annot=True, ax=ax2, annot_kws={'size':9,'weight':'bold', 'color':'blue'})
# Keyword arguments for ax.text when annot is True.
# http://stackoverflow.com/questions/35024475/seaborn-heatmap-key-words

f.savefig('sns_heatmap_annot.jpg')

**關於annot_kws的設置,還有很多值得研究的地方,ax.text有很多屬性,有興趣的可以去研究下;

ax.text可參考官方文檔:http://matplotlib.org/api/text_api.html#matplotlib.text.Text

2.7 fmt

  • fmt,格式設置
np.random.seed(0)
x = np.random.randn(10, 10)
f, (ax1, ax2) = plt.subplots(figsize=(8,8),nrows=2)

sns.heatmap(x, annot=True, ax=ax1)
sns.heatmap(x, annot=True, fmt='.1f', ax=ax2)

f.savefig('sns_heatmap_fmt.jpg')

3 案例應用:突出顯示某些數據

3.1 method 1:利用mask來實現

f,ax=plt.subplots(figsize=(10,5))

x = np.random.randn(10, 10)
sns.heatmap(x, annot=True, ax=ax)
sns.heatmap(x, mask=x < 1, cbar=False, ax=ax,
            annot=True, annot_kws={"weight": "bold"})

f.savefig('sns_heatmap_eg1.jpg')

3.2 method 2:利用ax.texts來實現

f,ax=plt.subplots(figsize=(10,5))

flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
pic = sns.heatmap(flights, annot=True, fmt="d", ax=ax)

for text in pic.texts:
    text.set_size(8)
    if text.get_text() == '118':
        text.set_size(12)
        text.set_weight('bold')
        text.set_style('italic')

f.savefig('sns_heatmap_eg2.jpg')

你可能會發現本文中seaborn的heatmap中還有些參數沒有進行介紹,介於篇幅,這里就不在啰嗦了,建議各位小伙伴自己可以研究下其他參數如何使用。

如需轉載,請在公眾號留言進行授權事宜溝通。

轉載請注明文章來自微信公眾號“Python數據之道”。

更多精彩內容請關注微信公眾號:

“Python數據之道”


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM