PTA L3-023 計算圖 (dfs+數學推導)


“計算圖”(computational graph)是現代深度學習系統的基礎執行引擎,提供了一種表示任意數學表達式的方法,例如用有向無環圖表示的神經網絡。 圖中的節點表示基本操作或輸入變量,邊表示節點之間的中間值的依賴性。 例如,下圖就是一個函數 ( 的計算圖。

figure.png

現在給定一個計算圖,請你根據所有輸入變量計算函數值及其偏導數(即梯度)。 例如,給定輸入,,上述計算圖獲得函數值 (;並且根據微分鏈式法則,上圖得到的梯度 ∇。

知道你已經把微積分忘了,所以這里只要求你處理幾個簡單的算子:加法、減法、乘法、指數(ex​​,即編程語言中的 exp(x) 函數)、對數(ln,即編程語言中的 log(x) 函數)和正弦函數(sin,即編程語言中的 sin(x) 函數)。

友情提醒:

  • 常數的導數是 0;x 的導數是 1;ex​​ 的導數還是 ex​​;ln 的導數是 1;sin 的導數是 cos。
  • 回顧一下什么是偏導數:在數學中,一個多變量的函數的偏導數,就是它關於其中一個變量的導數而保持其他變量恆定。在上面的例子中,當我們對 x1​​ 求偏導數 / 時,就將 x2​​ 當成常數,所以得到 ln 的導數是 1,x1​​x2​​ 的導數是 x2​​,sin 的導數是 0。
  • 回顧一下鏈式法則:復合函數的導數是構成復合這有限個函數在相應點的導數的乘積,即若有 (,(,則 /。例如對 sin 求導,就得到 cos。

如果你注意觀察,可以發現在計算圖中,計算函數值是一個從左向右進行的計算,而計算偏導數則正好相反。

輸入格式:

輸入在第一行給出正整數 N(≤),為計算圖中的頂點數。

以下 N 行,第 i 行給出第 i 個頂點的信息,其中 ,。第一個值是頂點的類型編號,分別為:

  • 0 代表輸入變量
  • 1 代表加法,對應 x1​​+x2​​
  • 2 代表減法,對應 x1​​x2​​
  • 3 代表乘法,對應 x1​​×x2​​
  • 4 代表指數,對應 ex​​
  • 5 代表對數,對應 ln
  • 6 代表正弦函數,對應 sin

對於輸入變量,后面會跟它的雙精度浮點數值;對於單目算子,后面會跟它對應的單個變量的頂點編號(編號從 0 開始);對於雙目算子,后面會跟它對應兩個變量的頂點編號。

題目保證只有一個輸出頂點(即沒有出邊的頂點,例如上圖最右邊的 -),且計算過程不會超過雙精度浮點數的計算精度范圍。

輸出格式:

首先在第一行輸出給定計算圖的函數值。在第二行順序輸出函數對於每個變量的偏導數的值,其間以一個空格分隔,行首尾不得有多余空格。偏導數的輸出順序與輸入變量的出現順序相同。輸出小數點后 3 位。

輸入樣例:

7
0 2.0
0 5.0
5 0
3 0 1
6 1
1 2 3
2 5 4

輸出樣例:

11.652
5.500 1.716

天梯賽L3的第二題,反向建圖之后利用各種求導公式對每個變量分別跑一遍dfs求偏導就行了。場下30分鍾過掉,場上的我真是宛如一個智障,~QAQ~

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 typedef double db;
 5 const int N=5e4+10;
 6 int n,f[N],dg[N],s,nxt[N][2],vis[N],x;
 7 db a[N],f1[N],f2[N];
 8 vector<int> vec;
 9 vector<db> ans;
10 void dfs(int u) {
11     if(vis[u])return;
12     vis[u]=1;
13     if(f[u]==0)f1[u]=a[u],f2[u]=u==x?1:0;
14     else if(f[u]==1) {
15         int v1=nxt[u][0],v2=nxt[u][1];
16         dfs(v1),dfs(v2);
17         f1[u]=f1[v1]+f1[v2],f2[u]=f2[v1]+f2[v2];
18     } else if(f[u]==2) {
19         int v1=nxt[u][0],v2=nxt[u][1];
20         dfs(v1),dfs(v2);
21         f1[u]=f1[v1]-f1[v2],f2[u]=f2[v1]-f2[v2];
22     } else if(f[u]==3) {
23         int v1=nxt[u][0],v2=nxt[u][1];
24         dfs(v1),dfs(v2);
25         f1[u]=f1[v1]*f1[v2],f2[u]=f2[v1]*f1[v2]+f1[v1]*f2[v2];
26     } else if(f[u]==4) {
27         int v=nxt[u][0];
28         dfs(v),f1[u]=exp(f1[v]),f2[u]=exp(f1[v])*f2[v];
29     } else if(f[u]==5) {
30         int v=nxt[u][0];
31         dfs(v),f1[u]=log(f1[v]),f2[u]=f2[v]/f1[v];
32     } else if(f[u]==6) {
33         int v=nxt[u][0];
34         dfs(v),f1[u]=sin(f1[v]),f2[u]=cos(f1[v])*f2[v];
35     }
36 }
37 int main() {
38     scanf("%d",&n);
39     for(int i=0; i<n; ++i) {
40         scanf("%d",&f[i]);
41         if(f[i]==0) {
42             scanf("%lf",&a[i]);
43             vec.push_back(i);
44         } else if(f[i]>=1&&f[i]<=3) {
45             int u,v;
46             scanf("%d%d",&u,&v);
47             nxt[i][0]=u,nxt[i][1]=v,dg[u]++,dg[v]++;
48         } else if(f[i]>=4&&f[i]<=6) {
49             int u;
50             scanf("%d",&u);
51             nxt[i][0]=u,dg[u]++;
52         }
53     }
54     for(int i=0; i<n; ++i)if(!dg[i])s=i;
55     for(int i:vec)x=i,memset(vis,0,sizeof vis),dfs(s),ans.push_back(f2[s]);
56     printf("%.3f\n",f1[s]);
57     for(int i=0; i<ans.size(); ++i)printf("%.3f%c",ans[i]," \n"[i==ans.size()-1]);
58     return 0;
59 }

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM