title: >-
The 2021 ICPC Asia Shanghai Regional Programming Contest - B. Strange
Permutations
date: 2021-12-13 15:27:09
tags: [inclusion-exclusion, combinatorics, math, FFT, team training, merge]
题意
给一个全排列 \(P\) ,计算构造全排列 \(Q\) 使得 \(\forall i \in \{1, 2, \cdots, n - 1\}, Q_{i+1} \neq P_{Q_i}\) 的方案数
思路
抽象题意:取编号 \(1\) ~ \(n\) 的点出来,每个点上有一个值,表示不能连出的边,计算所有经过且仅经过 \(1\) 次每个顶点的有向路径(哈密顿路径)的方案数
(图中橙色边表示不可连,蓝色边表示可连)
考虑容斥,枚举破坏 \(i\) 个条件(有 \(i\) 橙色边)
由于是全排列,所以必然有若干个圈(含自环),每个 \(k\) 元环可贡献 \(0\) ~ \(k-1\) 个橙色边(因为每个点只经过一次,不可能形成回路),则贡献的生成函数为
\[\begin{aligned} &1 + C_k^1 \cdot x + C_k^2 \cdot x^2 + \cdots + C_k^{k-1} \cdot x^{k-1} \\ =\ & (1 + x) ^k - x^k \end{aligned} \]
然后找出所有环,启发式合并这些多项式即可
复杂度 \(O(n\ log^2\ n)\)
代码
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define i64 long long
#define poly std::vector<int>
// dont visit a[m] when a.size() <= m
// (a = fastpow(c,n-m+1,m+1)).resize(m+1);
// i64 res = a[m] - b[m];
// (b = fastpow(d,n-m+1,m+1)).resize(m+1);
constexpr int MOD = 998244353;
namespace Poly { // remember to resize
const int N = (1 << 21), g = 3;
inline int power(int x, int p) {
int res = 1;
for (; p; p >>= 1, x = (ll)x * x % MOD)
if (p & 1)
res = (ll)res * x % MOD;
return res;
}
inline int fix(const int x) { return x >= MOD ? x - MOD : x; }
void dft(poly& A, int n) {
static ull W[N << 1], *H[30], *las = W, mx = 0;
for (; mx < n; mx++) {
H[mx] = las;
ull w = 1, wn = power(g, (MOD - 1) >> (mx + 1));
for(int i=0;i<1<<n;++i) *las++ = w, w = w * wn % MOD;
}
if (A.size() != (1 << n))
A.resize(1 << n);
static ull a[N];
for (int i = 0, j = 0; i < (1 << n); ++i) {
a[i] = A[j];
for (int k = 1 << (n - 1); (j ^= k) < k; k >>= 1);
}
for (int k = 0, d = 1; k < n; k++, d <<= 1)
for (int i = 0; i < (1 << n); i += (d << 1)) {
ull *l = a + i, *r = a + i + d, *w = H[k], t;
for (int j = 0; j < d; j++, l ++, r++) {
t = (*r) * (*w++) % MOD;
*r = *l + MOD - t, *l += t;
}
}
for(int i=0;i<1<<n;++i) A[i] = a[i] % MOD;
}
void idft(poly &a, int n) {
a.resize(1 << n), reverse(a.begin() + 1, a.end());
dft(a, n);
int inv = power(1 << n, MOD - 2);
for(int i=0;i<1<<n;++i) a[i] = (ll)a[i] * inv % MOD;
}
poly FIX(poly a) {
while (!a.empty() && !a.back()) a.pop_back();
return a;
}
// remember to resize
poly mul(poly a, poly b, int t = 1) {
if (t == 1 && a.size() + b.size() <= 24) {
poly c(a.size() + b.size(), 0);
for(int i=0;i<a.size();++i) for(int j=0;j<b.size();++j) c[i + j] = (c[i + j] + (ll)a[i] * b[j]) % MOD;
return FIX(c);
}
int n = 1, aim = a.size() * t + b.size();
while ((1<<n) <= aim) n++;
dft(a, n); dft(b, n);
if (t == 1)
for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * b[i] % MOD;
else
for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * a[i] % MOD * b[i] % MOD;
idft(a, n); a.resize(aim);
return FIX(a);
}
int Merge(std::vector<poly>&a) { // return index
std::priority_queue<std::pair<int,int> > H; // <-size, index>
int n = a.size();
for(int i=0;i<n;++i) {
H.emplace(-a[i].size(), i);
}
while(H.size()>=2) {
int o1 = H.top().second; H.pop();
int o2 = H.top().second; H.pop();
poly res = mul(a[o1], a[o2]);
a[o1].clear(); a[o2].clear();
for(int i=0;i<res.size();++i) a[o1].push_back(res[i]);
H.emplace(-a[o1].size(), o1);
}
return H.top().second; // index
}
};
void norm(int&x) {
if(x>=MOD) x -= MOD;
if(x<0) x += MOD;
}
int mul(int a,int b) {
return 1ll * a * b % MOD;
}
int main(int argc, char const *argv[])
{
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr); std::cout.tie(nullptr);
int n;
std::cin >> n;
std::vector<int> p(n);
for(int i=0;i<n;++i) {
std::cin >> p[i];
--p[i];
}
std::vector<int> vis(n, false); // bool
int circles = 0;
std::vector<int> cnt;
for(int i=0;i<n;++i) {
if(!vis[i]) {
cnt.push_back(0);
for(int j=i;!vis[j];j=p[j]) {
vis[j] = true;
++cnt[circles];
}
++circles;
}
}
std::vector<poly> ps(circles, poly());
std::vector<int> fac(n+1),ifac(n+1),inv(n+1);
fac[0] = fac[1] = ifac[0] = ifac[1] = inv[0] = inv[1] = 1;
for(int i=2;i<=n;++i) {
fac[i] = mul(i, fac[i - 1]);
inv[i] = mul(inv[MOD % i], MOD - MOD/i);
ifac[i] = mul(inv[i], ifac[i - 1]);
}
auto C = [&](int n, int m) {
return mul( fac[n], mul(ifac[m], ifac[n - m]) );
};
for(int i=0;i<circles;++i) {
poly &thiz = ps[i];
thiz.resize(cnt[i]);
for(int j=0;j<cnt[i];++j) {
thiz[j] = C(cnt[i],j);
}
}
int thiz = Poly::Merge(ps);
poly &ans = ps[thiz];
ans.resize(n+1);
int res = 0;
for(int i=0;i<=n;++i) {
int thiz = mul(ans[i], fac[n - i]);
norm(
res += ((i & 1) ? MOD - thiz : thiz)
);
}
std::cout << res;
return 0;
}