1. NiftyNet項目概述
NiftyNet項目對tensorflow進行了比較好的封裝,實現了一整套的DeepLearning流程。將數據加載、模型加載,網絡結構定義等進行了很好的分離,抽象封裝成了各自獨立的模塊。雖然抽象的概念比較多,使得整個項目更為復雜,但是整體結構清晰,支持的模塊多。可擴展性還沒有進行試驗,暫時不是很清楚。 該項目能夠實現:
- 圖像分割
- 圖像分類
- gan
- Autoencoder
- 回歸
項目支持醫學圖像的讀取,提供的讀取器有:
- nibabel 支持.nii醫學文件格式
- simpleitk 支持.dcm和.mhd格式的醫療圖像
- opencv 支持.jpg等常見圖像讀取,讀取后通道順序為BGR
- skimage 支持.jpg等常見圖像讀取
- pillow 支持.jpg等常見圖像讀取
在使用中遇到了一些問題,其訓練的速度非常慢。最開始單個iter的平均訓練時間估計在40秒以上,有的iter時間會有200秒。現在主要在查找性能瓶頸。
一、 項目結構
niftynet.engine.application_driver(ApplicationDriver)定義並驅動着整個Application的生命周期,將配置數據進行解析后,實例化Application並啟動流程。
i. Application
Application 作為核心概念,承擔整個train或inference的主要功能。所有Application繼承於niftynet.application.base_application(簡稱為BaseApplication)。BaseApplication使用單例模式。
在Application類中,構建了Tensorflow的圖結構和創建Session用於驅動計算。
BaseApplication單例模式的具體實現有一點小問題。
Application所完成的工作具體可以划分成以下4個環節
- 輸入數據相關 數據加載,數據增強,數據取樣等,抽象在這兩個接口中在SegmentationApplication中,sampler支持:uniform, weighted, resized, balanced4種方式
initialise_dataset_loader()
initialise_sampler()
- 網絡結構相關 網絡結構的定義,參數的管理,自定義操作等,抽象在此接口中
initialise_network()
- 模型共享相關 完成由網絡的輸入到網絡的輸出,計算loss、gradient,創建optimizer等,抽象在此接口中
connect_data_and_network()
- 輸出解碼相關 inference將網絡輸出解碼操作,抽象在此接口中
interpret_output()
ii. Config
配置文件需要必須包含的模塊:
- [SYSTEM]
- [NETWORK]
- 如果action為train,那么config中需要包含[TRAINING]模塊
- 如果action為inference,那么config中需要包含[INFERENCE]模塊
- 額外的,根據特定的application,會需要包含指定名稱的模塊。如:
– [GAN]
– [SEGMENTATION]
– [REGRESSION]
– [AUTOENCODER]
- 除了以上的配置外,其他的數據會處理為input data source specifications【數據聲明模塊】
l 數據聲明模塊
Name |
解釋 |
例子 |
默認值 |
csv_file |
包含輸入圖像文件的列表 |
csvfile=filelist.csv |
'' |
pathtosearch |
如果沒有配置csv_file,則從此路徑下去搜索輸入圖像 |
pathtosearch=~/ct_data |
NiftyNet home folder |
filename_contains |
搜索輸入圖像時用於匹配的關鍵詞 |
filename_contains=foo, bar |
'' |
filenamenotcontains |
搜索輸入圖像時用於排除的關鍵詞 |
filenamenotcontains=ti, s1 |
'' |
filename_removefromid |
正則表達式,用於從輸入圖像的文件名中,解析出id |
filename_removefromid=foo |
'' |
interp_order |
插值法 |
interp_order=1 |
3 |
pixdim |
如果指定了,輸入的3D圖像會重新采樣到指定大小再送入網絡 |
pixdim=1.2, 1.2, 1.2 |
'' |
axcodes |
如果指定了,輸入的3D圖像會重新設定到指定的axcodes順序再送入網絡 參考文章 |
axcodes=L, P, S |
'' |
spatialwindowsize |
3個整數,指定輸入window的大小[能被8整除] |
spatialwindowsize=64, 64, 64 |
'' |
loader |
指定圖像讀取loader類型 |
loder=simpleitk |
None |
[interp_order] 當設定采樣方法為resize時,需要這個參數對圖片上采樣或下采樣 1表示雙線性插值
0表示最近鄰插值
3表示三次樣條插值
l [SYSTEM]
Name |
解釋 |
例子 |
默認值 |
cude_devices |
指定GPU |
cuda_devices=0,1 |
'' |
num_threads |
預處理線程的數量 |
num_threads=8 |
2 |
num_gpus |
訓練時使用GPU數量 |
num_gpus=2 |
1 |
model_dir |
保存或讀取模型權重和Log的位置 |
model_dir=~/niftynet/xxx |
config文件所在目錄 |
datasetsplitfile |
用於將數據划分成training/validation/inferenct字集 |
datasetsplitfile=~/nifnet/xxx |
./datasetsplitfile.csv |
event_handler |
注冊事件處理 |
eventhandler=modelrestorer |
modelsaver, modelrestorer, samplerthreading, applygradients, outputinterpreter, consolelogger, tensorboard_logger |
l [NETWORK]
Names |
解釋 |
例子 |
默認值 |
name |
所使用的網絡結構 |
name=niftynet.network.toynet.ToyNet |
‘’ |
activation_function |
設置網絡中使用的激活函數 |
activation_function=prelu |
Relu |
batch_size |
批大小 |
batch_size=10 |
2 |
smaller_final_batch_mode |
當總數據量不能被batch_size整除時,最后一個batch_size的方式 |
smaller_final_batch_mode=drop smaller_final_batch_mode=pad smaller_final_batch_mode=dynamic |
pad |
decay |
正則化參數 |
decay=1e-5 |
0.0 |
reg_type |
正則化類型 |
reg_type=L1 |
L2 |
volume_padding_size |
|
volume_padding_size=4, 4, 4 |
0, 0, 0 |
volume_padding_mode |
|
volume_padding_mode=symmetric |
minimum |
window_sampling |
采樣的類型 |
window_sampling=uniform 固定尺寸,相同的概率分布 window_sampling=weighted 固定尺寸,根據intensity作為概率分布 window_sampling=balanced 固定尺寸,每個label擁有相同采樣概率 window_sampling=resize 縮放圖像到window尺寸 |
uniform |
queue_length |
采樣時使用的buffer大小 |
queue_length=10 |
5 |
keep_prob |
如果網絡中使用了dropout |
keep_prob=0.2 |
1.0 |
l [TRAINING]
Name |
解釋 |
例子 |
默認值 |
optimizer |
優化器類型 |
optimizer=momentum |
adam |
sample_per_volume |
每個輸入圖像采樣的次數 |
sample_per_volume=5 |
1 |
lr |
學習率 |
lr=0.0001 |
0.1 |
loss_type |
loss計算方式 |
loss_type=CrossEntropy |
Dice |
starting_iter |
啟動的iter |
starting_iter=0 |
0 |
save_every_n |
保存的間隔 |
save_every_n=50 |
500 |
tensorboard_every_n |
tensorboard記錄的間隔 |
tensorboard_every_n=50 |
20 |
max_iter |
最大iter數 |
max_iter=3000 |
10000 |
max_checkpoints |
保存的最多checkpoint數 |
max_checkpoints=5 |
100 |
訓練時驗證
validation_every_n |
訓練時進行驗證的間隔 |
validation_every_n=10 |
-1 |
validation_max_iter |
驗證時iter的數量 |
validation_max_iter=5 |
1 |
exclude_fraction_for_validation |
驗證集的比重 |
exclude_fraction_for_validation=0.2 |
0.0 |
exclude_fraction_for_inference |
測試集的比重 |
exclude_fraction_for_inference=0.1 |
0.0 |
數據增強
rotation_angle |
旋轉 |
rotation_angle=-10.0, 10.0 |
‘’ |
scaling_percentage |
縮放 |
scaling_percentage=-20.0, 20.0 |
‘’ |
random_flipping_axes |
翻轉 |
random_flipping_axes=1,2 |
-1 |
l [INFERENCE]
Name |
解釋 |
例子 |
默認值 |
spatial_window_size |
網絡輸入尺寸大小 |
spatial_window_size=64,64,64 |
‘’ |
border |
輸入尺寸的邊框 |
border=5,5,5 |
0,0,0 |
inference_iter |
使用指定iter保存的權重文件 |
inference_iter=1000 |
-1 |
save_seg_dir |
保存輸出路徑 |
save_seg_dir=output/test |
output |
output_postfix |
輸出保存的后綴 |
output_postfix=_output |
_niftynet_out |
output_interp_order |
插值法 |
output_interp_order=0 |
0 |
dataset_to_infer |
使用的數據集,可選:’all’, ‘training’, ‘validation’, ‘inference’ |
dataset_to_infer=all |
‘’ |
iii. Reader & Dataset
n niftynet.io.image_reader模塊
ImageReader的主要作用是,遍歷一組目錄,搜索並返回一個圖像的列表,以及使用iterative的方式將數據加載到內存中。
ImageReader會創建一個tf.data.Dataset的對象,這樣使得模塊可以很方便地接入到基於tensorflow的程序中。
ImageReader的特點:
l 設計用於支持醫療圖像數據的格式
l 支持多模態輸入數據
l 支持tf.data.Dataset
n niftynet.contrib.dataset_sampler
sampler將 image reader作為輸入,從每張圖像中采取出結果輸出。
在很多的醫學圖像處理的情況中,由於GPU顯存的限制以及訓練效率等的考慮,網絡結構會對圖像的部分進行處理而非整張圖像。
iv. Network
項目中包含了一些已經實現的網絡:
- GAN:
– simulator_gan
– siple_gan
- Segmentation:
– highres3dnet, highres3dnetsmall, highres3dnetlarge
– toynet
– unet
– vnet
– dense_vnet
– deepmedic
– scalenet
– holisticnet
– unet_2d
- classification:
– resnet
– se_resnet
- autoencoder:
– vae
v. Loss
已提供支持的loss計算方式
- Segmentation
- CrossEntropy
- CrossEntropy_Dense
- Dice
- Dice_NS
- Dice_Dense
- Dice_Dense_NS
- Tversky
- GDSC
- WGDL
- SensSpec
- Gan
- CrossEntropy
- Regression
- L1Loss
- L2Loss
- RMSE
- MAE
- Huber
- Classification
- CrossEntropy
- AutoEncoder
- VariationalLowerBound
支持的優化器類型
- adam
- gradientdescent
- momentum
- nesterov
- adagrad
- rmsprop
vi. Event機制
NiftyNet項目的設計,使用了Signal和event handler模式,具體實現使用了blinker庫。這樣可以方便地將模型保存,tensorboard記錄等操作進行配置。
目前可供注冊的signal有:
- GRAPH_CREATED
- SESS_STARTED
- SESS_FINISHED
- ITER_STARTED
- ITER_FINISHED
信號處理函數注冊到對應的信號后,由引擎負責調用。
vii. Layer
網絡層的相關設計都封裝在Layer類中,可繼承layer類,實現定制化結構