dataloader AttributeError: Can‘t pickle local object ‘trainer_synapse.<locals>.worker_init_fn‘ (不需要改numworkers)


在跑transunet代碼(https://github.com/Beckschen/TransUNet,論文:https://arxiv.org/pdf/2102.04306.pdf)的時候遇到上述問題,在網上解決方法基本都是把dataloader的numworkers改為0,但改完后訓練速度會下降

我的理解也比較淺薄,就是dataloader新的線程不能在dataset”trainer_synapse.py“文件下找到worker_init_fn函數

可以看到一開始的trainer函數如下:

def trainer_synapse(args, model, snapshot_path):
    from datasets.dataset_synapse import Synapse_dataset, RandomGenerator
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))
    base_lr = args.base_lr
    num_classes = args.num_classes
    batch_size = args.batch_size * args.n_gpu
    # max_iterations = args.max_iterations
    db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train",
                               transform=transforms.Compose(
                                   [RandomGenerator(output_size=[args.img_size, args.img_size])]))

    print("The length of train set is: {}".format(len(db_train)))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn,)
    ...

程序將def worker_init_fn 放在了trainer_synapse里面,所以外部調用trainer_synapse的時候找不到worker_init_fn函數(還沒被定義)

所以我就簡單地將代碼改成了:

def worker_init_fn(worker_id):
    random.seed(1234 + worker_id)
def trainer_synapse(args, model, snapshot_path):
    from datasets.dataset_synapse import Synapse_dataset, RandomGenerator
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))
    base_lr = args.base_lr
    num_classes = args.num_classes
    batch_size = args.batch_size * args.n_gpu
    # max_iterations = args.max_iterations
    db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train",
                               transform=transforms.Compose(
                                   [RandomGenerator(output_size=[args.img_size, args.img_size])]))

    print("The length of train set is: {}".format(len(db_train)))



    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
                             worker_init_fn=worker_init_fn,)
    ...

(也就是把

def worker_init_fn

移動到了

def trainer_synapse

的外面,並且把arg.seed改成了作者默認的1234——我只是圖省事,如果需要保留randomseed這個功能,可以自己重新寫參數傳遞方法,比如設一個seedconfig.cfg文件)

總之我是成功跑起來了,還把transunet用到了自己的數據集上,課程設計有着落了,歐耶!

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM