sns.heatmap() 熱地圖,包括傳統的,下三角,重點(挖空)相關性性圖


sns.heatmap() 熱地圖

我一般使用來畫特征相關系數的圖

seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', 
annotkws=None, linewidths=0, linecolor='white', cbar=True, cbarkws=None, cbar_ax=None, square=False,
ax=None, xticklabels=True, yticklabels=True, mask=None, **kwargs)

參數太多就不一一解釋了,我就用配置好的,以后都套着用,

注意,計算相關系數時,只計算數據型的的特征,object不能計算

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns

# 讀取數據文件
telcom=pd.read_csv('F:\\python\\電信用戶數據\\WA_Fn-UseC_-Telco-Customer-Churn.csv')

# 提取特征
charges=telcom.iloc[:,1:20]
# 對特征進行編碼
"""
離散特征的編碼分為兩種情況:
1、離散特征的取值之間沒有大小的意義,比如color:[red,blue],那么就使用one-hot編碼
2、離散特征的取值有大小的意義,比如size:[X,XL,XXL],那么就使用數值的映射{X:1,XL:2,XXL:3}
"""
corrDf = charges.apply(lambda x: pd.factorize(x)[0])
corrDf .head()

corrDf.info()


# 構造相關性矩陣
corr = corrDf.corr()
corr

# 使用熱地圖顯示相關系數
'''
heatmap    使用熱地圖展示系數矩陣情況
linewidths 熱力圖矩陣之間的間隔大小
annot      設定是否顯示每個色塊的系數值
'''
plt.figure(figsize=(20,16))
ax = sns.heatmap(corr, xticklabels=corr.columns, yticklabels=corr.columns, 
                 linewidths=0.2, cmap="YlGnBu",annot=True)
plt.title("Correlation between variables")

 

 

當然描述相關系數時還可以畫條形圖,特別是描述各特征和y值的相關程度時,還要注意,corr只對數據型數據進行計算

# 使用one-hot編碼
tel_dummies = pd.get_dummies(telcom.iloc[:,1:21])
tel_dummies.head()

tel_dummies.info()



# 電信用戶是否流失與各變量之間的相關性
plt.figure(figsize=(15,8))
tel_dummies.corr()['Churn'].sort_values(ascending = False).plot(kind='bar')
plt.title("Correlations between Churn and variables")

只畫下三角的地熱圖

fig = plt.figure(figsize = [15,10])
mask = np.zeros_like(finalTrain.corr(), dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
sns.heatmap(finalTrain.corr(), cmap=sns.diverging_palette(150, 275, s=80, l=55, n=9), mask = mask, annot=True, center = 0)
plt.title("Correlation Matrix (HeatMap)", fontsize = 15)

 20210407補充重點相關性圖

下三角矩陣好像變得更加簡潔了,可還是有些多有些亂,雖然我們可以依據顏色的深淺來判別特征之間的強弱相關性,但還是不太方便。我們能不能找出我們所希望看到的那塊呢?比如:

  • 我們只想找到強相關的特征來可視化,其余的全部過濾掉。

其實是可以的,在seaborn數據包的heatmap函數中,還有一個mask函數,可以幫助我們篩選出我們希望看到的部分,例如我們只想看相關性大於0.5的部分

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import load_iris

iris=load_iris()
iris.data
iris.target
iris.feature_names
df=pd.DataFrame(iris.data,columns=iris.feature_names)
iris.target_names
df['taget']= iris.target

# Calculate pairwise-correlation
matrix = df.corr()
cmap = sns.diverging_palette(250, 15, s=75, l=40, n=9, center="light", as_cmap=True)
# mask掉上三角 & 小於某個閾值的值
mask1 = np.triu(np.ones_like(matrix, dtype=bool))
mask2 = np.abs(matrix) <= 0.5
mask  = mask1 | mask2

plt.figure(figsize=(12, 8)) 
sns.heatmap(matrix,  mask=mask, center=0, annot=True,fmt='.2f', square=True, cmap=cmap) 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM