Problem
題意大意:
一棵樹,點權\(w_i\),每次玩家可以在樹上行走,一條邊需要\(1\)的時間,只能往兒子走。每次游戲需要從\(s\)到\(t\)。
玩家有一個總死亡次數,初始為\(0\)。如果走到\(i\)的時候,當前總的死亡次數小於\(w_i\),那么玩家就會立刻死亡並回到起點 \(s\),且該死亡過程不需要時間,多組詢問從\(s\)到\(t\)的最短時間
\(n,m\leq 3\times 10^5,w_i\leq 10^9\)
Solution
不難發現每次死亡都互相獨立,所以對於每次死亡都單獨計算
我們設\(f[x][i]\)表示節點\(x\)往下走,已經死亡\(i-1\)次,死亡第\(i\)次所花費的最短時間,假如從\(s\)到\(t\)之間的最大點權為\(W\),則答案為\(\sum_{i=1}^Wf[s][i]+dis(s,t)\)
解釋一下,如果路徑上最大點權為\(W\),則從\(s\)到\(t\)的過程中一定恰好死\(W\)次(因為如果沒死\(W\)次,則無法通過,一旦死亡\(W\)次,則接下來的點都可以通過且不會死亡),則答案記為\(\sum_{i=1}^Wf[s][i]\),最后還要來一次從\(s\)到\(t\)的暢通無阻的旅行,答案加上\(dis(s,t)\)
如何得到\(f[x][i]\)?可以簡單\(Dp\),時間空間復雜度都為\(O(nW)\),我們可以拿到\(10\)分的好成績,離散化一下可以多拿\(10\)分 (。・∀・)ノ゙
優化轉移,發現\(f[x][i]\)可能有一大段是相同的,而且\(f[x][i]\)一定依照\(i\)嚴格遞增,我們只需要記錄轉折點了
比如說對於\(f[x][i]\)我們記錄了\(<f_1,f_4,f_{12},f_{40}>\),則對於\(i\in [5,12]\),\(f[x][i]=f_{12}\)
這樣的話我們拿平衡樹維護這個\(<f_i>\),答案可以維護
答案轉移可以平衡樹啟發式合並,具體就是所有子樹相應位置取\(\min\),然后整體加\(1\),詢問完答案后,再將位置\([1,w_i]\)之間的值賦為\(0\)
我:“好了,我們開始打吧”
(一個小時后)
我:“這太惡心了,渾身難受,不想繼續,我們還是用正解的方法吧”
同桌:“這怎么行呢,代碼再長,忍忍就過去了,而且你就是要培養這種調試的能力”
我:“好吧”(開始硬着頭皮開始打)
(20分鍾后)
同桌:“啊,我不打了,這太惡心了,我還沒打完主程序就有300行了,我們還是看正解吧”
我:“不行不行,你要堅持,怎么能放棄呢,代碼再長,忍忍就過去了,而且就是要培養這種調試的能力嘛”
同桌:“算了算了,我放棄了,我們還是用正解做法吧”
如上,我們發現用上述方法會導致渾身藍瘦
所以我們有種奇妙的代碼優化,就是
隊列啟發式合並
具體怎么做呢,就是整體開一個長隊列,按照\(dfs\)序分配隊列空間,這樣合並一棵子樹的時候由於\(dfs\)序是連續的,所以在隊列中的空間也是連續的,這樣合並起來的話,繼續用合並起來的空間,這就顯得很巧妙了
合並兩個隊列的時候將短的隊列加入長的隊列,而隊列的長度上限取決於這個子樹內的最長鏈,如果我們采用長鏈剖分,則每條長鏈只會被遍歷一次,時間復雜度是\(O(n)\)的,詢問的時候二分查找第一個大於等於\(W\)的值,將前面的整體部分記上,再加上后面多出來的
如圖,橫坐標為\(depth\),縱坐標為\(w\),我們求的值就是下面的面積(其中隊列里存的節點為染成藍綠色的\(A,C,E,G\))(感謝@zjp_shadow提醒弱智博主修復)

整體復雜度\(O((n+m)\log n)\),具體做法詳見代碼
Code
#include <bits/stdc++.h>
typedef long long ll;
template <typename _tp> inline void read(_tp&x){
char ch=getchar(),ob=0;x=0;
while(ch!='-'&&!isdigit(ch))ch=getchar();if(ch=='-')ob=1,ch=getchar();
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();if(ob)x=-x;
}
template <typename _tp> inline void cmax(_tp&A,_tp B) {A = A > B ? A : B;}
template <typename _tp> inline _tp max(_tp A,_tp B) {return A > B ? A : B;}
const int N=301000;
struct Edge{int v,nxt;}a[N+N];
int head[N],Head[N],dep[N];
int he[N],ta[N],dfn[N],son[N];
int Q_t[N],w[N],n,m,dfc=1,_;
ll Ans[N],sm[N];
inline void add(int*arr,int u,int v){a[++_].v = v, a[_].nxt = arr[u], arr[u] = _;}
struct vLCA{
int anc[N][20],ans_F[N][20],len[N];
inline vLCA(){memset(len,0,sizeof len); len[0] = -1;}
int query(int x,int y){
int res(0);
for(int i=19;~i;--i)
if(dep[anc[x][i]] > dep[y])
cmax(res,ans_F[x][i]), x = anc[x][i];
return res;
}
void dfs(int x){
for(int i=1;i<20;++i){
anc[x][i] = anc[anc[x][i-1]][i-1];
ans_F[x][i] = max(ans_F[x][i-1], ans_F[anc[x][i-1]][i-1]);
}
for(int i=head[x];i;i=a[i].nxt){
anc[a[i].v][0] = x;
ans_F[a[i].v][0] = w[x];
dep[a[i].v] = dep[x] + 1;
dfs(a[i].v);
if(len[a[i].v] > len[son[x]])
son[x] = a[i].v;
}
len[x] = len[son[x]] + 1;
}
}lca;
struct node{
int dep,w;
inline node(){}
inline node(const int&Dep,const int&W):dep(Dep),w(W){}
}p[N],tmp[N];
void ins(int x, node e) {
while(he[x] <= ta[x] and p[he[x]].w <= e.w)++he[x];
if(he[x] > ta[x] or p[he[x]].dep > e.dep){
sm[he[x]-1] = 0;
if(he[x] <= ta[x]) sm[he[x]-1] = (ll)(p[he[x]].w - e.w) * p[he[x]].dep + sm[he[x]];
p[--he[x]] = e;
}
}
void merge(int x,int y){
int tp=0;
while(he[x] <= ta[x] and p[he[x]].dep < p[ta[y]].dep)
tmp[++tp] = p[he[x]++];
while(tp and he[y] <= ta[y])
if(p[ta[y]].dep > tmp[tp].dep) ins(x,p[ta[y]--]);
else ins(x, tmp[tp--]);
while(tp) ins(x, tmp[tp--]);
while(he[y] <= ta[y]) ins(x,p[ta[y]--]);
}
void solve(int x,int id){
int y = Q_t[id], op; ll res = 0ll;
int l = he[x], r = ta[x], mid, w = lca.query(y,x);
while(l<r){
mid = l + r>> 1;
if(p[mid].w < w) l = mid + 1;
else r = mid;
}
op = bool(p[he[x]].w <= w);
if(op) res = sm[he[x]] - sm[l] + (ll)p[he[x]].w * p[he[x]].dep;
res += (ll)p[l].dep * (w-(op?p[l].w:0)) - (ll)dep[x]*w;
Ans[id] = res + dep[y] - dep[x];
}
void dfs(int x){
dfn[x] = ++dfc;
if(son[x]) dfs(son[x]), he[x] = he[son[x]], ta[x] = ta[son[x]];
else he[x] = dfc, ta[x] = dfc - 1;
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v != son[x])
dfs(a[i].v), merge(x,a[i].v);
for(int i=Head[x];i;i=a[i].nxt)
solve(x,a[i].v);
ins(x,node(dep[x],w[x]));
}
void input();
void print();
int main(){
input(); lca.dfs(1);
dfs(1); print();
return 0;
}
void print(){
for(int i=1;i<=m;++i)
printf("%lld\n",Ans[i]);
}
void input(){
read(n), dep[1] = 1; int x;
for(int i=1;i<=n;++i) read(w[i]);
for(int i=2;i<=n;++i) read(x), add(head,x,i);
read(m);
for(int i=1;i<=m;++i){
read(x), read(Q_t[i]);
add(Head,x,i);
}
}
