(代碼托管在我的Github上,如果有幫助記得點星星嗨!)
0 - 背景
0.0 - 概要
生存預測模型探索的是患者的各個屬性/特征與治療效果之間的關系。之前的生存預測模型,像linear Cox proportional hazards model需要有專業的醫學知識作為專業背景來構建特征工程,而另外的一些nonlinear survival methods,像neural networks/survival forests則沒有在有效的推薦系統中得到實踐證明。文中提出一種Cox proportional hazards deep neural network的生存模型DeepSurv,並且在模擬數據集及臨床數據集上進行實驗,證明了模型具有可比的或者最好的性能,此外還將該DeepSurv應用到治療效果推薦系統上提供個性化推薦。
0.1 - 相關概念
生存數據:一般生存數據由三部分組成:患者的基線數據$x$,死亡事件時間$T$,事件指標$E$。如果死亡事件發生了,則$T$代表的就是基線數據$x$與死亡事件發生之間的時間間隔,此時$E=1$。如果死亡事件沒有發生,則$T$表示基線數據$x$與患者最后一次采集數據的時間間隔,此時$E=0$,這部分數據稱為右刪失(right-censored)。
生存函數(Survival Function):生存函數可以定義為$S(t)=Pr(T>t)$,其表示的是個體在時刻$t$生存的概率,其可以通過下式進行估計,
$$\hat{S}(t)=\frac{number\ of\ patients\ surviving\ longer\ than\ t}{total\ number\ of\ patients}.$$
密度函數(Density Function):密度函數可以定義為$f(t)=\lim_{\delta \rightarrow 0}\frac{Pr\left(t\leq T< t+\delta | T\geq t \right )}{\delta}$,其表示已經處在生存時間$T$的短暫時刻發生事件的概率,其估計方法為,
$$\hat{f}(t)=\frac{number\ of\ patients\ dying\ in\ the\ interval\ beginning\ at\ time\ t}{(total\ number\ of\ patients)\times(interval\ width)}.$$
風險函數(Hazard Function):風險函數用來衡量當前個體在時刻$t$之前沒有發生任何事件的情況下,時刻$t$發生事件的概率,其可以定義為$\lambda(t)=\lim_{\delta \rightarrow 0}\frac{Pr\left(t\leq T< t+\delta\right)}{\delta}$,其估計方法為,
$$\hat{\lambda}(t)=\frac{number\ of\ patients\ dying\ in\ the\ interval\ beginning\ at\ time\ t}{(number\ of\ patients\ surviving\ at\ t)\times(interval\ width)}.$$
Cox比例風險回歸模型:Cox比例風險回歸模型是一種常用的方法,用於在給定基線數據$x$的情況下對個體的生存風險進行建模。該模型由兩部分組成:只與時間相關的基線風險函數$\lambda_0(t)$和只與患者數據$x$相關的函數$h(x)$。該模型表示為$\lambda(t|x)=\lambda_0(t)\cdot e^{h(x)}.$
C-index:這是生存預測的一個評價指標,英文全稱為concordance index,因為對於存在刪失的生存數據,一些標准的評估方法,例如均方誤差等,是不合適的。其計算方式是,(1)將所有樣本兩兩配對,例如有$N$個樣本,則一共可以組成$N\times (N-1)/2$對;(2)排除其中無法判斷誰先出現感興趣事件的配對(兩個實例都沒有發生事件),得到剩余的對數$M$;(3)在剩下的$M$對中,預測結果與實際結果一致的配對數$K$,即預測的生存$S(X_A)<S(X_B)$(或者說風險率$R(X_A)>R(X_B)$),實際的$T_A<T_B$,即為一致;(4)則$C-index=\frac{K}{M}$。其可以形式化為如下公式,
$$\frac{1}{M}\sum_{i:\delta_i=1}\sum_{j:T_i<T_j}I\left[S(T_i,X_i)<S(T_j,X_j) \right ],$$
其中$I[C]$表示若$C$為真,則$I[C]=1$,否則$I[C]=0$。$\delta_i=1$表示至少要有一個實例發生了事件,$T_i<T_j$表示對$i$和$j$配對的要求,即防止$i$和$j$顛倒算了兩次。
0.2 - Linear Survial Models
線性生存模型是把cox模型中的$h(x)$采用線性函數$\hat{h}_{\beta}(x)=\beta^Tx$進行建模,可以定義為,
$$L_c(\beta)=\prod_{i:E_i=1}\frac{exp(\hat{h}_{\beta}(x_i))}{\sum_{j\in \Re(T_i)}exp(\hat{h}_{\beta}(x_j))},$$
其中$T_i,E_i,x_i$分別表示事件事件、事件指標、第$i$個基准數據。上述式子定義在一組可觀察到事件發生的患者上$E_i=1$,風險集合$\Re (t)=\{i:T_i\geq t\}$表示在時刻$t$仍然處於風險的患者集合。
0.3 - NonLinear Survial Models
即$\hat{h}_{\theta}(x)$由非線性模型進行建模。
1 - 方法
1.0 - DeepSurv
DeepSurv是一個多層感知機,模型的預測輸出是一個值,代表患者的健康風險,其損失函數定義為,
$$l(\theta):=-\sum_{i:E_i=1}\left(\hat{h}_{\theta}(x_i)-log\sum_{j\in\Re(T_i)}e^{\hat{h}_{\theta}(x_j)} \right ),$$
文中將DeepSurv設計成了一個深度結構(可能有多層隱藏層),並且加入了權重衰減正則化、ReLU激活、batch normalization、SELU、dropout、SGD、Adam、梯度裁剪、學習率調整策略等當時比較新的技術。
1.1 - 治療推薦系統
在一項臨床研究中,患者根據其相關的預后特征和所接受的治療具有不同程度的風險。文中把這個假設概括為,假設研究中的患者被分到$n$個治療組$\tau \in \{0,1,\cdots,n-1\}$中的一個,每一個治療方案$i$具有獨立的風險函數$h_i(x)$。總的來說,風險函數變成了,
$$\lambda(t;x|\tau=i)=\lambda_0(t)\cdot e^{h_i(x)},$$
基於上述的假設,每一個個體擁有一樣的初始風險函數$\lambda_0(t)$,我們可以用采用不同資料方案的風險率的對數來對比同一個體接受兩種治療方案的對比,其推導為,
$$rec_{ij}(x)=log\left(\frac{\lambda(t;x|\tau=i)}{\lambda(t;x|\tau=j)} \right )=log\left(\frac{\lambda_0(t)\cdot e^{h_i(x)}}{\lambda_0(t)\cdot e^{h_j(x)}} \right )=h_i(x)-h_j(x),$$
如果$rec_{ij}>0$,則說明$i$方案比$j$方案風險高,應該選擇$j$方案,反之則反。
2 - 結果
我復現了文章中第4節在幾個數據集上的結果,模型和訓練的參數有稍微的調整,參數配置可以自己在配置文件里面修改,結果如下表所示。(代碼托管在我的Github上,如果有幫助記得點星星嗨
Simulated Linear | Simulated Nonlinear | WHAS | SUPPORT | METRABRIC | Simulated Treatment | Rotterdam & GBSG | |
Paper | 0.774019 | 0.648902 | 0.862620 | 0.618308 | 0.643374 | 0.582774 | 0.668402 |
Ours | 0.778607 | 0.652048 | 0.841484 | 0.618107 | 0.643453 | 0.552648 | 0.673290 |
3 - 參考資料
https://github.com/czifan/DeepSurv.pytorch