【數值分析】Python實現Lagrange插值


一直想把這幾個插值公式用代碼實現一下,今天閑着沒事,嘗試嘗試。

先從最簡單的拉格朗日插值開始!關於拉格朗日插值公式的基礎知識就不贅述,百度上一搜一大堆。

基本思路是首先從文件讀入給出的樣本點,根據輸入的插值次數和想要預測的點的x選擇合適的樣本點區間,最后計算基函數得到結果。直接看代碼!(注:這里說樣本點不是很准確,實在詞窮找不到一個更好的描述。。。)

str2double

一個小問題就是怎樣將python中的str類型轉換成float類型,畢竟我們給出的樣本點不一定總是整數,而且也需要做一些容錯處理,比如多個+、多個-等等,也應該能識別為正確的數。所以實現了一個str2double方法。

import re
def str2double(str_num):
    pattern = re.compile(r'^((\+*)|(\-*))?(\d+)(.(\d+))?$')
    m = pattern.match(str_num)
    if m is None:
        return m
    else:
        sign = 1 if str_num[0] == '+' or '0' <= str_num[0] <= '9' else -1
        num = re.sub(r'(\++)|(\-+)', "", m.group(0))
        matchObj = re.match(r'^\d+$', num)
        if matchObj is not None:
            num = sign * int(matchObj.group(0))
        else:
            matchObj = re.match(r'^(\d+).(\d+)$', num)
            if matchObj is not None:
                integer = int(matchObj.group(1))
                fraction = int(matchObj.group(2)) * pow(10, -1*(len(matchObj.group(2))))
                num = sign * (integer + fraction)
        return num

我使用了正則表達式來實現,pattern = re.compile(r'^((\+*)|(\-*))?(\d+)(.(\d+))?$')可以匹配我上面提到的所有類型的整數和浮點數,之后進行匹配,匹配成功,如果是整數,直接return整數部分,這個用(int)強制轉換即可;如果是浮點數,那么用(\d+)這個正則表達式再次匹配,分別得到整數部分和小數部分,整數部分的處理和上面類似,小數部分則用乘以pow(10, -小數位數)得到,之后直接相加即可。這里為了支持多個+或者-,使用re.sub方法將符號去掉,所以就需要用sign來記錄數字的正負,在最后return時乘上sign即可。

def binary_search(point_set, n, x):
    first = 0
    length = len(point_set)
    last = length
    while first < last:
        mid = (first + last) // 2
        if point_set[mid][0] < x:
            first = mid + 1
        elif point_set[mid][0] == x:
            return mid
        else:
            last = mid
    last =  last if last != length else last-1

    head = last - 1
    tail = last
    while n > 0:
        if head != -1:
            n -= 1
            head -= 1
        if tail != length:
            n -= 1
            tail += 1
    return [head+1, tail-1] if n == 0 else [head+1, tail-2]

這里point_set是全部樣本點的集合,n是輸入的插值次數,x是輸入的預測點。返回合適的插值區間,即盡可能地把x包在里面。

因為要根據輸入得到合適的插值區間,所以就涉及查找方面的知識。這里使用了二分查找,先對樣本點集合point_set進行排序(升序),找到第一個大於需要預測點的樣本點,在它的兩側擴展區間,直到滿足插值次數要求。這里我的實現有些問題,可能會出現n=-1因為tail多加了一次,就在while循環外又進行了一次判斷,n=-1tail-2,這個實現的確不好,可能還會有bug。。。

最后,剩下的內容比較好理解,直接放上全部代碼。

import re
import matplotlib.pyplot as plt
import numpy as np

def str2double(str_num):
    pattern = re.compile(r'^((\+*)|(\-*))?(\d+)(.(\d+))?$')
    m = pattern.match(str_num)
    if m is None:
        return m
    else:
        sign = 1 if str_num[0] == '+' or '0' <= str_num[0] <= '9' else -1
        num = re.sub(r'(\++)|(\-+)', "", m.group(0))
        matchObj = re.match(r'^\d+$', num)
        if matchObj is not None:
            num = sign * int(matchObj.group(0))
        else:
            matchObj = re.match(r'^(\d+).(\d+)$', num)
            if matchObj is not None:
                integer = int(matchObj.group(1))
                fraction = int(matchObj.group(2)) * pow(10, -1*(len(matchObj.group(2))))
                num = sign * (integer + fraction)
        return num

def preprocess():
    f = open("input.txt", "r")
    lines = f.readlines()
    lines = [line.strip('\n') for line in lines]
    point_set = list()
    for line in lines:
        point = list(filter(None, line.split(" ")))
        point = [str2double(pos) for pos in point]
        point_set.append(point)
    return point_set

def lagrangeFit(point_set, x):
    res = 0
    for i in range(len(point_set)):
        L = 1
        for j in range(len(point_set)):
            if i == j:
                continue
            else:
                L = L * (x - point_set[j][0]) / (point_set[i][0] - point_set[j][0])
        L = L * point_set[i][1]
        res += L
    return res

def showbasis(point_set):
    print("Lagrange Basis Function:\n")
    for i in range(len(point_set)):
        top = ""
        buttom = ""
        for j in range(len(point_set)):
            if i == j:
                continue
            else:
                top += "(x-{})".format(point_set[j][0])
                buttom += "({}-{})".format(point_set[i][0], point_set[j][0])
        print("Basis function{}:".format(i))
        print("\t\t{}".format(top))
        print("\t\t{}".format(buttom))

def binary_search(point_set, n, x):
    first = 0
    length = len(point_set)
    last = length
    while first < last:
        mid = (first + last) // 2
        if point_set[mid][0] < x:
            first = mid + 1
        elif point_set[mid][0] == x:
            return mid
        else:
            last = mid
    last =  last if last != length else last-1

    head = last - 1
    tail = last
    while n > 0:
        if head != -1:
            n -= 1
            head -= 1
        if tail != length:
            n -= 1
            tail += 1
    return [head+1, tail-1] if n == 0 else [head+1, tail-2]

if __name__ == '__main__':
    pred_x = input("Predict x:")
    pred_x = float(pred_x)
    n = input("Interpolation times:")
    n = int(n)
    point_set = preprocess()
    point_set = sorted(point_set, key=lambda a: a[0])
    span = binary_search(point_set, n+1, pred_x)
    print("Chosen points: {}".format(point_set[span[0]:span[1]+1]))
    showbasis(point_set[span[0]:span[1]+1])

    X = np.linspace(-np.pi, np.pi, 256, endpoint=True)
    S = np.sin(X)
    L = [lagrangeFit(point_set, x) for x in X]
    L1 = [lagrangeFit(point_set[span[0]:span[1]+1], x) for x in X]
    
    plt.figure(figsize=(8, 4))
    plt.plot(X, S, label="$sin(x)$", color="red", linewidth=2)
    plt.plot(X, L, label="$LagrangeFit-all$", color="blue", linewidth=2)
    plt.plot(X, L1, label="$LagrangeFit-special$", color="green", linewidth=2)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title("$sin(x)$ and Lagrange Fit")
    plt.legend()
    plt.show()

About Input

使用了input.txt進行樣本點讀入,每一行一個點,中間有一個空格。

結果

感覺挺好玩的hhh,過幾天試試牛頓插值!掰掰!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM