首先注意到題目中 \(a\) 數組是有序的,那我們只用算有序的方案乘上 \(n!\) 即可。
而此時的答案顯然
\[Ans=[x^n](1+x)(1+2x)\dots (1+Ax)=\prod_{i=1}^A(1+ix) \]
取對數把乘法變加法,即
\[\prod_{i=1}^A(1+ix)=\exp(\sum_{i=1}^A\ln(1+ix)) \]
這里有 \(\ln\) 的展開式
\[-\ln(1-x)=\sum_{i=1}^\infty\frac{x^i}{i} \]
故有
\[\ln(1+ix)\\=\ln(1-(-ix))\\=-\sum_{k=1}^\infty \frac{(-ix)^k}{k}\\=\sum_{k=1}^\infty \frac{(-1)^{k+1}i^k}{k}x^k \]
則
\[\sum_{i=1}^A \ln(1+ix)\\ =\sum_{k=1}^\infty \frac{(-1)^{k+1}\sum_{i=1}^Ai^k}{k}x^k \]
自然數冪和可以用某種方法(插值、伯努利數之類)算出來。
最后還要多項式 exp,直接 \(O(n^2)\) 算。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 505;
int n, M, P, inv[N], Bo[N], C[N][N], a[N], b[N];
typedef vector<int> poly;
poly F[N];
int calc(const poly&a, int x)
{
int y = 0;
for(int i = a.size() - 1; i >= 0; i --)
y = (1LL * y * x + a[i]) % P;
return y;
}
int main()
{
scanf("%d%d%d",&M,&n,&P);
for(int i = 0; i <= n + 1; i ++)
{
C[i][0] = 1;
for(int j = 1; j <= i; j ++)
C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % P;
}
inv[1] = 1;
for(int i = 2; i <= n + 1; i ++)
inv[i] = (ll) inv[P % i] * (P - P / i) % P;
Bo[0] = 1;
for(int i = 1; i <= n; i ++)
{
int t = 0;
for(int j = 0; j < i; j ++)
t = (t + (ll)Bo[j] * C[i + 1][j]) % P;
Bo[i] = (ll)(P - inv[i + 1]) * t % P;
}
F[0] = poly{0, 1};
for(int i = 1; i <= n; i ++)
{
F[i].resize(i + 2);
for(int j = 1; j <= i + 1; j ++)
{
F[i][j] = (ll) Bo[i + 1 - j] * C[i + 1][j] % P * inv[i + 1] % P;
if((i + 1 - j) & 1)
F[i][j] = (P - F[i][j]) % P;
}
}
for(int i = 1; i <= n; i ++)
{
a[i] = (ll)inv[i] * calc(F[i], M) % P;
if(~i&1) a[i] = (P - a[i]) % P;
}
for(int i = 1; i <= n; i ++)
a[i - 1] = (ll)i * a[i] % P;
b[0] = 1;
for(int i = 1; i <= n; i ++)
{
for(int j = 0; j < i; j ++)
b[i] = (b[i] + (ll)b[j] * a[i - j - 1]) % P;
b[i] = (ll)b[i] * inv[i] % P;
}
int ans = b[n];
for(int i = 2; i <= n; i ++) ans = (ll)ans * i % P;
printf("%d", ans);
return 0;
}
注意到復雜度瓶頸在於對所有 \(k\in [1,n]\) 預處理自然數冪和。
這東西的EGF
\[\sum_{k=0}^\infty \sum_{i=0}^A i^k \frac{x^k}{k!}\\ =\sum_{i=0}^A \sum_{k=0}^\infty \frac{(ix)^k}{k!}\\ = \sum_{i=0}^A e^{ix}\\= \frac{e^{(A+1)x}-1}{e^x-1} \]
多項式求逆即可,整個復雜度也是 \(O(nlogn)\)。
下面是模數 \(998244353\) 的代碼。
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<functional>
using namespace std;
typedef long long ll;
const int N = 1 << 20 | 3, P = 998244353;
int fpow(int a, int b)
{
ll x = 1, o = a;
for(; b; b >>= 1, o = o * o % P)
if(b & 1) x = x * o % P;
return x;
}
int V, n, fac[N], ifac[N], inv[N], a[N], b[N], c[N];
namespace poly
{
int Len, sz, rev[N], w[N];
void prepare(int n)
{
for(Len = 1, sz = 0; Len <= n; Len <<= 1, sz ++);
for(int i = 0; i < Len; i ++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (sz - 1));
int wn = fpow(3, (P - 1) / Len);
w[Len / 2] = 1;
for(int i = Len / 2 + 1; i < Len; i ++)
w[i] = 1LL * w[i - 1] * wn % P;
for(int i = Len / 2 - 1; i >= 0; i --)
w[i] = w[i << 1];
}
void DFT(int n, int *a, int T)
{
static unsigned long long F[N];
int shift = sz - __builtin_ctz(n), x;
for(int i = 0; i < n; i ++)
F[rev[i] >> shift] = a[i];
for(int l = 1; l < n; l <<= 1)
for(int i = 0; i < n; i += l << 1)
for(int j = 0; j < l; j ++)
{
x = F[i + j + l] * w[j + l] % P;
F[i + j + l] = F[i + j] + P - x;
F[i + j] += x;
}
for(int i = 0; i < n; i ++)
a[i] = F[i] % P;
if(T)
{
x = fpow(n, P - 2);
for(int i = 0; i < n; i ++)
a[i] = (ll) a[i] * x % P;
reverse(a + 1, a + n);
}
}
void Inverse(int n, int *a, int *b)
{
if(n == 1)
{
*b = fpow(*a, P - 2);
return;
}
Inverse((n + 1) >> 1, a, b);
static int c[N], len;
for(len = 1; len < n << 1; len <<= 1);
for(int i = 0; i < len; i ++)
i < n ? c[i] = a[i] : c[i] = b[i] = 0;
DFT(len, b, 0);
DFT(len, c, 0);
for(int i = 0; i < len; i ++)
b[i] = 1LL * b[i] * (P + 2 - 1LL * b[i] * c[i] % P) % P;
DFT(len, b, 1);
for(int i = n; i < len; i ++) b[i] = 0;
}
void Exp(int n, int *a, int *b)
{
static int c[N], d[N];
for(int i = 1; i < n; i ++)
c[i - 1] = (ll) a[i] * i % P;
c[n - 1] = 0;
for(int i = 0; i < n; i ++)
d[i] = 0;
function<void(int,int)> solve = [&](int l, int r)
{
if(l == r)
{
if(!l)
d[l] = 1;
else
d[l] = (ll)d[l] * inv[l] % P;
return;
}
int mid = (l + r) / 2;
solve(l, mid);
static int A[N], B[N];
int L = 1;
while(L <= r - l) L <<= 1;
memset(A, 0, L << 2);
memset(B, 0, L << 2);
memcpy(A, d + l, (mid - l + 1) << 2);
memcpy(B, c, (r - l) << 2);
DFT(L, A, 0);
DFT(L, B, 0);
for(int i = 0; i < L; i ++)
A[i] = (ll) A[i] * B[i] % P;
DFT(L, A, 1);
for(int i = mid + 1; i <= r; i ++)
d[i] = (d[i] + A[i - l - 1]) % P;
solve(mid + 1, r);
};
solve(0, n - 1);
memcpy(b, d, n << 2);
}
}
int main()
{
scanf("%d %d", &V, &n);
poly::prepare(n + n);
fac[0] = ifac[0] = 1;
for(int i = 1; i <= n + 1; i ++)
{
fac[i] = (ll)fac[i - 1] * i % P;
inv[i] = (i != 1 ?
(ll)inv[P % i] * (P - P / i) % P : 1);
ifac[i] = (ll)ifac[i - 1] * inv[i] % P;
}
// (e^x)^(V+1)-1 / e^x-1
for(int i = 0, e = V + 1; i <= n; i ++, e = 1LL * e * (V + 1) % P)
{
a[i] = 1LL * e * ifac[i + 1] % P;
b[i] = ifac[i + 1];
}
poly::Inverse(n + 1, b, c);
int L = poly::Len;
poly::DFT(L, a, 0);
poly::DFT(L, c, 0);
for(int i = 0; i != L; i ++) a[i] = 1LL * a[i] * c[i] % P;
poly::DFT(L, a, 1);
a[0] = 0;
for(int i = 1; i <= n; i ++)
{
a[i] = (ll)a[i] * fac[i] % P * inv[i] % P;
if((i + 1) & 1) a[i] = (P - a[i]) % P;
}
poly::Exp(n + 1, a, b);
for(int i = 1; i <= n; i ++)
printf("%d\n", int(1LL * b[i] * fac[i] % P));
return 0;
}