matplotlib---插值畫二維、三維圖


一、畫二維圖

1.原始數據(x,y)

import matplotlib.pyplot as plt
import numpy as np

#數據
X = np.array(list(i for i in range(6)))
Y = np.array([10,30,20,50,100,120])

2.先對橫坐標x進行擴充數據量,采用linspace

#插值
from scipy.interpolate import spline
X_new = np.linspace(X.min(),X.max(),300) #300 represents number of points to make between X.min and X.max

3.采用scipy.interpolate中的spline來對縱坐標數據y進行插值

由6個擴充到300個

smooth = spline(X,Y,X_new)
print(X_new.shape)  #(300,)
print(smooth.shape)  #(300,)

4.畫圖

#畫圖
plt.plot(X_new,smooth)
plt.show()

 

插值前 插值后

 

 

 

 

二、畫三維圖

1.載入數據

# 載入模塊
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import pandas as pd
import seaborn as sns
from scipy import interpolate

df_epsilon_alpha = pd.read_excel('實驗記錄_超參數.xlsx',sheet_name='epsilon_alpha')
#生成數據
epsilon = np.array(df_epsilon_alpha['epsilon'].values)
alpha = np.array(df_epsilon_alpha['alpha'].values)
Precision = np.array(df_epsilon_alpha['Precision'].values)

  

2.將x和y擴充到想要的大小

【兩種方法:np.arange和np.linspace】

xnew = np.arange(0.1, 1, 0.09) #左閉右閉每0.09間隔生成一個數
ynew = np.arange(0.1, 1, 0.09)  
或者
x = np.linspace(0.1,0.9,9)#0.1到0.9生成9個數
y = np.linspace(0.1,0.9,9)

 

3.對z插值

x,y原數據:

x = np.linspace(0.1,0.9,9)
y = np.linspace(0.1,0.9,9)
z = Precision

采用 scipy.interpolate.interp2d函數進行插值

f = interpolate.interp2d(x, y, z, kind='cubic')

x,y擴充數據:

xnew = np.arange(0.1, 1, 0.03)#(31,)
ynew = np.arange(0.1, 1, 0.03)#(31,)
znew = f(xnew, ynew)#(31,31) 

znew為插值后的z

 

4.畫圖

采用  from mpl_toolkits.mplot3d import Axes3D進行畫三維圖

Axes3D簡單用法:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

比如采用plot_trisurf畫三維圖:

plot_trisurf(x,y,z)

plot_trisurf對數據要求是:x.shape = y.shape = z.shape,所以x和y的shape需要修改,采用np.meshgrid,且都為一維數據

修改x,y,z輸入畫圖函數前的shape

xx1, yy1 = np.meshgrid(xnew, ynew)#執行之后,xx1.shape=(31,31),yy1.shape=(31,31)
newshape = (xx1.shape[0])*(xx1.shape[0])
y_input = xx1.reshape(newshape)
x_input = yy1.reshape(newshape)
z_input = znew.reshape(newshape)

x_input.shape,y_input.shape,z_input.shape=((961,), (961,), (961,))

 

畫圖代碼

#畫圖
sns.set(style='ticks')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_trisurf(x_input,y_input,z_input,cmap=cm.coolwarm)

plt.xlim((0.1,0.9))
plt.xticks([0.1,0.3,0.5,0.7,0.9])
plt.yticks([0.1,0.3,0.5,0.7,0.9])
ax.set_xlabel(r'$\alpha$',fontdict={'color': 'black',
                             'family': 'Times New Roman',
                             'weight': 'normal',
                             'size': 18})
ax.set_ylabel(r'$\epsilon$',fontdict={'color': 'black',
                             'family': 'Times New Roman',
                             'weight': 'normal',
                             'size': 18})
ax.set_zlabel('precision',fontdict={'color': 'black',
                             'family': 'Times New Roman',
                             'weight': 'normal',
                             'size': 18})

plt.tight_layout()
# plt.savefig('loc_svg/alpha_epsilon2.svg',dpi=600) #指定分辨率保存
plt.show()

 

 

插值前 插值后

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM