THUWC2018 題解


2018清華冬令營

又一次由於接連而至的玄學現象跪慘,錯失良機,就不再公開提我這次慘痛的經歷了,寫點干貨……

day1

A 零食 (1s, 1G)

試題簡述

\(n\) 種物品1,\(m\) 種物品2,要求安排一個兩種物品的排列,當且僅當某個物品的有前一個物品前一個物品種類和它一樣時才能讓總和增加這個物品的權值。

現給出所有物品的權值,求最大總和。

輸入

第一行一個整數 \(T\),表示數據組數。

第二行一個正整數 \(n\),表示物品1的個數。

第三行 \(n\) 個整數,\(A_1, A_2, \cdots , A_n\),分別表示所有物品1的權值。

第四行一個正整數 \(m\),表示物品2的個數。

第五行 \(m\) 個整數,\(B_1, B_2, \cdots , B_n\),分別表示所有物品2的權值。

輸出

一個整數,表示最大總和。

輸入示例

2
5
2 3 3 -3 -3
5
6 6 6 -6 -6
2
1 -1
3
1 -1 1

輸出示例

26
3

數據規模及約定

\(1 \le n, m \le 10^6, |A_i|, |B_i| \le 10^9\)

題解

容易想到兩種物品“被消掉”(權值沒有被計入總和)的個數至多相差 \(1\),於是我們枚舉物品1被消掉的個數,物品2被消掉的個數可以隨之確定,而我們肯定是要貪心地消兩種物品中權值小的物品,所以最終給 \(A\)\(B\) 排個序掃一遍就好了。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 1000010
#define LL long long
#define ool (1ll << 60)

int n, m, A[maxn], B[maxn];
LL sA[maxn], sB[maxn];

LL sufA(int x) { return x <= n ? sA[x] : 0; }
LL sufB(int x) { return x <= m ? sB[x] : 0; }

void work() {
	n = read();
	rep(i, 1, n) A[i] = read();
	m = read();
	rep(i, 1, m) B[i] = read();
	
	sort(A + 1, A + n + 1);
	sort(B + 1, B + m + 1);
	
	sA[n+1] = 0; dwn(i, n, 1) sA[i] = sA[i+1] + A[i];
	sB[m+1] = 0; dwn(i, m, 1) sB[i] = sB[i+1] + B[i];
	LL ans = -ool;
	rep(i, 1, min(n, m)) {
		if(i <= n && i <= m) ans = max(ans, sufA(i + 1) + sufB(i + 1));
		if(i <= n && i + 1 <= m) ans = max(ans, sufA(i + 1) + sufB(i + 2));
		if(i + 1 <= n && i <= m) ans = max(ans, sufA(i + 2) + sufB(i + 1));
	}
	printf("%lld\n", ans);
	return ;
}

int main() {
	int T = read();
	
	while(T--) work();
	
	return 0;
}

B 城市規划 (1.5s, 2G)

試題簡述

\(n\) 個節點的樹,節點 \(i\) 有顏色 \(a_i\),求包含不超過兩種顏色的連通塊個數模 \(998244353\) 后的值。

輸入

第一行一個正整數 \(n\)

第二行 \(n\) 個正整數 \(a_1, a_2, \cdots , a_n\)

接下來 \(n-1\) 行,每行兩個正整數 \(u_i, v_i\),描述一條樹邊。

輸出

取模后的答案。

輸入示例

6
1 1 2 3 4 5
1 2
1 3
1 4
2 5
2 6

輸出示例

15

數據規模及約定

\(1 \le a_i, u_i, v_i \le n \le 10^5\)

題解

考慮一個暴力的 dp,設 \(f(i, c)\) 表示包含節點 \(i\)\(i\) 子樹中的節點,其中一種顏色為 \(a_i\),另一種顏色為 \(c\) 的連通塊個數。特別地,當 \(c=0\) 時,該連通塊只包含顏色 \(a_i\)(可以知道,只包含節點 \(i\) 的連通塊也會被統計到 \(f(i, 0)\) 中)。

分兩種轉移(令 \(A_i\) 表示節點 \(i\) 的 dp 值的顏色集合):

  • 對於一個 \(i\) 的兒子 \(son\),若 \(a_{son} \ne a_i\),那么 \(f(i, a_{son}) \leftarrow f(son, a_i) + f(son, 0) + 1\),其中 \(a \leftarrow b\) 表示將 \(b\) 累乘到 \(a\) 中。為什么要加 \(1\) 呢,因為不選擇這個子樹也是一種方案。
  • \(a_{son} = a_i\)\(\forall c \in A_i \cup A_{son}, f(i, c) \leftarrow f(son, c) + f(son, 0) + 1\)
  • 特殊地計算一下 \(f(i, 0) = \prod_{a_{son} \ne a_i} (f(son, 0) + 1)\),我就不把它算作“一種轉移”了。

但是這樣轉移完還不算結束,容易發現我們對於所有出現過的顏色 \(c\),都多統計了 \(f(i, 0)\) 的方案,也就是說,所有的 \(f(i, c)\) 要減去 \(f(i, 0)\) 才能得到真正的 dp 值。

不難發現一個節點 \(i\) 的 dp 值只需要存儲它子樹中出現過的顏色,而上面的第一種轉移相當於單點修改,第二種轉移相當於同類合並。那么我們寫一個線段樹合並就能均攤 \(O(n \log n)\) 支持所有的轉移操作了。

為什么線段樹合並格外好寫呢,因為它天然支持區間加和區間乘。

為了方便,我們可以對於所有出現過的 \(f(i, c)\) 最后不減去 \(f(i, 0)\),因為上面所有的轉移式子中其實都是形如 \(f(i, c) + f(i, 0) + 1\),如果我們不減去,式子就簡化成了 \(f(i, c) + 1\),下面的代碼也是這樣實現的。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <cassert>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 100010
#define maxm 200010
#define maxnode 4000010
#define MOD 998244353
#define LL long long

int n, m, head[maxn], nxt[maxm], to[maxm], col[maxn];

void AddEdge(int a, int b) {
	to[++m] = b; nxt[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; nxt[m] = head[a]; head[a] = m;
	return ;
}

int ToT, rt[maxnode], lc[maxnode], rc[maxnode], mulv[maxnode], addv[maxnode], sumv[maxnode], siz[maxnode];
int f0;
void multi(int& a, int b) {
	a = (LL)a * b % MOD;
	return ;
}
void incr(int& a, int b) {
	a += b; if(a >= MOD) a -= MOD;
	return ;
}
void pushdown(int o, int l, int r) {
	if((!addv[o] && mulv[o] == 1) || l == r){ mulv[o] = 1; addv[o] = 0; return ; }
	assert(l > 0);
	if(lc[o])
		multi(mulv[lc[o]], mulv[o]), multi(addv[lc[o]], mulv[o]), multi(sumv[lc[o]], mulv[o]), incr(addv[lc[o]], addv[o]), incr(sumv[lc[o]], (LL)addv[o] * siz[lc[o]] % MOD);
	if(rc[o])
		multi(mulv[rc[o]], mulv[o]), multi(addv[rc[o]], mulv[o]), multi(sumv[rc[o]], mulv[o]), incr(addv[rc[o]], addv[o]), incr(sumv[rc[o]], (LL)addv[o] * siz[rc[o]] % MOD);
	mulv[o] = 1; addv[o] = 0;
	return ;
}
void Mul(int& o, int l, int r, int p, int v) {
	bool isnew = 0;
	if(!o) mulv[o = ++ToT] = 1, isnew = 1;
	else pushdown(o, l, r);
	if(l == r) {
		siz[o] = 1;
		if(isnew) sumv[o] = v;
		else multi(sumv[o], v);
//		printf("sumv[%d] = %d\n", o, sumv[o]);
		return ;
	}
	int mid = l + r >> 1;
	if(p <= mid) Mul(lc[o], l, mid, p, v);
	else Mul(rc[o], mid + 1, r, p, v);
	sumv[o] = sumv[lc[o]] + sumv[rc[o]]; if(sumv[o] >= MOD) sumv[o] -= MOD;
	siz[o] = siz[lc[o]] + siz[rc[o]];
	return ;
}
int query(int o, int l, int r, int p) {
	if(!o) return f0;
	pushdown(o, l, r);
	if(l == r) return sumv[o];
	int mid = l + r >> 1;
	if(p <= mid) return query(lc[o], l, mid, p);
	return query(rc[o], mid + 1, r, p);
}
void update_add(int& o, int l, int r, int v) {
	assert(l > 0);
	pushdown(o, l, r);
	incr(sumv[o], (LL)v * siz[o] % MOD);
	incr(addv[o], v);
	return ;
}
void update_mul(int& o, int l, int r, int v) {
	assert(l > 0);
	pushdown(o, l, r);
	multi(sumv[o], v);
	multi(mulv[o], v);
	return ;
}
int Merge(int x, int y, int l, int r) {
	if(!x && !y) return 0;
	if(!x){ update_add(y, l, r, 1); return y; }
	if(!y){ update_mul(x, l, r, f0 + 1); return x; }
	pushdown(x, l, r); pushdown(y, l, r);
	if(l == r){ multi(sumv[x], (sumv[y] + 1)); return x; }
	int mid = l + r >> 1;
	lc[x] = Merge(lc[x], lc[y], l, mid); rc[x] = Merge(rc[x], rc[y], mid + 1, r);
	sumv[x] = sumv[lc[x]] + sumv[rc[x]]; if(sumv[x] >= MOD) sumv[x] -= MOD;
	siz[x] = siz[lc[x]] + siz[rc[x]];
	return x;
}

int ans;
void dp(int u, int fa) {
	bool has = 0;
	for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa) dp(to[e], u), has = 1;
	Mul(rt[u], 0, n, 0, 1);
	if(has) {
		for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa) {
			f0 = query(rt[to[e]], 0, n, 0);
			if(col[to[e]] != col[u]) Mul(rt[u], 0, n, col[to[e]], (query(rt[to[e]], 0, n, col[u]) + 1) % MOD);
			else rt[u] = Merge(rt[u], rt[to[e]], 0, n);
		}
		int tmp = sumv[rt[u]];
		incr(tmp, MOD - (LL)(siz[rt[u]] - 1) * query(rt[u], 0, n, 0) % MOD);
		incr(ans, tmp);
	}
	else incr(ans, 1);
	/*f0 = query(rt[u], 0, n, 0);
	rep(i, 0, n) printf("f[%d][%d] = %d\n", u, i, query(rt[u], 0, n, i) - (i ? f0 : 0)); // */
	return ;
}

int main() {
	n = read();
	rep(i, 1, n) col[i] = read();
	rep(i, 1, n - 1) {
		int a = read(), b = read();
		AddEdge(a, b);
	}
	
	dp(1, 0);
	
	printf("%d\n", ans);
	
	return 0;
}

C 字胡串 (3s, 512MB)

試題簡述

給出長度為 \(n\) 的串 \(A\),和 \(q\) 個詢問。每個詢問是一個串 \(B_i\),要求回答一個最小的 \(j\),使得 \(A[1..j] + B_i + A[j+1..n]\) 字典序最小。

輸入

第一行一個正整數 \(n\)

第二行一個數字串 \(A\)

第三行一個正整數 \(q\)

接下來 \(q\) 行,每行一個數字串 \(B_i\)

輸出

\(q\) 行,分別為每組詢問的答案

輸入示例1

6
000001
2
00
01

輸出示例1

0
5

輸入示例2

10
7676767982
1
7676

輸出示例2

0

數據規模及約定

\(1 \le n, q, \sum |B_i| \le 10^6\)\(1 \le \max\{ |B_i| \} \le n\),輸入的字符串均為字符集為全體數字的串。

題解

看到這題要理清思路,要求最小化字典序,那么自然是從前往后貪心地考慮每一位,並讓每一位盡量低,於是現在需要考慮清楚,插入后的串的一個前綴到底是由什么構成的?不難發現情況有三種:

  • \(A\) 的前綴,形式化地:\(A[1..i]\)
  • \(A\) 的前綴 + \(B\) 的前綴,形式化地:\(A[1..i] + B[1..j]\)
  • \(A\) 的前綴 + \(B\) + \(A\) 后面的一段,形式化地:\(A[1..i] + B + A[i+1..k]\)

情況1不需要考慮,因為這時候還未引入串 \(B\)

可以發現情況3也可以忽略,因為我們會在情況2中先找到“沖突”,而這時已經可以把 \(B\) 插入的位置確定下來了,沒必要再去看情況3。

解釋一下什么叫“沖突”,我們一定會找到一個 \((i, j)\) 滿足 \(A[1..i] + B[1..j] < A[1..i+j]\)(假設串 \(A\) 后面追加一個字符 \(10\),它比任何數字字符都要大,這樣就能夠保證一定能找到這樣一個位置)。不難發現我們希望 \(i+j\) 最小,因為字符變小的位置越靠前越好。

那么現在目標就是找到這個最小的 \(i+j\)。我們可以先枚舉 \(j\),由於在 \(i\) 確定的情況下 \(j\) 要最小,所以會有 \(B[j] < A[i+j]\),那么我們可以構造一下 \(A\) 的后綴自動機,並查找 \(B[1..j-1] + x\) 的最靠前的匹配位置(\(x\)\(B[j] + 1\)\(10\) 枚舉),這個最靠前的匹配位置就是 \(i\)。這樣所有的取個最小值就可以求出最小的 \(i+j\) 了。

但是僅僅找到最小的 \(i+j\) 並不能得到最小字典序的串,假設對於 \(j_1\)\(j_2\) 都有 \(A[1..i_1] + B[1..j_1] < A[1..i_1+j_1]\)\(A[1..i_2] + B[1..j_2] < A[1..i_2+j_2]\),那么 \(A[1..i_1] + B + A[i_1+1..n]\)\(A[1..i_2] + B + A[i_2+1..n]\) 哪個更小呢?注意到這里有 \(i_1+j_1 = i_2+j_2\),那么 \(A[1..i_1] + B[1..j_1-1] = A[1..i_2] + B[1..j_2-1]\),於是我們可以直觀地看下圖:

QAQ

綠色部分表示串 \(B\),容易發現,這時只用比較 \(B\) 的后綴 \(B[j_1..l]\)\(B[j_2..l]\)\(l = |B|\))誰更靠前就好了(所以要再打個后綴數組)。細心的讀者一定發現,這樣比較之后還有一個隱患:如果 \(B[j_2..l]\) 恰好是 \(B[j_1..l]\) 的前綴(不失一般性,這里假設 \(j_2 > j_1\)),怎么辦呢?可以證明,這樣的話,兩種方案得到的串是一樣的。如圖:

TAT

其中,兩個灰色刻度之間的部分是相同的,那么可以發現紅色部分是 \(B\) 的一個 border(若 \(|B|\) 小於紅色部分的長度則 \(B\) 是紅色部分的一個 border,不過不影響后面的結論),令紅色部分長度為 \(l'\),可以證明 \(B\) 是以 \(gcd(l, l')\) 為周期的一個周期串。那么圖中 \(紅串+綠串\) 和顛倒過來的 \(綠串+紅串\) 就是完全一樣的啦!

至此,我們還沒有做完。因為我們剛剛得到一個能使得插入之后字典序最小的插入位置,這個插入位置可能有很多,我們要求出最小的那個。不難發現,所有可能的插入位置就是串 \(B\) 往前跳,所以我們用 KMP 找到 \(B\) 的最小周期,然后倍增 + 哈希往前跳就好了。(哈希是因為我們需要判斷往前跳那么多步之后和跳之前的串完全相同)

縱觀本題,竟然有四個字符串工具:SAM + SA + KMP + hash。實現的話,細節是相當多的。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 4000010
#define maxlog 25
#define maxnode 8000010
#define maxa 11
#define oo 2147483647

int n, q;
char str[maxn], tmp[maxn];

int ToT, to[maxnode][maxa], par[maxnode], lst, mx[maxnode], mnp[maxnode];
void extend(int i) {
	int c = str[i] - '0', now = ++ToT, v = lst; lst = now;
	mx[now] = i; mnp[now] = i;
	while(v && !to[v][c]) to[v][c] = now, v = par[v];
	if(!v) par[now] = 1;
	else {
		int u = to[v][c];
		if(mx[u] == mx[v] + 1) par[now] = u;
		else {
			int q = ++ToT; mx[q] = mx[v] + 1; mnp[q] = oo;
			par[q] = par[u];
			par[u] = par[now] = q;
			memcpy(to[q], to[u], sizeof(to[u]));
			while(v && to[v][c] == u) to[v][c] = q, v = par[v];
		}
	}
	return ;
}
bool cmp(int a, int b) { return mx[a] < mx[b]; }
int id[maxnode];
void build() {
	rep(i, 1, ToT) id[i] = i;
	sort(id + 1, id + ToT + 1, cmp);
	dwn(i, ToT, 1) {
		int u = id[i];
		mnp[par[u]] = min(mnp[par[u]], mnp[u]);
	}
	return ;
}

#define rate 233
#define UL unsigned long long

UL rpow[maxn], hval[maxn], htmp[maxn], jump[maxlog];
void hash_init() {
	rpow[0] = 1;
	rep(i, 1, n + 1) rpow[i] = rpow[i-1] * rate;
	rep(i, 1, n) hval[i] = hval[i-1] * rate + str[i];
	return ;
}
UL get_hash(int l, int r) {
	return hval[r] - hval[l-1] * rpow[r-l+1];
}

struct SA {
	char S[maxn];
	int N, height[maxn], rank[maxn], sa[maxn], Ws[maxn];
	
	bool diff(int *a, int p1, int p2, int len) {
		if(p1 + len > N && p2 + len > N) return a[p1] != a[p2];
		if(p1 + len > N || p2 + len > N) return 1;
		return a[p1] != a[p2] || a[p1+len] != a[p2+len];
	}
	void ssort(char *tS) {
		rep(i, 1, max(N, 10)) Ws[i] = 0;
		N = strlen(tS);
		rep(i, 1, N) S[i] = tS[i-1];
		int *x = height, *y = rank, m = 0;
		rep(i, 1, N) Ws[x[i] = S[i]-'0'+1]++, m = max(m, x[i]);
		rep(i, 1, m) Ws[i] += Ws[i-1];
		dwn(i, N, 1) sa[Ws[x[i]]--] = i;
		for(int j = 1, pos; j < N; j <<= 1, m = pos) {
			pos = 0;
			rep(i, N - j + 1, N) y[++pos] = i;
			rep(i, 1, N) if(sa[i] > j) y[++pos] = sa[i] - j;
			rep(i, 1, m) Ws[i] = 0;
			rep(i, 1, N) Ws[x[i]]++;
			rep(i, 1, m) Ws[i] += Ws[i-1];
			dwn(i, N, 1) sa[Ws[x[y[i]]]--] = y[i];
			swap(x, y); x[sa[1]] = pos = 1;
			rep(i, 2, N) x[sa[i]] = diff(y, sa[i], sa[i-1], j) ? ++pos : pos;
		}
		rep(i, 1, N) rank[sa[i]] = i;
		return ;
	}
} cmper;

int fa[maxn];

int main() {
	n = read();
	scanf("%s", str + 1);
	
	ToT = lst = 1; mnp[1] = oo;
	rep(i, 1, n) extend(i);
	rep(i, n + 1, n << 1) str[i] = 10 + '0', extend(i);
	n <<= 1;
	build();
	hash_init();
	
	int q = read();
	while(q--) {
		scanf("%s", tmp + 1);
		cmper.ssort(tmp + 1);
		int now = 1, bestl = 0, bestp = -1, len = strlen(tmp + 1);
//		rep(i, 1, len) printf("%d%c", cmper.rank[i], i < len ? ' ' : '\n');
		rep(i, 1, len) {
			rep(c, tmp[i] - '0' + 1, 10) if(to[now][c]) {
				int u = to[now][c];
				if(!bestl || bestp > mnp[u] || (bestp == mnp[u] && cmper.rank[i] < cmper.rank[bestl]))
					bestp = mnp[u], bestl = i;
			}
			now = to[now][tmp[i]-'0'];
			if(!now) break;
		}
		
		fa[1] = fa[2] = 1;
		rep(i, 2, len) {
			int j = fa[i];
			while(j > 1 && tmp[j] != tmp[i]) j = fa[j];
			fa[i+1] = tmp[j] == tmp[i] ? j + 1 : 1;
		}
		rep(i, 1, len) htmp[i] = htmp[i-1] * rate + tmp[i];
//		printf("bestp: %d %d\n", bestp, bestl);
		bestp -= bestl;
		int u = fa[len+1], ans = bestp;
		while(1) {
			int tlen = len + 1 - u;
			if(tlen && len % tlen == 0) {
//				printf("try tlen: %d\n", tlen);
				jump[0] = htmp[tlen];
				for(int i = 1; (1 << i) * tlen <= n; i++) jump[i] = jump[i-1] * rpow[(1<<i>>1)*tlen] + jump[i-1];
//				for(int i = 0; (1 << i) * tlen <= n; i++) printf("jump[%d]: %llu\n", i, jump[i]);
				int np = bestp;
				dwn(i, maxlog - 1, 0)
					if(np - (1ll << i) * tlen >= 0 && get_hash(np - (1 << i) * tlen + 1, np) == jump[i])
						np -= (1ll << i) * tlen;
				ans = min(ans, np);
				break;
			}
			if(u == 1) break;
			u = fa[u];
		}
		printf("%d\n", ans);
	}
	
	return 0;
}

day2

A 明天的太陽會照常升起 (7s, 512MB)

試題簡述

\(n\) 個城市從北往南依次編號為 \(1 \sim n\),城市 \(i\) 和城市 \(i+1\) 之間有一條長度為 \(l_i\) 的道路 \((1 \le i < n)\),城市 \(i\)\(1\) 單位油的價錢為 \(p_i\)\(1\) 單位油可以走 \(1\) 單位距離。

現有 \(m\) 組詢問,每次詢問形如 \((s_i, t_i, v_i)\),表示從城市 \(s_i\) 開車到 \(t_i\),初始時油量為 \(v_i\) 所需的最小花費。

注意:每組詢問中車都有恆定的油量上界 \(V\)

輸入

第一行三個正整數 \(n, m, V\)

第二行 \(n\) 個正整數 \(p_1, p_2, \cdots , p_n\)

第三行 \(n-1\) 個正整數 \(l_1, l_2, \cdots, l_{n-1}\)

接下來 \(m\) 行,每行三個正整數 \(s_i, t_i, v_i\),表示一組詢問。

輸出

\(m\) 行,分別為每組詢問的答案

輸入示例

7 2 9
3 2 5 6 7 4 1
2 5 7 7 3 4
1 4 2
2 6 5

輸出示例

33
82

數據規模及約定

\(1 \le m, p_i \le 10^6\)\(1 \le V \le 10^{18}\)\(1 \le l_i, v_i \le \min\{ 10^6, V \}\)\(1 \le s_i < t_i \le n \le 10^6\)

題解

首先考慮暴力每次詢問從 \(s\)\(t\) 線性掃一遍,顯然我們需要決策的就是在每個城市是否要加油,如果要,加多少。不難發現一個顯然的貪心策略:到一個城市 \(i\) 后,如果對於 \(j > i\)\(p_j < p_i\)\(j = t\) 的最小的 \(j\)\(\sum_{k=i}^{j-1} l_k \le V\),那么我們把油量加到 \(\sum_{k=i}^{j-1} l_k\),否則加滿(即加到 \(V\))。

那么現在的任務無非是利用數據結構優化這個暴力的過程。

只考慮暴力中會加油的那些城市,我們發現遵循這樣一個規則:若在城市 \(i\) 加油,那么下一個加油的城市可以唯一確定出來,分兩種情況確定:

  • 對於 \(j > i\)\(p_j < p_i\) 的最小的 \(j\),有 \(\sum_{k=i}^{j-1} l_k \le V\),那么下一個加油的城市就是 \(j\)
  • 否則下一個加油的城市是從 \(i\) 往后 \(V\) 的距離中 \(p_j\) 最小的城市 \(j\)

容易發現每個城市只有一個“下一個加油的城市”,所以如果把這個關系圖建出來,這就是一棵樹。樹上的操作就很好辦了,從一個點跳到另一個點,自然想到倍增。

但是兩個問題,一是如何解決初始油量,二是如何解決限定終點。

一下我們將節點分成兩類,下一個價格更便宜的城市到這個城市的距離 \(\le V\) 的稱作第一類,否則稱作第二類。

問題一很好解決,我們先倍增往上跳,當遇到一個第二類或者初始油量不夠往上跳的時候結束。以這時所在的節點 \(s'\) 為新的起點,並設這時剩下的油量為 \(v'\)。那么這時可以保證下一步跳所花費的價格可以直接減去 \(v' \cdot p_{s'}\)

對於問題二,我們可以倍增找到第一次遇到的可以“直接到達 \(t\) 的節點”。\(i\) 可以直接到達的含義是:若 \(i\) 是第一類城市,那么它到它父親節點的距離大於等於它到 \(t\) 的距離;否則 \(V\) 大於等於它到 \(t\) 的距離。找到這樣的節點后,一步跳到 \(t\) 即可。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)
#define LL long long

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
LL read() {
	LL x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 1000010
#define maxlog 20

int n, q, pri[maxn];
LL lim, D[maxn];

LL cdist(int a, int b) { b = min(b, n); return D[b-1] - D[a-1]; }

int nxt[maxn], fa[maxn][maxlog], mnp[maxlog][maxn], Log[maxn];
bool type[maxn];
LL cost[maxn][maxlog], rest[maxn][maxlog];
int qmin(int l, int r) {
	int t = Log[r-l+1], a = mnp[t][l], b = mnp[t][r-(1<<t)+1];
	return pri[a] < pri[b] ? a : b;
}
void init() {
	pri[n+1] = 0;
	dwn(i, n, 1) {
		nxt[i] = i + 1;
		while(pri[nxt[i]] > pri[i]) nxt[i] = nxt[nxt[i]];
	}
	
	Log[1] = 0;
	rep(i, 2, n) Log[i] = Log[i>>1] + 1;
	rep(i, 1, n) mnp[0][i] = i;
	for(int j = 1; (1 << j) <= n; j++)
		rep(i, 1, n - (1 << j) + 1) {
			int a = mnp[j-1][i], b = mnp[j-1][i+(1<<j>>1)];
			mnp[j][i] = pri[a] < pri[b] ? a : b;
		}
	
	rep(i, 1, n)
		if(cdist(i, nxt[i]) <= lim) {
			fa[i][0] = nxt[i];
			cost[i][0] = cdist(i, nxt[i]) * pri[i];
			rest[i][0] = 0;
			type[i] = 0;
		}
		else {
			int l = i + 1, r = n + 1;
			while(r - l > 1) {
				int mid = l + r >> 1;
				if(cdist(i, mid) <= lim) l = mid; else r = mid;
			}
			l = qmin(i + 1, l);
			fa[i][0] = l;
			cost[i][0] = lim * pri[i];
			rest[i][0] = lim - cdist(i, l);
			type[i] = 1;
		}
	
	int now = n + 1;
	dwn(i, n, 1) {
		if(type[i]) now = i;
		nxt[i] = now;
	}
	rep(j, 1, maxlog - 1) rep(i, 1, n) {
		int m = fa[i][j-1];
		fa[i][j] = fa[m][j-1];
		cost[i][j] = cost[i][j-1] + cost[m][j-1] - rest[i][j-1] * pri[m];
		rest[i][j] = rest[m][j-1];
	}
	return ;
}

int num[100], cntn;
void putint(LL x) {
	cntn = 0;
	while(x) num[++cntn] = x % 10, x /= 10;
	dwn(i, cntn, 1) putchar(num[i] + '0'); putchar('\n');
	return ;
}

int main() {
	n = read(); q = read(); lim = read();
	rep(i, 1, n) pri[i] = read();
	rep(i, 1, n - 1) D[i] = D[i-1] + read();
	
	init();
	while(q--) {
		int s = read(), t = read(); LL v = read();
		if(cdist(s, t) <= v){ puts("0"); continue; }
		LL ans = 0;
		dwn(i, maxlog - 1, 0) if(fa[s][i] && nxt[s] >= fa[s][i] && cdist(s, fa[s][i]) <= v)
			v -= cdist(s, fa[s][i]), s = fa[s][i];
//		printf("get1: %d\n", s);
		dwn(i, maxlog - 1, 0) {
			int tmp = fa[s][i];
			if(!tmp || tmp > t) continue;
//			if(i == 0) printf("tmp: %d %lld\n", tmp, cdist(tmp, t));
			if((type[tmp] && lim >= cdist(tmp, t)) || (!type[tmp] && fa[tmp][0] && fa[tmp][0] >= t)) continue;
			ans += cost[s][i] - v * pri[s];
			v = rest[s][i]; s = fa[s][i];
		}
		if((type[s] && lim >= cdist(s, t)) || (!type[s] && fa[s][0] && fa[s][0] >= t));
		else ans += cost[s][0] - v * pri[s], v = rest[s][0], s = fa[s][0];
//		printf("get2: %d %lld  %lld  (%lld - %lld) * %d = %lld\n", s, v, ans, cdist(s, t), v, pri[s], (cdist(s, t) - v) * pri[s]);
		ans += (cdist(s, t) - v) * pri[s];
		putint(ans);
	}
	
	return 0;
}

據說標程很短,不知怎么做到的。

B 小球序列 (5s, 512MB)

試題簡述

\(k\) 種顏色的球,第 \(i\) 中顏色的球有 \(a_i\) 個,現在要你將它們排成一排,要求對於任意非空前綴、后綴都滿足 \(k\) 中顏色的小球個數不同。

求排列方案數對 \(998244353\) 取模后的結果。

輸入

第一行一個正整數 \(k\)

第二行 \(k\) 個正整數 \(a_1, a_2, \cdots , a_k\)

輸出

取模后的答案。

輸入示例

3
1 2 1

輸出示例

2

數據規模及約定

\(1 \le k \le 100, 1 \le a_i \le 2 \times 10^5\)

題解

這種題還是得從暴力 dp 入手。我們發現問題可以轉化成 \(k\) 維空間中,不能經過兩條直線上的點,每次只能沿某個維度的正方向走,問從 \((0, 0, \cdots , 0)\)\((a_1, a_2, \cdots , a_k)\) 有多少條路徑。

以下令 \(n = \min_{i \in [1, k]} a_i\)

這個問題 \(O(n^2)\) 顯然可以將所有被挖掉的點按照坐標排序,然后容斥 dp 一下做出來。

但由於這題的點的坐標非常特殊,我們可以利用這個條件進行優化。

先推一個小式子,從 \((x_1, x_2, \cdots , x_k)\)\((y_1, y_2, \cdots , y_k)\)(保證 \(x_i \le y_i\))的方案數可以用組合數計算,類比二維情況,將路徑看成每個方向的序列,求序列有多少種。那么方案數顯然是下面這個式子

\[\prod_{i=1}^k C_{\sum_{j=i}^k y_j - x_j}^{y_j - x_j} \]

方便起見,令 \(\Delta x_i = y_i - x_i\),我們將上式的組合數展開,得到

\[\prod_{i=1}^k \frac{(\sum_{j=i}^k \Delta x_j)!}{\Delta x_i! \cdot (\sum_{j=i+1}^k \Delta x_j)!} \\ = \frac{(\Delta x_1 + \Delta x_2 + \cdots + \Delta x_k)!}{\Delta x_1! (\Delta x_2 + \Delta x_3 + \cdots + \Delta x_k)!} \cdot \frac{(\Delta x_2 + \Delta x_3 + \cdots + \Delta x_k)!}{\Delta x_2! (\Delta x_3 + \Delta x_4 + \cdots + \Delta x_k)!} \cdots \frac{\Delta x_k!}{\Delta x_k!} \\ = \frac{(\sum_{i=1}^k \Delta x_i)!}{\prod_{i=1}^k \Delta x_i!} \]

變成了一個非常簡潔的形式!

下面還是考慮容斥,並嘗試利用“坐標非常特殊”這種條件。注意:接下來將大量用到上面推到過的式子。

\(f(i)\) 表示只考慮 \((1, 1, \cdots , 1), (2, 2, \cdots, 2), \cdots , (n, n, \cdots, n)\) 那些點不能經過,從原點到 \((i, i, \cdots , i)\) 的方案數,那么 \(f(i) = \frac{(ki)!}{(i!)^k} - \sum_{j=1}^{i-1} { f(j) \frac{[k(i-j)]!}{[(i-j)!]^k} }\)

\(g(i)\) 表示從原點到 \((a_1+i-n, a_2+i-n, \cdots , a_k+i-n)\),不經過所有不合法點的方案數,令 \(A = \sum_{i=1}^k a_i\),那么 \(g(i) = \frac{[A+k(i-n)]!}{\prod_{j=1}^k (a_j + i - n)!} - \sum_{j=1}^i { f(j) \frac{[A+k(i-j-n)]!}{\prod_{t=1}^k (a_t + i - j - n)!} } - \sum_{j=0}^{i-1} {g(j) h(i-j)}\)。注意,這里的容斥其實是所有方案依次減去經過的最小不合法點編號為 \(i\) 的方案,並且現在的編號是過原點的那條直線在前,過終點的那條直線在后,所以這就解釋了那個函數 \(h(i)\) 是要干什么用的:它就是要求過終點的那條直線上的兩個點之間的,不經過過原點那條直線上的點的路徑數。

現在再考慮一下 \(h(i)\) 怎么求,其實就是一個各維度棱長為 \(i\) 的超立方體,從左下角到右上角,不經過 \((m-n+i, m-n+i, \cdots , m-n+i)\)\(m = \max \{ a_i \}\)) 這種點的路徑數。這個東西需要再套一個容斥(霧):\(h(i) = \frac{(ki)!}{(i!)^k} - \sum_{j=0}^i { t(j) \frac{[A+k(i-j-n)]!}{\prod_{t=1}^k (a_t + i-j - n)!} }\)

\(t(i)\) 就是從 \((0, 0, \cdots , 0)\)\((m-n+i, m-n+i, \cdots , m-n+i)\),不經過 \((m-n+j, m-n+j, \cdots , m-n+j), j < i\) 的方案數。\(t(i) = \frac{[k(i+n)-A]!}{\prod_{j=1}^k (i+n-a_j)!} - \sum_{j=1}^{i=1} { t(j) \frac{[k(i-j)]!}{[(i-j)!]^k} }\),特別地,當 \(i < m-n\) 時,\(t(i) = 0\)

以上式子看不懂屬於正常現象,這種東西自己手推比較好。核心就是容斥。

於是發現上面四種 dp 值的轉移都是卷積的形式,那么可以用分治 FFT 或多項式逆元來求了。個人認為多項式逆元好一些,只需要推下式子,好寫,復雜度還低。

下面令 \(F(x), G(x), H(x), T(x)\) 分別為 \(f(i), g(i), h(i), t(i)\) 的生成函數;並令 \(T_1(x) = \sum_{i=0}^n { \frac{(ki)!}{(i!)^k} x^i }, T_2(x) = \sum_{i=0}^n { \frac{[A+k(i-n)]!}{\prod_{j=1}^k (a_j+i-n)!} x^i }, T_3(x) = \sum_{i=0}^n { \frac{[k(i+n)-A]!}{\prod_{j=1}^k (i+n-a_j)!} [i \ge m-n] x^i }\)(對應着三個廣義組合數)。那么可以退一下式子求解出 \(F(x), G(x), H(x), T(x)\)

\[F(x): \\ T_1(x) = F(x)T_1(x) + 1 \\ F(x) = \frac{T_1(x) - 1}{T_1(x)} \]

\[G(x): \\ T_2(x) = F(x)T_2(x) + G(x)H(x) \\ G(x) = \frac{[1-F(x)]T_2(x)}{H(x)} \]

\[T(x): \\ T(x) = \frac{T_3(x)}{T_1(x)} \]

\[H(x): \\ H(x) = T_1(x) - T(x)T_2(x) \]

最后 \(g(n)\) 就是答案。

force:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxd 110
#define maxn 2010
#define maxtot 201010
#define oo 2147483647
#define MOD 998244353
#define LL long long

int k, n, a[maxd], f[maxn], g[maxn], t[maxn], h[maxn], T[maxn], T2[maxn], T3[maxn], fac[maxtot], ifac[maxtot];

int Pow(int a, int b) {
	int ans = 1, t = a;
	while(b) {
		if(b & 1) ans = (LL)ans * t % MOD;
		t = (LL)t * t % MOD; b >>= 1;
	}
	return ans;
}

int main() {
	k = read(); n = oo;
	int A = 0, mxa = 0;
	rep(i, 1, k) a[i] = read(), A += a[i], n = min(n, a[i]), mxa = max(mxa, a[i]);
	
	ifac[1] = 1;
	rep(i, 2, A) ifac[i] = (LL)(MOD - MOD / i) * ifac[MOD%i] % MOD;
	fac[0] = ifac[0] = 1;
	rep(i, 1, A) fac[i] = (LL)fac[i-1] * i % MOD, ifac[i] = (LL)ifac[i] * ifac[i-1] % MOD;
	rep(i, 0, n) {
		T[i] = (LL)fac[k*i] * Pow(ifac[i], k) % MOD;
		T2[i] = fac[A+k*(i-n)];
		rep(j, 1, k) T2[i] = (LL)T2[i] * ifac[a[j]+i-n] % MOD;
		T3[i] = fac[k*(i+n)-A];
		rep(j, 1, k) T3[i] = (LL)T3[i] * ifac[i+n-a[j]] % MOD;
//		printf("Ts[%d] %d %d %d\n", i, T[i], T2[i], T3[i]);
	}
	
	rep(i, 1, n) {
		f[i] = T[i];
		rep(j, 1, i - 1) {
			f[i] -= (LL)f[j] * T[i-j] % MOD;
			if(f[i] < 0) f[i] += MOD;
		}
	}
	rep(i, 0, n) {
		h[i] = T[i];
		if(i >= mxa - n) {
			t[i] = T3[i];
			rep(j, 1, i - 1) {
				t[i] -= (LL)t[j] * T[i-j] % MOD;
				if(t[i] < 0) t[i] += MOD;
			}
			rep(j, mxa - n, i) {
				h[i] -= (LL)t[j] * T2[i-j] % MOD;
				if(h[i] < 0) h[i] += MOD;
			}
		}
		else t[i] = 0;
	}
	rep(i, 0, n) {
		g[i] = T2[i];
		rep(j, 1, i) {
			g[i] -= (LL)f[j] * T2[i-j] % MOD;
			if(g[i] < 0) g[i] += MOD;
		}
		rep(j, 0, i - 1) {
			g[i] -= (LL)g[j] * h[i-j] % MOD;
			if(g[i] < 0) g[i] += MOD;
		}
	}
	/*rep(i, 0, n) printf("f[%d] = %d\n", i, f[i]);
	rep(i, 0, n) printf("h[%d] = %d\n", i, h[i]);
	rep(i, 0, n) printf("t[%d] = %d\n", i, t[i]);
	rep(i, 0, n) printf("g[%d] = %d\n", i, g[i]); // */
	
	printf("%d\n", g[n]);
	
	return 0;
}

100pts:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 524288
#define MOD 998244353
#define Groot 3
#define LL long long

int Pow(int a, int b) {
	int ans = 1, t = a;
	while(b) {
		if(b & 1) ans = (LL)ans * t % MOD;
		t = (LL)t * t % MOD; b >>= 1;
	}
	return ans;
}

int brev[maxn];
void FFT(int *a, int len, int tp) {
	int n = 1 << len;
	rep(i, 0, n - 1) if(i < brev[i]) swap(a[i], a[brev[i]]);
	rep(i, 1, len) {
		int wn = Pow(Groot, MOD - 1 >> i);
		if(tp < 0) wn = Pow(wn, MOD - 2);
		for(int j = 0; j < n; j += 1 << i) {
			int w = 1;
			rep(k, 0, (1 << i >> 1) - 1) {
				int la = a[j+k], ra = (LL)w * a[j+k+(1<<i>>1)] % MOD;
				a[j+k] = (la + ra) % MOD;
				a[j+k+(1<<i>>1)] = (la - ra + MOD) % MOD;
				w = (LL)w * wn % MOD;
			}
		}
	}
	if(tp < 0) {
		int invn = Pow(n, MOD - 2);
		rep(i, 0, n - 1) a[i] = (LL)a[i] * invn % MOD;
	}
	return ;
}

void Mul(int *A, int *B, int n, int m, bool recover = 0) {
	int N = 1, len = 0;
	while(N <= n + m) N <<= 1, len++;
	rep(i, 0, N - 1) brev[i] = (brev[i>>1] >> 1) | ((i & 1) << len >> 1);
	FFT(A, len, 1); FFT(B, len, 1);
	rep(i, 0, N - 1) A[i] = (LL)A[i] * B[i] % MOD;
	FFT(A, len, -1); if(recover) FFT(B, len, -1);
	return ;
}

int tmp[maxn];
void inverse(int *f, int *g, int n) {
	if(n == 1) return (void)(f[0] = Pow(g[0], MOD - 2));
	inverse(f, g, n + 1 >> 1);
	rep(i, 0, n - 1) tmp[i] = g[i];
	int N = 1, len = 0;
	while(N <= (n << 1)) N <<= 1, len++;
	rep(i, 0, N - 1) brev[i] = (brev[i>>1] >> 1) | ((i & 1) << len >> 1);
	rep(i, n, N - 1) tmp[i] = f[i] = 0;
	FFT(f, len, 1); FFT(tmp, len, 1);
	rep(i, 0, N - 1) f[i] = (LL)f[i] * (2ll - (LL)tmp[i] * f[i] % MOD + MOD) % MOD;
	FFT(f, len, -1); rep(i, n, N - 1) f[i] = 0;
	return ;
}

#define maxd 110
#define maxtot 20000010
#define oo 2147483647

int k, a[maxd], fac[maxtot], ifac[maxtot];
int F[maxn], G[maxn], T[maxn], H[maxn], T1[maxn], T2[maxn], T3[maxn];

int main() {
	int n = oo, A = 0, mxa = 0;
	k = read();
	rep(i, 1, k) a[i] = read(), A += a[i], n = min(n, a[i]), mxa = max(mxa, a[i]);
	
	ifac[1] = 1;
	rep(i, 2, A) ifac[i] = (LL)(MOD - MOD / i) * ifac[MOD%i] % MOD;
	fac[0] = ifac[0] = 1;
	rep(i, 1, A) fac[i] = (LL)fac[i-1] * i % MOD, ifac[i] = (LL)ifac[i] * ifac[i-1] % MOD;
	rep(i, 0, n) {
		T1[i] = (LL)fac[k*i] * Pow(ifac[i], k) % MOD;
		T2[i] = fac[A+k*(i-n)];
		rep(j, 1, k) T2[i] = (LL)T2[i] * ifac[a[j]+i-n] % MOD;
		T3[i] = i < mxa - n ? 0 : fac[k*(i+n)-A];
		rep(j, 1, k) T3[i] = (LL)T3[i] * ifac[i+n-a[j]] % MOD;
	}
	
	inverse(F, T1, n + 1);
	T1[0]--;
	Mul(F, T1, n, n, 1);
	T1[0]++;
	rep(i, n + 1, n << 1) F[i] = 0;
	
	inverse(T, T1, n + 1);
	Mul(T, T3, n, n, 1);
	rep(i, n + 1, n << 1) T[i] = 0;
	
	memcpy(H, T, sizeof(T));
	Mul(H, T2, n, n, 1);
	rep(i, n + 1, n << 1) H[i] = 0;
	rep(i, 0, n) H[i] = (T1[i] - H[i] + MOD) % MOD;
	
	inverse(G, H, n + 1);
	rep(i, 0, n) F[i] = MOD - F[i]; F[0]++;
	Mul(F, T2, n, n);
	rep(i, n + 1, n << 1) F[i] = 0;
	Mul(G, F, n, n);
	
	printf("%d\n", G[n]);
	
	return 0;
}

C 角點檢測

亂搞題,不是用來 AC 的,有興趣的同學可以自學圖像處理。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM