[深度應用]·首屆中國心電智能大賽初賽開源Baseline(基於Keras val_acc: 0.88)
個人網站–> http://www.yansongsong.cn
項目github地址:https://github.com/xiaosongshine/preliminary_challenge_baseline_keras
大賽簡介
為響應國家健康中國戰略,推送健康醫療和大數據的融合發展的政策,由清華大學臨床醫學院和數據科學研究院,天津市武清區京津高村科技創新園,以及多家重點醫院聯合主辦的首屆中國心電智能大賽正式啟動。自今日起至2019年3月31日24時,大賽開啟全球招募,預計大賽總獎金將高達百萬元!目前官方報名網站已上線,歡迎高校、醫院、創業團隊等有志於中國心電人工智能發展的人員踴躍參加。
首屆中國心電智能大賽官方報名網站>>http://mdi.ids.tsinghua.edu.cn
數據介紹
下載完整的訓練集和測試集,共1000例常規心電圖,其中訓練集中包含600例,測試集中共400例。該數據是從多個公開數據集中獲取。參賽團隊需要利用有正常/異常兩類標簽的訓練集數據設計和實現算法,並在沒有標簽的測試集上做出預測。
該心電數據的采樣率為500 Hz。為了方便參賽團隊用不同編程語言都能讀取數據,所有心電數據的存儲格式為MAT格式。該文件中存儲了12個導聯的電壓信號。訓練數據對應的標簽存儲在txt文件中,其中0代表正常,1代表異常。
賽題分析
簡單分析一下,初賽的數據集共有1000個樣本,其中訓練集中包含600例,測試集中共400例。其中訓練集中包含600例是具有label的,可以用於我們訓練模型;測試集中共400例沒有標簽,需要我們使用訓練好的模型進行預測。
賽題就是一個二分類預測問題,解題思路應該包括以下內容
- 數據讀取與處理
- 網絡模型搭建
- 模型的訓練
- 模型應用與提交預測結果
實戰應用
經過對賽題的分析,我們把任務分成四個小任務,首先第一步是:
1.數據讀取與處理
該心電數據的采樣率為500 Hz。為了方便參賽團隊用不同編程語言都能讀取數據,所有心電數據的存儲格式為MAT格式。該文件中存儲了12個導聯的電壓信號。訓練數據對應的標簽存儲在txt文件中,其中0代表正常,1代表異常。
我們由上述描述可以得知,
- 我們的數據保存在MAT格式文件中(這決定了后面我們要如何讀取數據)
- 采樣率為500 Hz(這個信息並沒有怎么用到,大家可以簡單了解一下,就是1秒采集500個點,由后面我們得知每個數據都是5000個點,也就是10秒的心電圖片)
- 12個導聯的電壓信號(這個是指采用12種導聯方式,大家可以簡單理解為用12個體溫計量體溫,從而得到更加准確的信息,下圖為導聯方式簡單介紹,大家了解下即可。要注意的是,既然提供了12種導聯,我們應該全部都用到,雖然我們僅使用一種導聯方式也可以進行訓練與預測,但是經驗告訴我們,采取多個特征會取得更優效果)
數據處理函數定義:
讀取一條數據進行顯示
我們由上述信息可以看出每種導聯都是由5000個點組成的列表,12種導聯方式使每個樣本都是12*5000的矩陣,類似於一張分辨率為12x5000的照片。
我們需要處理的就是把每個讀取出來,歸一化一下,送入網絡進行訓練可以了。
標簽處理方式
我這里是采用從reference.txt讀取,然后打亂保存到reference.csv中,注意一定要進行數據打亂操作,不然訓練效果很差。因為原始數據前面便簽全部是1,后面全部是0
數據迭代方式
數據讀取的方式我采用的是生成器的方式,這樣可以按batch讀取,加快訓練速度,大家也可以采用一下全部讀取,看個人的習慣了
2.網絡模型搭建
數據我們處理好了,后面就是模型的搭建了,我使用keras搭建的,操作簡單便捷,tf,pytorch,sklearn大家可以按照自己喜好來。
網絡模型可以選擇CNN,RNN,Attention結構,或者多模型的融合,拋磚引玉,此Baseline采用的一維CNN方式,一維CNN學習地址
模型搭建
用model.summary()輸出的網絡模型為
訓練參數比較少,大家可以根據自己想法更改。
3.網絡模型訓練
模型訓練
if __name__ == "__main__": """dat1 = get_feature("TRAIN101.mat") print("one data shape is",dat1.shape) #one data shape is (12, 5000) plt.plot(dat1[0]) plt.show()""" if (os.path.exists(MANIFEST_DIR)==False): create_csv() train_iter = xs_gen(train=True) test_iter = xs_gen(train=False) model = build_model() print(model.summary()) ckpt = keras.callbacks.ModelCheckpoint( filepath='best_model.{epoch:02d}-{val_acc:.2f}.h5', monitor='val_acc', save_best_only=True,verbose=1) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.fit_generator( generator=train_iter, steps_per_epoch=500//Batch_size, epochs=20, initial_epoch=0, validation_data = test_iter, nb_val_samples = 100//Batch_size, callbacks=[ckpt], )
訓練過程輸出(最優結果:loss: 0.0565 - acc: 0.9820 - val_loss: 0.8307 - val_acc: 0.8800)
4.模型應用預測結果
預測數據
下面是前十條預測結果:
大家需要注意一下,我預測的方式和官方不同,需要大家自己根據賽題要求來進行預測提交。。
展望
此Baseline采用最簡單的一維卷積達到了88%測試准確率(可能會因為隨機初始化值上下波動),大家也可以多嘗試GRU,Attention,和Resnet等結果,測試准確率准確率會突破90+。
能力有限,寫的不好的地方歡迎大家批評指正。。
個人主頁--> https://xiaosongshine.github.io/
項目github地址:https://github.com/xiaosongshine/preliminary_challenge_baseline_keras
歡迎Fork+Star,覺得有用的話,麻煩小小鼓勵一下 ><