BM 算法
BM 算法,全名 Berlekamp-Massey 算法,是一個可以 \(O(n^2)\) 求出一個數列的最短線性遞推式的算法。其主要思想(大概)是一項一項加入,若不符合當前猜測的遞推式則對其進行調整。
假設我們欲求數列 \({a_0,a_1,\cdots,a_n}\) 的最短線性遞推式,設 \(r^{(i)}\) 是 \(a_0,\cdots,a_i\) 的最短線性遞推式,\(l_i\) 為 \(r^{(i)}\) 的階數(可以直觀理解成項數 +1)。初始時令 \(r^{(-1)}=\{1\}\),然后每次考慮加入一個新的數,對每個前綴計算其最短線性遞推式。
引理
先給出一個引理,其給出了每個前綴的線性遞推式的最短長度:如果 \(r^{(i-1)}\) 不是 \(a_0,\cdots,a_i\) 的線性遞推式,則 \(l_i\ge\max(l_{i-1},i+1-l_{i-1})\)。
證明:\(l_i\ge l_{i-1}\) 顯然在任何情況下成立。
若有 \(l_i< i+1-l_{i-1}\),則對於所有 \(j\ge l_i\) 下式成立:
這也說明 \(r^{(i-1)}\) 是 \(a_0,\cdots,a_i\) 的遞推式,與假設矛盾,得證。
有限數列
這樣我們就拿到了下界,且由反證過程可以看到它是緊的。我們來考慮一下能不能構造到這個下界。
假設 \(A\) 是 \(a\) 的生成函數,\(R_i\) 是 \(r^{(i)}\) 的生成函數。
考慮新加入一個數 \(a_i\),若之前的遞推式仍然可行(即 \(AR_{i-1}\equiv S_i\pmod {x^{i+1}}\),其中 \(S_i\) 的次數 \(< l_{i-1}\)),則直接沿用之前的遞推式即可。反之,則會出現 \(AR_{i-1}\equiv S_{i-1}+cx^i\pmod {x^{i+1}}\),考慮怎么用原來的方案修正。考慮我們上一次修正(假設在 \(p\) 時刻)的時候有 \(AR_{p-1}\equiv S_{p-1}+dx^p\pmod{x^{p+1}}\),為了構造出 \(cx^i\),我們給等式左右兩邊同乘一個 \(x^{i-p}cd^{-1}\),即 \(x^{i-p}cd^{-1}AR_{p-1}\equiv x^{i-p}cd^{-1}S_{p-1}+cx^i\pmod{x^{i+1}}\);然后我們再將兩式相減,可以得到 \(A(R_{i-1}-x^{i-p}cd^{-1}R_{p-1})\equiv S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\pmod{x^{i+1}}\),即可得到 \(R_i=R_{i-1}-x^{i-p}cd^{-1}R_{p-1},S_i=S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\)。
我們來歸納證明其取到了下界。假設有 \(l_p=\max(l_{p-1},p+1-l_{p-1})\),則由上面我們算出來的 \(R_i\) 的值有 \(l_i=\max(l_{i-1},i-p+l_{p-1})\)。若 \(l_p=l_{p-1}\),則繼續往下歸納即可;否則有 \(l_p=p+1-l_{p-1}\),\(l_i=\max(l_{i-1},i-p+l_{p-1})=\max(l_{i-1},i+1-l_{p})\),得證。
無限數列
如果數列是無限數列,但我們知道 \(l_{\infty}\le s\),那么我們可以僅計算 \(a_0,\cdots,a_{2s}\) 的遞推式,因為若我們在 \(t>2s\) 處更新了遞推式,則會有 \(l_t=\max(l_{t-1},t+1-l_{t-1})>s\),與條件矛盾。
行/列向量列,矩陣列
假設欲計算 \(n\) 維行向量列 \(v_0,v_1,\cdots\) 的最短線性遞推式,可以在模 \(p\) 意義下隨機一個 \(n\) 維列向量 \(l\),則有 \(1-\frac{n}{p}\) 的概率使得 \(v_0l,v_1l,\cdots\) 的最短線性遞推式就是 原行向量列的最短線性遞推式。列向量同理。
矩陣列則可以隨機一個 \(n\) 維列向量 \(u\) 與列向量 \(v\),計算 \(ub_0v,ub_1v,\cdots\) 的最短線性遞推式即可,有 \(1-\frac{n+m}{p}\) 的概率出錯。
概率好像是根據 Schwartz-Zippel 引理推知的,但是是怎么推來的嘛...不知道(
求矩陣最小多項式
\(n \times n\) 的矩陣 \(B\) 的最小多項式是次數最小的使得 \(f(B) = 0\) 的多項式 \(f\)。
對矩陣列 \(I,B,B^2,B^3,\cdots,B^{2n}\) 計算線性遞推式即可,因為 \(B\) 的特征多項式滿足 \(f(B)=0\),所以其最小多項式次數必定 \(\le n\)。
具體而言,我們需要計算 \(uIv,uBv,\cdots,uB^{2n}v\)。因為矩陣乘向量是 \(n^2\) 的,所以我們可以先 \(n^3\) 算出 \(uI,uB,\cdots,uB^{2n}\),再用 \(n^3\) 把 \(v\) 一一乘上去。
優化一類 dp
假設一類 dp 可以用矩陣快速冪計算(即形如 \(dp_{i,j}=\sum dp_{i-1,k}\cdot f_{k,j}\)),最后要求 \(\sum k_idp_{n,i}\) 或其他類似的東西,有一種萬能的辦法:把它的前 2k 項扔進 BM 算出它的遞推式,然后直接 \(k^2\log n\) 或者 \(k\log k\log n\) 計算即可。因為矩陣是有最小多項式的,所以這個 dp 本質上也是一個線性遞推,由 \(B\) 的最小多項式定義,所以其遞推項數不超過 \(k\)。
知道了某個遞推式,怎么快速算某一項?
code:
#include <bits/stdc++.h>
using namespace std;
int n , m , l[11000] , p;
long long a[11000] , r[11000] , rr[11000] , sav[11000] , las , noww , ans;
const int mod = 998244353;
long long exp( int a , int b )
{
long long ans = 1 , t = a;
while(b)
{
if(b & 1) ans = ans * t % mod;
t = t * t % mod; b >>= 1;
}
return ans;
}
struct poly
{
long long a[11000]; int len;
poly operator * ( const poly &x ) const
{
poly ans; memset(ans.a , 0 , sizeof(ans.a)); ans.len = len + x.len;
for(int i = 0 ; i <= len ; i++ )
{
for(int j = 0 ; j <= x.len ; j++ )
{
(ans.a[i + j] += a[i] * x.a[j] % mod) %= mod;
}
}
return ans;
}
poly operator % ( const poly &x ) const
{
poly ans; ans = (*this);
long long coe = exp(x.a[x.len] , mod - 2);
for(int i = ans.len ; i >= x.len ; i-- )
{
long long qwq = ans.a[i] * coe % mod;
for(int j = 0 ; j <= x.len ; j++ )
(ans.a[i - x.len + j] += mod - qwq * x.a[j] % mod) %= mod;
}
ans.len = min(ans.len , x.len - 1);
return ans;
}
} f , g , e;
poly exp( poly a , int b , poly p )
{
poly ans = e , t = a;
while(b)
{
if(b & 1) ans = (ans * t) % p;
t = (t * t) % p; b >>= 1;
}
return ans;
}
int main()
{
// freopen("1.in" , "r" , stdin);
// freopen("1.out" , "w" , stdout);
scanf("%d%d" , &n , &m);
for(int i = 0 ; i < n ; i++ ) scanf("%lld" , &a[i]); r[0] = 1; p = -1;
for(int i = 0 ; i < n ; i++ )
{
noww = 0; l[i] = l[i - 1];
for(int j = 0 ; j <= l[i] ; j++ ) (noww += r[j] * a[i - j] % mod) %= mod;
if(!noww) continue;
// cerr << las << ' ' << noww << endl;
long long coe = exp(las , mod - 2) * noww % mod;
if(!las)
{
memcpy(rr , r , sizeof(rr));
l[i] = i + 1; r[i + 1] = 1;
las = noww; p = i;
}
else
{
memcpy(sav , rr , sizeof(sav)); memcpy(rr , r , sizeof(rr));
l[i] = max(l[i - 1] , i + 1 - l[i - 1]);
for(int j = 0 ; j <= l[p - 1] ; j++ )
(r[i - p + j] += mod - coe * sav[j] % mod) %= mod;
p = i; las = noww;
}
// for(int j = 0 ; j <= l[i] ; j++ ) printf("%lld " , r[j]); printf("\n");
}
// for(int i = 0 ; i < n ; i++ ) cerr << l[i] << ' '; cerr << endl;
for(int i = 1 ; i <= l[n - 1] ; i++ ) printf("%lld " , (mod - r[i]) % mod); printf("\n");
e.a[0] = 1; g.a[1] = g.len = 1; f.len = l[n - 1];
for(int i = 0 ; i <= l[n - 1] ; i++ ) f.a[l[n - 1] - i] = r[i];
g = exp(g , m , f);
// for(int i = 0 ; i <= g.len ; i++ ) cerr << g.a[i] << ' '; cerr << endl;
for(int i = 0 ; i <= l[n - 1] ; i++ ) (ans += g.a[i] * a[i] % mod) %= mod;
printf("%lld" , ans);
return 0;
}
/*
*/