多项式快速插值


200+行的多项式板子题真爽啊

给定$n$个点的点值$(x_i,y_i)$,求这$n$个点确定的$n-1$次多项式

\(n\le 10^5\)

前置知识:

多项式多点求值

拉格朗日插值

微积分基础

首先我们有一个$n^2$的拉格朗日插值法

\(f(x)=\sum\limits_{i=1}^{n}y_i\prod\limits_{i\ne j}\frac{x-x_j}{x_i-x_j}\)

然后我们学习一个WC2017挑战就过了

考虑优化,我们知道这个形式它很死,把它变成重心插值

\(f(x)=\sum\limits_{i=1}^{n}\frac{y_i}{\prod\limits_{i\ne j}x_i-x_j}\prod\limits_{i\ne j}(x-x_j)\)

发现除了一个常数$y_i$剩下的就是$\prod\limits_{i\ne j}(x-x_i)$的形式了,我们设它为$g(x)$

那么前面那一项的分母可以表示为$\frac{g(x)}(x=x_i)$

发现$x=x_i$时分子分母都是$0$

根据洛必达法则

\(\lim\limits_{x→x_i}\frac{g(x_i)}{x-x_i}=\lim\limits_{x→x_i}\frac{g'(x_i)}{(x-x_i)'}=\lim\limits_{x→x_i}g'(x)\)

分治算出$g(x)$后多点求值算出$g'(x_i)$

然后拿$y_i$除一下,前一项就搞定了

\(f_{l,r}=\sum\limits_{i=l}^{r}\frac{y_i}{g'(x_i)}\prod\limits_{j=l,j\ne i}^{r}(x-x_j)\)

\(=\prod\limits_{j=mid+1}^{r}(x-x_j)\sum\limits_{i=l}^{mid}\frac{y_i}{g'(x_i)}\prod\limits_{j=l,j\ne i}^{mid}(x-x_j)+\prod\limits_{j=l}^{mid}(x-x_j)\sum\limits_{i=mid+1}^{r}\frac{y_i}{g'(x_i)}\prod\limits_{j=mid+1,j\ne i}^{r}(x-x_j)\)

\(=\prod\limits_{i=mid+1}^{r}(x-x_i)f_{l,mid}+\prod\limits_{i=l}^{mid}(x-x_i)f_{mid+1,r}\)

\(=g_{mid+1,r}f_{l,mid}+g_{l,mid}f_{mid+1,r}\)

分治求解

这板子让我再打一遍都打不出来

#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define eps (1e-8)
	inline int read()
	{
		int x=0;char ch,f=1;
		for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
		if(ch=='-') f=0,ch=getchar();
		while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
		return f?x:-x;
	}
	const int N=1e5+10,M=18,mod=998244353;
	int n,m,tot,limit,len;
	int xx[N],yy[N],val[N];
	int a[N],b[N],c[N],rr[N],cc[N<<2];
	int ra[N],rb[N<<2],irb[N<<2];
	int f[N*M<<1],stf[N<<2],enf[N<<2];
	int g[N*M<<1],stg[N<<2],eng[N<<2];
	int h[N],sth[N],enh[N],f3[N];
	int f1[N<<2],f2[N<<2],w[21][N<<2],pos[N<<2];
	inline int fast(int x,int k)
	{
		int ret=1;
		while(k)
		{
			if(k&1) ret=ret*x%mod;
			x=x*x%mod;
			k>>=1;
		}
		return ret;
	}
	inline int add(int x,const int &y)
	{
		x+=y;
		return x>mod?x-mod:x;
	}
	inline int del(int x,const int &y)
	{
		x-=y;
		return x<0?x+mod:x;
	}
	inline void init(int x)
	{
		limit=1,len=0;
		while(limit<x) limit<<=1,++len;
		for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
	}
	inline void ntt(int limit,int *a,int inv)
	{
		for(int i=0;i<limit;++i)
			if(i<pos[i]) swap(a[i],a[pos[i]]);
		for(int mid=1,t=1;mid<limit;mid<<=1,++t)
		{
			for(int r=mid<<1,j=0;j<limit;j+=r)
			{
				for(int k=0;k<mid;++k)
				{
					int x=a[j+k],y=w[t][k]*a[j+k+mid]%mod;
					a[j+k]=add(x,y);
					a[j+k+mid]=del(x,y);
				}
			}
		}
		if(inv) return;
		inv=fast(limit,mod-2);reverse(a+1,a+limit);
		for(int i=0;i<limit;++i) a[i]=a[i]*inv%mod;
	}
	inline void NTT(int *a,int *b,int limit)
	{
		ntt(limit,a,1);ntt(limit,b,1);
		for(int i=0;i<limit;++i) a[i]=a[i]*b[i]%mod;
		ntt(limit,a,0);
	}
	inline void poly_inv(int limit,int len,int *a,int *b)
	{
		if(limit==1){b[0]=fast(a[0],mod-2);return;}
		poly_inv(limit>>1,len-1,a,b);
		for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
		for(int i=0;i<(limit>>1);++i) cc[i]=a[i];
		for(int i=limit>>1;i<limit;++i) cc[i]=0;
		ntt(limit,cc,1);ntt(limit,b,1);
		for(int i=0;i<limit;++i) b[i]=((2-cc[i]*b[i]%mod)+mod)%mod*b[i]%mod;
		ntt(limit,b,0);
		for(int i=limit>>1;i<limit;++i) b[i]=0;
	}
	inline void make(int l,int r,int p)
	{
		if(l==r)
		{
			g[stg[p]=++tot]=mod-xx[l];
			g[eng[p]=++tot]=1;
			return;
		}
		int mid=(l+r)>>1;
		make(l,mid,ls(p));
		make(mid+1,r,rs(p));
		int na=eng[ls(p)]-stg[ls(p)]+1;
		int nb=eng[rs(p)]-stg[rs(p)]+1;
		init(na+nb);
		for(int i=0;i<na;i++) f1[i]=g[stg[ls(p)]+i];
		for(int i=na;i<limit;i++) f1[i]=0;
		for(int i=0;i<nb;i++) f2[i]=g[stg[rs(p)]+i];
		for(int i=nb;i<limit;i++) f2[i]=0;
		NTT(f1,f2,limit);
		stg[p]=tot+1;
		na+=nb-1;
		for(int i=0;i<na;++i) g[++tot]=f1[i];
		eng[p]=tot;
	}
	inline void solve(int l,int r,int p,int fa)
	{
		int na=enf[fa]-stf[fa],nb=eng[p]-stg[p];
		if(na>=nb)
		{
			int nc=na-nb;
			for(int i=0;i<=na;++i) a[i]=f[stf[fa]+i];
			for(int i=0;i<=nb;i++) b[i]=g[stg[p]+i];
			for(int i=0;i<=nc;i++) ra[i]=a[na-i];
			for(int i=0;i<=nb;i++) rb[i]=b[nb-i];
			for(int i=nc+1;i<=nb;i++) rb[i]=0;
			init(nc*2+2);
			for(int i=nb+1;i<limit;i++) rb[i]=0;
			for(int i=0;i<limit;i++) irb[i]=0,f1[i]=0;
			poly_inv(limit,len,rb,irb);
			for(int i=0;i<=nc;i++) f1[i]=ra[i],f2[i]=irb[i];
			for(int i=nc+1;i<limit;i++) f1[i]=f2[i]=0;
			NTT(f1,f2,limit);
			for(int i=0;i<=nc;i++) c[nc-i]=f1[i];
			for(int i=nc+1;i<nb;i++) c[i]=0;
			init(nb<<1);
			for(int i=0;i<nb;i++) f1[i]=b[i],f2[i]=c[i];
			for(int i=nb;i<limit;i++) f1[i]=0,f2[i]=0;
			NTT(f1,f2,limit);
			for(int i=0;i<nb;i++) rr[i]=(a[i]-f1[i]+mod)%mod;
			while(nb>1 && !rr[nb-1]) nb--;
			stf[p]=tot+1;
			for(int i=0;i<nb;i++) f[++tot]=rr[i];
			enf[p]=tot;
		}
		else
		{
			stf[p]=tot+1;
			for(int i=stf[fa];i<=enf[fa];++i) f[++tot]=f[i];
			enf[p]=tot;
		}
		if(l==r)
		{
			val[l]=f[stf[p]];
			return;
		}
		int mid=(l+r)>>1;
		solve(l,mid,ls(p),p);
		solve(mid+1,r,rs(p),p);
	}
	inline void work(int l,int r,int p)
	{
		if(l==r) return;
		int mid=(l+r)>>1;
		work(l,mid,ls(p));
		work(mid+1,r,rs(p));
		int na=enh[l]-sth[l]+1,nb=eng[rs(p)]-stg[rs(p)]+1;
		
		init(na+nb);
		for(int i=0;i<na;i++) f1[i]=h[sth[l]+i];
		for(int i=na;i<limit;i++) f1[i]=0;
		for(int i=0;i<nb;i++) f2[i]=g[stg[rs(p)]+i];
		for(int i=nb;i<limit;i++) f2[i]=0;
		NTT(f1,f2,limit);
		
		na+=nb-1;
		for(int i=0;i<na;i++) f3[i]=f1[i];
		for(int i=na;i<limit;i++) f3[i]=0;
		
		na=enh[mid+1]-sth[mid+1]+1,nb=eng[ls(p)]-stg[ls(p)]+1;
		for(int i=0;i<na;i++) f1[i]=h[sth[mid+1]+i];
		for(int i=na;i<limit;i++) f1[i]=0;
		for(int i=0;i<nb;i++) f2[i]=g[stg[ls(p)]+i];
		for(int i=nb;i<limit;i++) f2[i]=0;
		NTT(f1,f2,limit);
		na+=nb-1;
		for(int i=0;i<na;i++) h[sth[l]+i]=f3[i]+f1[i]>=mod?f3[i]+f1[i]-mod:f3[i]+f1[i];
		enh[l]=sth[l]+na-1;
	}
	inline void main()
	{
		n=read();
		for(int mid=1,t=1;mid<400000;mid<<=1,++t)
		{
			w[t][0]=1;int Wn=fast(3,(mod-1)/(mid<<1));
			for(int k=1;k<mid;++k)
			{
				w[t][k]=w[t][k-1]*Wn%mod;
			}
		}
		for(int i=1;i<=n;++i) xx[i]=read(),yy[i]=read();
		make(1,n,1);
		m=eng[1]-stg[1];
		for(int i=0;i<=m;++i) f1[i]=g[stg[1]+i];
		for(int i=0;i<m;++i) f1[i]=f1[i+1]*(i+1)%mod;
		f1[m--]=0;
		stf[tot=0]=1;
		for(int i=0;i<=m;++i) f[enf[0]=++tot]=f1[i];
		solve(1,n,1,0);
		for(int i=1;i<=n;i++) val[i]=yy[i]*fast(val[i],mod-2)%mod;
		tot=0;
		for(int i=1;i<=n;i++)
		{
			sth[i]=enh[i]=++tot;
			h[tot]=val[i];
		}
		work(1,n,1);
		for(int i=0;i<n;++i) printf("%lld ",h[sth[1]+i]);
	}
}
signed main()
{
	red::main();
	return 0;
}


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM