淺談 BM 算法


BM 算法

BM 算法,全名 Berlekamp-Massey 算法,是一個可以 \(O(n^2)\)​ 求出一個數列的最短線性遞推式的算法。其主要思想(大概)是一項一項加入,若不符合當前猜測的遞推式則對其進行調整。

假設我們欲求數列 \({a_0,a_1,\cdots,a_n}\)​ 的最短線性遞推式,設 \(r^{(i)}\)​ 是 \(a_0,\cdots,a_i\)​ 的最短線性遞推式,\(l_i\)​ 為 \(r^{(i)}\)​ 的階數(可以直觀理解成項數 +1)。初始時令 \(r^{(-1)}=\{1\}\)​,然后每次考慮加入一個新的數,對每個前綴計算其最短線性遞推式。

引理

先給出一個引理,其給出了每個前綴的線性遞推式的最短長度:如果 \(r^{(i-1)}\) 不是 \(a_0,\cdots,a_i\) 的線性遞推式,則 \(l_i\ge\max(l_{i-1},i+1-l_{i-1})\)

證明:\(l_i\ge l_{i-1}\) 顯然在任何情況下成立。

若有 \(l_i< i+1-l_{i-1}\),則對於所有 \(j\ge l_i\) 下式成立:

\[\begin{aligned} a_i&=-\sum_{j=1}^{l_{i-1}}p_ja_{i-j}=\sum_{j=1}^{l_{i-1}}p_j\sum_{k=1}^{l_i}q_ka_{i-j-k}\ (l_i< i+1-l_{i-1}) \\&=\sum_{j=1}^{l_i}q_j\sum_{k=1}^{l_{i-1}}p_ka_{i-j-k}=-\sum_{j=1}^{l_i}q_ja_{i-j} \end{aligned} \]

這也說明 \(r^{(i-1)}\)​ 是 \(a_0,\cdots,a_i\)​​ 的遞推式,與假設矛盾,得證。

有限數列

這樣我們就拿到了下界,且由反證過程可以看到它是緊的。我們來考慮一下能不能構造到這個下界。

假設 \(A\)\(a\) 的生成函數,\(R_i\)\(r^{(i)}\) 的生成函數。

考慮新加入一個數 \(a_i\)​,若之前的遞推式仍然可行(即 \(AR_{i-1}\equiv S_i\pmod {x^{i+1}}\)​,其中 \(S_i\)​ 的次數 \(< l_{i-1}\)​),則直接沿用之前的遞推式即可。反之,則會出現 \(AR_{i-1}\equiv S_{i-1}+cx^i\pmod {x^{i+1}}\)​,考慮怎么用原來的方案修正。考慮我們上一次修正(假設在 \(p\)​ 時刻)的時候有 \(AR_{p-1}\equiv S_{p-1}+dx^p\pmod{x^{p+1}}\)​,為了構造出 \(cx^i\)​,我們給等式左右兩邊同乘一個 \(x^{i-p}cd^{-1}\)​,即 \(x^{i-p}cd^{-1}AR_{p-1}\equiv x^{i-p}cd^{-1}S_{p-1}+cx^i\pmod{x^{i+1}}\)​;然后我們再將兩式相減,可以得到 \(A(R_{i-1}-x^{i-p}cd^{-1}R_{p-1})\equiv S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\pmod{x^{i+1}}\)​,即可得到 \(R_i=R_{i-1}-x^{i-p}cd^{-1}R_{p-1},S_i=S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\)​。

我們來歸納證明其取到了下界。假設有 \(l_p=\max(l_{p-1},p+1-l_{p-1})\),則由上面我們算出來的 \(R_i\) 的值有 \(l_i=\max(l_{i-1},i-p+l_{p-1})\)。若 \(l_p=l_{p-1}\),則繼續往下歸納即可;否則有 \(l_p=p+1-l_{p-1}\)\(l_i=\max(l_{i-1},i-p+l_{p-1})=\max(l_{i-1},i+1-l_{p})\)​,得證。

無限數列

如果數列是無限數列,但我們知道 \(l_{\infty}\le s\),那么我們可以僅計算 \(a_0,\cdots,a_{2s}\) 的遞推式,因為若我們在 \(t>2s\) 處更新了遞推式,則會有 \(l_t=\max(l_{t-1},t+1-l_{t-1})>s\),與條件矛盾。

行/列向量列,矩陣列

假設欲計算 \(n\) 維行向量列 \(v_0,v_1,\cdots\) 的最短線性遞推式,可以在模 \(p\) 意義下隨機一個 \(n\) 維列向量 \(l\),則有 \(1-\frac{n}{p}\) 的概率使得 \(v_0l,v_1l,\cdots\)​​​ 的最短線性遞推式就是 原行向量列的最短線性遞推式。列向量同理。

矩陣列則可以隨機一個 \(n\)​ 維列向量 \(u\)​ 與列向量 \(v\)​,計算 \(ub_0v,ub_1v,\cdots\)​ 的最短線性遞推式即可,有 \(1-\frac{n+m}{p}\)​​ 的概率出錯。

概率好像是根據 Schwartz-Zippel 引理推知的,但是是怎么推來的嘛...不知道(

求矩陣最小多項式

\(n \times n\) 的矩陣 \(B\) 的最小多項式是次數最小的使得 \(f(B) = 0\) 的多項式 \(f\)

對矩陣列 \(I,B,B^2,B^3,\cdots,B^{2n}\)​ 計算線性遞推式即可,因為​ \(B\) 的特征多項式滿足 \(f(B)=0\),所以其最小多項式次數必定 \(\le n\)

具體而言,我們需要計算 \(uIv,uBv,\cdots,uB^{2n}v\)。因為矩陣乘向量是 \(n^2\) 的,所以我們可以先 \(n^3\) 算出 \(uI,uB,\cdots,uB^{2n}\),再用 \(n^3\)\(v\) 一一乘上去。

優化一類 dp

假設一類 dp 可以用矩陣快速冪計算(即形如 \(dp_{i,j}=\sum dp_{i-1,k}\cdot f_{k,j}\)),最后要求 \(\sum k_idp_{n,i}\) 或其他類似的東西,有一種萬能的辦法:把它的前 2k 項扔進 BM 算出它的遞推式,然后直接 \(k^2\log n\) 或者 \(k\log k\log n\) 計算即可。因為矩陣是有最小多項式的,所以這個 dp 本質上也是一個線性遞推,由 \(B\) 的最小多項式定義,所以其遞推項數不超過 \(k\)

知道了某個遞推式,怎么快速算某一項?

Cayley-Hamilton 定理

code:

#include <bits/stdc++.h>
using namespace std;
int n , m , l[11000] , p;
long long a[11000] , r[11000] , rr[11000] , sav[11000] , las , noww , ans;
const int mod = 998244353;
long long exp( int a , int b )
{
	long long ans = 1 , t = a;
	while(b)
	{
		if(b & 1) ans = ans * t % mod;
		t = t * t % mod; b >>= 1;
	}
	return ans;
}
struct poly
{
	long long a[11000]; int len;
	poly operator * ( const poly &x ) const
	{
		poly ans; memset(ans.a , 0 , sizeof(ans.a)); ans.len = len + x.len;
		for(int i = 0 ; i <= len ; i++ )
		{
			for(int j = 0 ; j <= x.len ; j++ )
			{
				(ans.a[i + j] += a[i] * x.a[j] % mod) %= mod;
			}
		}
		return ans;
	}
	poly operator % ( const poly &x ) const
	{
		poly ans; ans = (*this);
		long long coe = exp(x.a[x.len] , mod - 2);
		for(int i = ans.len ; i >= x.len ; i-- )
		{
			long long qwq = ans.a[i] * coe % mod;
			for(int j = 0 ; j <= x.len ; j++ ) 
				(ans.a[i - x.len + j] += mod - qwq * x.a[j] % mod) %= mod;
		}
		ans.len = min(ans.len , x.len - 1);
		return ans;
	} 
} f , g , e;
poly exp( poly a , int b , poly p )
{
	poly ans = e , t = a;
	while(b)
	{
		if(b & 1) ans = (ans * t) % p;
		t = (t * t) % p; b >>= 1;
	}
	return ans;
}
int main() 
{
//	freopen("1.in" , "r" , stdin);
//	freopen("1.out" , "w" , stdout);
	scanf("%d%d" , &n , &m);
	for(int i = 0 ; i < n ; i++ ) scanf("%lld" , &a[i]); r[0] = 1; p = -1;
	for(int i = 0 ; i < n ; i++ )
	{
		noww = 0; l[i] = l[i - 1];
		for(int j = 0 ; j <= l[i] ; j++ ) (noww += r[j] * a[i - j] % mod) %= mod;
		if(!noww) continue;
//		cerr << las << ' ' << noww << endl;
		long long coe = exp(las , mod - 2) * noww % mod;
		if(!las)
		{
			memcpy(rr , r , sizeof(rr));
			l[i] = i + 1; r[i + 1] = 1;
			las = noww; p = i;
		} 
		else
		{
			memcpy(sav , rr , sizeof(sav)); memcpy(rr , r , sizeof(rr));
			l[i] = max(l[i - 1] , i + 1 - l[i - 1]);
			for(int j = 0 ; j <= l[p - 1] ; j++ )
				(r[i - p + j] += mod - coe * sav[j] % mod) %= mod;
			p = i; las = noww;
		}
//		for(int j = 0 ; j <= l[i] ; j++ ) printf("%lld " , r[j]); printf("\n");
	}
//	for(int i = 0 ; i < n ; i++ ) cerr << l[i] << ' '; cerr << endl;
	for(int i = 1 ; i <= l[n - 1] ; i++ ) printf("%lld " , (mod - r[i]) % mod); printf("\n");
	e.a[0] = 1; g.a[1] = g.len = 1; f.len = l[n - 1];
	for(int i = 0 ; i <= l[n - 1] ; i++ ) f.a[l[n - 1] - i] = r[i];
	g = exp(g , m , f);
//	for(int i = 0 ; i <= g.len ; i++ ) cerr << g.a[i] << ' '; cerr << endl;
	for(int i = 0 ; i <= l[n - 1] ; i++ ) (ans += g.a[i] * a[i] % mod) %= mod;
	printf("%lld" , ans);
    return 0; 
}
/*
*/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM