(十三)二次样条插值


 

 1 #encoding=utf-8
 2 from numpy import *
 3 import  numpy as  np
 4 
 5 import matplotlib.pyplot as plt
 6 plt.close()
 7 fig=plt.figure()
 8 plt.grid(True)
 9 plt.xlabel("X")
10 plt.ylabel("Y")
11 plt.title(u"二次样条", fontproperties='SimHei')
12 #二次样条
13 #设已知n+1个点 设方程式为yi = ai*x**2+bi*x+ci,则需要求解3n个参数,则需要3n个条件。
14 # 由曲线过点得2n 个条件 ,由内点处一次求导相等得n-1个条件 ,令第一个节点处二阶导为0 得一个条件 a1 = 0
15 
16 
17 # #n+1个点
18 # xi = [3,4.5,7,9]#存储x的值
19 # n=len(xi)-1
20 # yi = [2.5,1,2.5,0.5]#存储y 的值
21 
22 # 输入相关数值
23 n=input("请输入取到样点数目:")-1
24 xi = []
25 yi = []
26 for t in range(0,n+1):
27     inputX = "请输入第"+str(t+1)+"个点 X:"
28     inputY = "请输入第"+str(t+1)+"个点 Y:"
29     xi.append(input(inputX))
30     yi.append(input(inputY))
31 
32 
33 
34 xi_2 = []# 存储 x**2 的值
35 #向 x**2 中存入数据
36 for i in range(0,len(xi)):
37     xi_2.append(xi[i]**2)
38 #将方程组转化为矩阵 使 mat1 * mat2 = mat3
39 # 可以得到行列为3n 的矩阵
40 m1 = [[0 for a in range(3*n)] for b in range(3*n)]
41 m3 = [[0 for a in range(1)] for b in range(3*n)]
42 #若 mat2 内的值定为 a1,b1 ,c1,a2,b2,c2,a3,b3,c2... 便可以确定mat1的值
43 #向 m1 中存入数据
44 #定义变量 p 用于记录向矩阵插入的行数,及后续插入
45 p=0
46 for j in range(0,n):
47 
48 #从0 到n-1 的 n 个点代入
49     m1[p][3*j]=xi_2[j]
50     m1[p][3*j+1]=xi[j]
51     m1[p][3*j+2]=1
52     m3[p][0]=yi[j]
53     p=p+1
54 #从 1 到 n 的 n 个点代入
55     m1[p][3*j]=xi_2[j+1]
56     m1[p][3*j+1]=xi[j+1]
57     m1[p][3*j+2]=1
58     m3[p][0]=yi[j+1]
59     p=p+1
60 # 中间节点的一阶导 相等 2ai + bi - 2a(i+1)-b(i+1) = 0
61 #从 1 到n-1 的 n-1 个 点处代入
62 for k in range(1,n):
63     m1[p][3*(k-1)]= 2*xi[k]
64     m1[p][3*(k-1)+1]= 1
65     m1[p][3*k]= -2*xi[k]
66     m1[p][3*k+1]= -1
67     p=p+1
68 #  代入条件 二阶导为零a1 = 0
69 m1[p][0] = 2
70 p=p+1
71 #将list转化为 矩阵
72 mat1 = np.matrix(m1)
73 mat3 = np.matrix(m3)
74 #求mat2
75 _mat1 = mat1.getI()
76 mat2=_mat1*mat3
77 # mat2=mat3/mat1
78 #将矩阵转化为list提取数据
79 m2=mat2.tolist()
80 #整理求得的曲线
81 line=[]
82 #用于收集区间
83 interval=[]
84 for q in range(0,n):
85     interval.append(np.linspace(xi[q],xi[q+1]))
86     a=m2[q*3]
87     b=m2[q*3+1]
88     c=m2[q*3+2]
89     line.append(a*interval[q]**2+b*interval[q]+c)
90     plt.plot(interval[q],line[q])
91 plt.show()

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM