轉置原理口胡
抄自:cy的WC2020課件、rqy的uoj博客和個人博客
本文沒有任何嚴謹證明,基本都是博主自己口胡的。
也不保證不會出鍋,因為博主是垃圾。
線性算法
見cy的WC2020課件。
轉化
有一個DAG,點數為\(n+k+m\),滿足所有邊形如u v a
\((n<u<v)\)。
DAG分為三部分:前\(n\)個點(\([1,n]\))是輸入,中間\(k\)個(\([n+1,n+k]\))是中間變量,后\(m\)個(\([n+k+1,n+k+m] \))是輸出。
在這個DAG上運行如下算法:
- 輸入前\(n\)個點的權值。
- 按編號順序遍歷后\(k+m\)個點,將每個點的權值設為\(w_v=\sum_{(u,v,a)\in E}a\times w_u\)。
- 輸出后\(m\)個點的權值。
可以發現這個模型可以實現任何線性算法。
容易發現,\(\forall i>n\),\(w_i\)是\(w_1,w_2,\cdots,w_n\)的線性組合。
所以存在一個\(n\times m\)的矩陣\(A\),\(A_{i,j}\)表示輸入\(w_i\)到輸出\(w_{n+k+j}\)的貢獻。
現在將\([n+k+1,n+k+m]\)看做輸入,\([1,n]\)看做輸出,邊全部反向,並保持邊權不變。
反向運行上述算法,容易發現,\(\forall i\le n+k\),\(w_i\)是\(w_{n+k+1},w_{n+k+2},\cdots,w_{n+k+m}\)的線性組合。
所以存在一個\(m\times n\)的矩陣\(B\),\(B_{i,j}\)表示輸入\(w_{n+k+i}\)到輸出\(w_j\)的貢獻。
上述兩個矩陣的“貢獻”其實就是路徑權值和(路徑的權值是所有邊權積)所以顯然有\(A_{i,j}=B_{j,i}\)即\(A=B^T\)。
原來的dag對應一個線性算法,新的dag對應另外一個線性算法,它們的計算次數完全相同。(至於加法次數的略微差別是因為第一次加法可以直接賦值,這個可以忽略)
實際的線性算法是可以復用空間的,但這里我懶得寫了。最后可以對應到cy的PPT中的構造。
多點求值
給定多項式\(F(x)=\sum_{i=0}^{n-1}f_ix^i\)
求\(ans_i=F(q_i),i\in[0,m)\)
問題是
看成\(uA=v\),考慮問題\(u'A^T=v'\),它的輸入是\(u'\)輸出是\(v'\)。即:
可以發現求出\(\sum_{j=0}^{m-1}\frac{g[j]}{1-xq_j}\)就完事兒了,這個很好做,分治就完事了
新問題的求解過程如下
- 新問題以\(g_0,g_1,\cdots,g_{m-1}\)作為輸入,\(b_0,b_1,\cdots,b_{n-1}\)作為輸出
- \(q\)始終是常量,不參與轉置
- 線段樹上維護兩個信息:\(P_x,Q_x\),分別表示這個節點對應區間的\(\sum\frac{g[j]}{1-xq_j}\)的分子和分母
- 由於\(q\)為常量\(Q\)可以直接確定,只需向上求\(P_x=P_{ls}Q_{rs}+P_{rs}Q_{ls}\)
- 求出線段樹根節點的\(P_1Q_1^{-1}\)的\(0\)到\(n-1\)項系數即為答案
多項式乘法
見cy的PPT
轉置回原問題,求解過程
分治的過程\(P_x=P_{ls}Q_{rs}+P_{rs}Q_{ls}\)可以看成是:
P2_ls=P_ls*Q_rs
P2_rs=P_rs*Q_ls
Px=Px+P2_ls
Px=Px+P2_rs
重寫之后是:
P2_rs=P2_rs+Px
P2_ls=P2_ls+Px
P_rs=P2_rs *^T Q_ls
P_ls=P2_ls *^T Q_rs
- 以\(f_0,f_1,\cdots,f_{n-1}\)作為輸入,\(ans_0,ans_1,\cdots,ans_{m-1}\)作為輸出
- 一開始先分治求出所有的\(Q\)(\(q\)始終是常量)
- 計算\(f\times^T Q_1^{-1}\),保留\(m\)到\(n+m-2\)項系數作為\(P_1\)
- 從上到下分治,令\(P_{ls}=P_{x}\times^TQ_{rs},P_{rs}=P_{x}\times^TQ_{ls}\)
- \(ans_i\)為第\(i\)個葉子的\(P\)中常數項
因為博主水平不行,必須要有\(m\ge n\)以保證復雜度(如果\(m<n\),因為\(P_1\)次數為\(n\)將導致所有的\(P\)次數增加\(n-m\))
#include<bits/stdc++.h>
typedef long long ll;
#define mod 998244353
#define poly std::vector<int>
ll gi(){
ll x=0,f=1;
char ch=getchar();
while(!isdigit(ch))f^=ch=='-',ch=getchar();
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
std::mt19937 rnd(time(NULL));
#define rand rnd
#define pr std::pair<int,int>
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
template<class T>void cxk(T&a,T b){a=a>b?a:b;}
template<class T>void cnk(T&a,T b){a=a<b?a:b;}
#ifdef mod
int pow(int x,int y){
int ret=1;
while(y){
if(y&1)ret=1ll*ret*x%mod;
x=1ll*x*x%mod;y>>=1;
}
return ret;
}
template<class Ta,class Tb>void inc(Ta&a,Tb b){a=a+b>=mod?a+b-mod:a+b;}
template<class Ta,class Tb>void dec(Ta&a,Tb b){a=a>=b?a-b:a+mod-b;}
template<class Ta,class Tb>int sub(Ta&a,Tb b){return a>=b?a-b:a+mod-b;}
#endif
int coef[65539],qx[65539],Q[262147],rev[131113],A[131113],B[131113],N,lg,ans[65539];
void setN(int n){
lg=32-__builtin_clz(n),N=1<<lg;
for(int i=0;i<N;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
}
int getN(int n){return 1<<32-__builtin_clz(n);}
void ntt(int*A,int t){
for(int i=0;i<N;++i)if(i>rev[i])std::swap(A[i],A[rev[i]]);
for(int o=1,*qq=Q+o*2;o<N;o<<=1,qq=Q+o*2)
for(int*p=A;p!=A+N;p+=o<<1)
for(int i=0;i<o;++i){
int t=1ll*p[i+o]*qq[i]%mod;
p[i+o]=sub(p[i],t),inc(p[i],t);
}
if(!t){
std::reverse(A+1,A+N);
for(int i=0,iv=pow(N,mod-2);i<N;++i)A[i]=1ll*A[i]*iv%mod;
}
}
poly mul(const poly&x,const poly&y){
int len=x.size()+y.size()-1;setN(len);
memset(A,0,N<<2);memset(B,0,N<<2);
for(int i=0;i<x.size();++i)A[i]=x[i];
for(int i=0;i<y.size();++i)B[i]=y[i];
ntt(A,1),ntt(B,1);for(int i=0;i<N;++i)A[i]=1ll*A[i]*B[i]%mod;ntt(A,0);
poly z(len);for(int i=0;i<len;++i)z[i]=A[i];
return z;
}
poly mulT(const poly&x,const poly&y){
int len=x.size();setN(len);
memset(A,0,N<<2);memset(B,0,N<<2);
for(int i=0;i<x.size();++i)A[i]=x[i];
for(int i=0;i<y.size();++i)B[i]=y[i];
std::reverse(B,B+y.size());
ntt(A,1),ntt(B,1);for(int i=0;i<N;++i)A[i]=1ll*A[i]*B[i]%mod;ntt(A,0);
poly z(x.size()-y.size()+1);for(int i=0;i<z.size();++i)z[i]=A[i+y.size()-1];return z;
}
poly getinv(poly x){
if(x.size()==1)return{pow(x[0],mod-2)};
int n=x.size(),m=x.size()+1>>1;
poly y(x.begin(),x.begin()+m),_y;_y=y=getinv(y);
setN(x.size()*2+2);y.resize(N);x.resize(N);
ntt(&y[0],1);ntt(&x[0],1);
for(int i=0;i<N;++i)x[i]=1ll*x[i]*y[i]%mod*y[i]%mod;
ntt(&x[0],0);
for(int i=0;i<n;++i)x[i]=((i<m?2ll*_y[i]:0ll)-x[i]+mod)%mod;
x.resize(n);return x;
}
#define mid ((l+r)>>1)
poly qwq[262147],qaq[262147];
void divide1(int x,int l,int r){
if(l==r){qwq[x]={1,mod-qx[l]};return;}
divide1(x<<1,l,mid),divide1(x<<1|1,mid+1,r);
qwq[x]=mul(qwq[x<<1],qwq[x<<1|1]);
}
void divide2(int x,int l,int r){
if(l==r){ans[l]=qaq[x][0];return;}
setN(qaq[x].size());
{
poly&a=qaq[x],&y=qwq[x<<1|1],&z=qaq[x<<1];
memset(A,0,N<<2);memset(B,0,N<<2);
for(int i=0;i<a.size();++i)A[i]=a[i];
for(int i=0;i<y.size();++i)B[i]=y[i];
std::reverse(B,B+y.size());
ntt(A,1),ntt(B,1);for(int i=0;i<N;++i)B[i]=1ll*A[i]*B[i]%mod;ntt(B,0);
z.resize(a.size()-y.size()+1);
for(int i=0;i<z.size();++i)z[i]=B[i+y.size()-1];
}
{
poly&a=qaq[x],&y=qwq[x<<1],&z=qaq[x<<1|1];
memset(B,0,N<<2);
for(int i=0;i<y.size();++i)B[i]=y[i];
std::reverse(B,B+y.size());
ntt(B,1);for(int i=0;i<N;++i)B[i]=1ll*A[i]*B[i]%mod;ntt(B,0);
z.resize(a.size()-y.size()+1);
for(int i=0;i<z.size();++i)z[i]=B[i+y.size()-1];
}
divide2(x<<1,l,mid);
divide2(x<<1|1,mid+1,r);
}
int main(){
#ifdef LOCAL
freopen("in.in","r",stdin);
//freopen("out.out","w",stdout);
#endif
for(int o=1;o<=(1<<17);o<<=1){
int P=pow(19260817,mod/o);
Q[o]=1;for(int i=1;i<o;++i)Q[i+o]=1ll*Q[i+o-1]*P%mod;
}
int n=gi()+1,m=gi(),_m=m;
for(int i=0;i<n;++i)coef[i]=gi();
for(int i=0;i<m;++i)qx[i]=gi();
cxk(m,n);divide1(1,0,m-1);
poly sfm=getinv(qwq[1]);std::reverse(all(sfm));
poly s=mul(poly(coef,coef+n),sfm);
qaq[1]=poly(s.begin()+m,s.end());
qaq[1].resize(m+1);
divide2(1,0,m-1);
for(int i=0;i<_m;++i)printf("%d\n",ans[i]);
return 0;
}