三維醫學圖像深度學習,數據增強方法(monai):RandHistogramShiftD, Flipd, Rotate90d


#coding:utf-8
import torch
from monai.transforms import Compose, RandHistogramShiftD, Flipd, Rotate90d
import matplotlib.pyplot as plt
import SimpleITK as sitk
# start a chain of transforms
KEYS = ("image", "label")
class aug():
    def __init__(self):
        self.random_rotated = Compose([
            Rotate90d(KEYS, k=1, spatial_axes=(2,3),allow_missing_keys=True),
            Flipd(KEYS, spatial_axis=(1,2,3),allow_missing_keys=True),
            RandHistogramShiftD(KEYS,  prob=1, num_control_points=30, allow_missing_keys=True),
            # ToTensorD(KEYS),
        ])
    def forward(self,x):
        x = self.random_rotated(x)
        return x

# start a dataset
def save(before_x, after_x, new_path,new_name=""):
    after_x = after_x[0, 0,...]
    if new_name=="image":
        ct = sitk.ReadImage(before_x, sitk.sitkInt16)
    else:
        ct = sitk.ReadImage(before_x, sitk.sitkUInt8)
    predict_seg = sitk.GetImageFromArray(after_x)
    predict_seg.SetDirection(ct.GetDirection())
    predict_seg.SetOrigin(ct.GetOrigin())
    predict_seg.SetSpacing(ct.GetSpacing())

    sitk.WriteImage(predict_seg,new_path)


if __name__ == "__main__":
    image = r"D:\MyData\3Dircadb1_fusion_date\image_2.nii"   # 原圖
    label = r"D:\MyData\3Dircadb1_fusion_date\liver_2.nii"   #標簽
    new_path = r"D:\MyData\3Dircadb1_fusion_date\image_0.nii"  #增強后的原圖
    new_path1 = r"D:\MyData\3Dircadb1_fusion_date\liver_1.nii"  #增強后的標簽

    ct = sitk.ReadImage(image)
    ct1 = sitk.GetArrayFromImage(ct)
    seg = sitk.ReadImage(label)
    seg1 = sitk.GetArrayFromImage(seg)

    ct = ct1[None, None,...]
    seg = seg1[None, None,...]

    ct = torch.from_numpy(ct)
    seg = torch.from_numpy(seg)
    m = {"image": ct,
         "label":seg}
    augs = aug()
    print(m["image"].shape)
    data_dict= augs.forward(m)

    save(image, data_dict["image"], new_path, "image")
    save(label, data_dict["label"], new_path1, "label")


    print(data_dict["image"].shape)
    plt.subplots(1, 3)
    plt.subplot(1, 3, 1);
    plt.imshow(ct1[66,...])
    plt.subplot(1, 3, 2);
    plt.imshow(data_dict["image"][0,0, 66,...])
    plt.subplot(1, 3, 3);
    plt.imshow(data_dict["label"][0,0, 66,...])
    plt.show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM