在跑transunet代碼(https://github.com/Beckschen/TransUNet,論文:https://arxiv.org/pdf/2102.04306.pdf)的時候遇到上述問題,在網上解決方法基本都是把dataloader的numworkers改為0,但改完后訓練速度會下降
我的理解也比較淺薄,就是dataloader新的線程不能在dataset”trainer_synapse.py“文件下找到worker_init_fn函數
可以看到一開始的trainer函數如下:
def trainer_synapse(args, model, snapshot_path): from datasets.dataset_synapse import Synapse_dataset, RandomGenerator logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) base_lr = args.base_lr num_classes = args.num_classes batch_size = args.batch_size * args.n_gpu # max_iterations = args.max_iterations db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", transform=transforms.Compose( [RandomGenerator(output_size=[args.img_size, args.img_size])])) print("The length of train set is: {}".format(len(db_train))) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn,) ...
程序將def worker_init_fn 放在了trainer_synapse里面,所以外部調用trainer_synapse的時候找不到worker_init_fn函數(還沒被定義)
所以我就簡單地將代碼改成了:
def worker_init_fn(worker_id): random.seed(1234 + worker_id) def trainer_synapse(args, model, snapshot_path): from datasets.dataset_synapse import Synapse_dataset, RandomGenerator logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) base_lr = args.base_lr num_classes = args.num_classes batch_size = args.batch_size * args.n_gpu # max_iterations = args.max_iterations db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", transform=transforms.Compose( [RandomGenerator(output_size=[args.img_size, args.img_size])])) print("The length of train set is: {}".format(len(db_train))) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn,) ...
(也就是把
def worker_init_fn
移動到了
def trainer_synapse
的外面,並且把arg.seed改成了作者默認的1234——我只是圖省事,如果需要保留randomseed這個功能,可以自己重新寫參數傳遞方法,比如設一個seedconfig.cfg文件)
總之我是成功跑起來了,還把transunet用到了自己的數據集上,課程設計有着落了,歐耶!