最近一個月都在做肺結節的檢測,學到了不少東西,運行的項目主要是基於這篇論文,在github上可以查到項目代碼。
我個人總結的肺結節檢測可以分為三個階段,數據預處理,網絡搭建及訓練,結果評估。
這篇博客主要分析一下項目預處理部分的代碼實現。
預處理的全部代碼都在prepare.py中,對原始數據進行處理,輸出預處理后的數據。
首先是主函數
def preprocess_luna():
luna_segment = config['luna_segment']#存放CT掩碼的路徑
savepath = config['preprocess_result_path']#存放預處理后數據的路徑
luna_data = config['luna_data']#LUNA16的原始數據
luna_label = config['luna_label']#存放所有病例標簽的文件
finished_flag = '.flag_preprocessluna'#是否已經預處理過的標志
print('starting preprocessing luna')
if not os.path.exists(finished_flag):
annos = np.array(pandas.read_csv(luna_label))
pool = Pool()#開啟線程池
if not os.path.exists(savepath):
os.mkdir(savepath)
for setidx in xrange(10):#十份數據
print 'process subset', setidx
filelist = [f.split('.mhd')[0] for f in os.listdir(luna_data+'subset'+str(setidx)) if f.endswith('.mhd') ]#原始數據為.mhd文件,只保留文件名,去掉.mhd后綴
if not os.path.exists(savepath+'subset'+str(setidx)):
os.mkdir(savepath+'subset'+str(setidx))#為每份數據創建存放預處理結果的文件夾
partial_savenpy_luna = partial(savenpy_luna, annos=annos, filelist=filelist,#函數修飾器,將一些參數預先設定,后面調用更簡潔
luna_segment=luna_segment, luna_data=luna_data+'subset'+str(setidx)+'/',
savepath=savepath+'subset'+str(setidx)+'/')
N = len(filelist)
#savenpy(1)
_=pool.map(partial_savenpy_luna,range(N))#將函數調用在序列的每個元素上,返回一個含有所有返回值的列表
pool.close()#關閉線程池
pool.join()
print('end preprocessing luna')
f= open(finished_flag,"w+")#預處理結束,寫入結束標志
上面的代碼就是預處理的全部代碼,當然里面還調用了其它函數,主要的流程就是針對十份數據中的每一份,針對每一份中的每一個case(CT圖像,以.mhd格式存儲),分別處理,為加快速度,在每份數據中,針對每個文件分別開啟一個線程,我試過不用線程,采取循環處理,速度確實慢了很多。
上面的代碼中最關鍵的就是數據預處理函數savenpy_luna,在主函數中已經設定好一部分參數,如文件名列表,標簽,掩碼,原始數據,預處理存儲路徑,萬事俱備,只欠實現,接下來就看一下savenpy_luna的代碼。
def savenpy_luna(id, annos, filelist, luna_segment, luna_data,savepath):
islabel = True
isClean = True
resolution = np.array([1,1,1])
# resolution = np.array([2,2,2])
name = filelist[id]
sliceim,origin,spacing,isflip = load_itk_image(os.path.join(luna_data,name+'.mhd'))#加載原始數據
Mask,origin,spacing,isflip = load_itk_image(os.path.join(luna_segment,name+'.mhd'))#加載相應的掩碼
if isflip: #這一步沒看懂
Mask = Mask[:,::-1,::-1]
newshape = np.round(np.array(Mask.shape)*spacing/resolution).astype('int')#獲取mask在新分辨率下的尺寸
m1 = Mask==3 #LUNA16的掩碼有兩種值,3和4
m2 = Mask==4
Mask = m1+m2 #將兩種掩碼合並
xx,yy,zz= np.where(Mask) #確定掩碼的邊界
box = np.array([[np.min(xx),np.max(xx)],[np.min(yy),np.max(yy)],[np.min(zz),np.max(zz)]])
box = box*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) #對邊界即掩碼的最小外部長方體應用新分辨率
box = np.floor(box).astype('int')
margin = 5
extendbox = np.vstack([np.max([[0,0,0],box[:,0]-margin],0),np.min([newshape,box[:,1]+2*margin],axis=0).T]).T #對box留置一定空白
this_annos = np.copy(annos[annos[:,0]==(name)]) #讀取該病例對應標簽
if isClean:
convex_mask = m1
dm1 = process_mask(m1) #對掩碼采取膨脹操作,去除肺部黑洞
dm2 = process_mask(m2)
dilatedMask = dm1+dm2
Mask = m1+m2
extramask = dilatedMask ^ Mask
bone_thresh = 210
pad_value = 170
if isflip:
sliceim = sliceim[:,::-1,::-1]
print('flip!')
sliceim = lumTrans(sliceim) #對原始數據閾值化,並歸一化
sliceim = sliceim*dilatedMask+pad_value*(1-dilatedMask).astype('uint8') #170對應歸一化話后的水,掩碼外的區域補充為水
bones = (sliceim*extramask)>bone_thresh #210對應歸一化后的骨頭,凡是大於骨頭的區域都填充為水
sliceim[bones] = pad_value
sliceim1,_ = resample(sliceim,spacing,resolution,order=1) #對原始數據重采樣,即采用新分辨率
sliceim2 = sliceim1[extendbox[0,0]:extendbox[0,1], #將extendbox內數據取出作為最后結果
extendbox[1,0]:extendbox[1,1],
extendbox[2,0]:extendbox[2,1]]
sliceim = sliceim2[np.newaxis,...]
np.save(os.path.join(savepath, name+'_clean.npy'), sliceim)
np.save(os.path.join(savepath, name+'_spacing.npy'), spacing)
np.save(os.path.join(savepath, name+'_extendbox.npy'), extendbox)
np.save(os.path.join(savepath, name+'_origin.npy'), origin)
np.save(os.path.join(savepath, name+'_mask.npy'), Mask)
if islabel:
this_annos = np.copy(annos[annos[:,0]==(name)]) #一行代表一個結節,所以一個病例可能對應多行標簽
label = []
if len(this_annos)>0:
for c in this_annos:
pos = worldToVoxelCoord(c[1:4][::-1],origin=origin,spacing=spacing) #將世界坐標轉換為體素坐標
if isflip:
pos[1:] = Mask.shape[1:3]-pos[1:]
label.append(np.concatenate([pos,[c[4]/spacing[1]]]))
label = np.array(label)
if len(label)==0:
label2 = np.array([[0,0,0,0]]) #若沒有結節則設為全0
else:
label2 = np.copy(label).T
label2[:3] = label2[:3]*np.expand_dims(spacing,1)/np.expand_dims(resolution,1) #對標簽應用新的分辨率
label2[3] = label2[3]*spacing[1]/resolution[1] #對直徑應用新的分辨率
label2[:3] = label2[:3]-np.expand_dims(extendbox[:,0],1) #將box外的長度砍掉,也就是相對於box的坐標
label2 = label2[:4].T
np.save(os.path.join(savepath,name+'_label.npy'), label2)
print(name)
這段代碼屬實是長,慢慢看。
主要分為以下幾步。
- 加載原始數據和掩碼,用的是load_itk_image函數
- 求取掩碼的邊界,即非零部分的邊緣,求出一個box,然后對其應用新的分辨率,也就是重采樣,將分辨率統一,采用的函數是resample
- 將數據clip至-1200~600,此范圍外的數據置為-1200或600,然后再將數據歸一化至0~255,采用的是lum_trans函數
- 對掩碼進行一下膨脹操作,去除肺部的小空洞,采用的函數是process_mask,然后對原始數據應用新掩碼,並將掩碼外的數據值為170(水的HU值經過歸一化后的新數值)
- 將原始數據重采樣,再截取box內的數據即可。
- 讀取標簽,將其轉換為體素坐標,采用的函數是worldToVoxelCoord,再對其應用新的分辨率,最后注意,數據是box內的數據,所以坐標是相對box的坐標。
- 將預處理后的數據和標簽以.npy格式存儲
針對上面用到的各種工具函數,分別解析下。
加載圖像:主要用到sitk模塊,讀取原始數據的numpy表示,以及origin和space,origin就是真實坐標的原點,space就是分辨率
def load_itk_image(filename):
with open(filename) as f:
contents = f.readlines()
line = [k for k in contents if k.startswith('TransformMatrix')][0]
transformM = np.array(line.split(' = ')[1].split(' ')).astype('float')
transformM = np.round(transformM)
if np.any( transformM!=np.array([1,0,0, 0, 1, 0, 0, 0, 1])):
isflip = True
else:
isflip = False
itkimage = sitk.ReadImage(filename)
numpyImage = sitk.GetArrayFromImage(itkimage)
numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
return numpyImage, numpyOrigin, numpySpacing,isflip
重采樣:原始CT分辨率往往不一致,為便於應用網絡,需要統一分辨率
def resample(imgs, spacing, new_spacing,order=2):
if len(imgs.shape)==3:
new_shape = np.round(imgs.shape * spacing / new_spacing)
true_spacing = spacing * imgs.shape / new_shape
resize_factor = new_shape / imgs.shape
imgs = zoom(imgs, resize_factor, mode = 'nearest',order=order)
return imgs, true_spacing
elif len(imgs.shape)==4:
n = imgs.shape[-1]
newimg = []
for i in range(n):
slice = imgs[:,:,:,i]
newslice,true_spacing = resample(slice,spacing,new_spacing)
newimg.append(newslice)
newimg=np.transpose(np.array(newimg),[1,2,3,0])
return newimg,true_spacing
else:
raise ValueError('wrong shape')
坐標轉換:給定的標簽是世界坐標,單位是mm,需要轉換為體素坐標,也就是在像素體內的坐標
def worldToVoxelCoord(worldCoord, origin, spacing):
stretchedVoxelCoord = np.absolute(worldCoord - origin)
voxelCoord = stretchedVoxelCoord / spacing
return voxelCoord
掩碼處理:這里對掩碼進行膨脹處理
def process_mask(mask):
convex_mask = np.copy(mask)
for i_layer in range(convex_mask.shape[0]):
mask1 = np.ascontiguousarray(mask[i_layer])
if np.sum(mask1)>0:
mask2 = convex_hull_image(mask1)
if np.sum(mask2)>1.5*np.sum(mask1):
mask2 = mask1
else:
mask2 = mask1
convex_mask[i_layer] = mask2
struct = generate_binary_structure(3,1)
dilatedMask = binary_dilation(convex_mask,structure=struct,iterations=10)
return dilatedMask
歸一化:將數據歸一化至0~255
def lumTrans(img):
lungwin = np.array([-1200.,600.])
newimg = (img-lungwin[0])/(lungwin[1]-lungwin[0])
newimg[newimg<0]=0
newimg[newimg>1]=1
newimg = (newimg*255).astype('uint8')
return newimg
