【GiantPandaCV導語】這學期參加了一個比賽,有比較大的代碼量,在這個過程中暴露出來很多問題。由於實驗記錄很糟糕,導致結果非常混亂、無法進行有效分析,也沒能進行有效的回溯。趁比賽完結,打算重構一下代碼,順便參考一些大型項目的管理方法。本文將總結如何高效、標准化管理深度學習實驗。以下總結偏個人,可能不適宜所有項目,僅供參考。
1. 目前的管理方法
因為有很多需要嘗試的想法,但是又按照下圖這種時間格式來命名文件夾,保存權重。每次運行嘗試的方法只是記錄在本子上和有道雲筆記上。

筆記截圖:

總體來說,這種管理方法不是很理想。一個實驗運行的時間比較久,跨度很久,而之前調的參數、修改的核心代碼、想要驗證的想法都已經很模糊了,甚至有些時候可能看到一組實驗跑完了,忘記了這個實驗想要驗證什么。
這樣的實驗管理是低效的,筆者之前就了解到很多實驗管理的方法、庫的模塊化設計,但這些方法都沉寂在收藏夾中,無用武之地。趁着這次比賽結束,好好對代碼進行重構、完善實驗管理方法、總結經驗教訓。同時也參考了交流群里蔣神、雪神等大佬的建議,總結了以下方法。
2. 大型項目實例
先推薦一個模板,是L1aoXingyu@Github分享的模板項目,鏈接如下:
https://github.com/L1aoXingyu/Deep-Learning-Project-Template
如果長期維護一個深度學習項目,代碼的組織就比較重要了。如何設計一個簡單而可擴展的結構是非常重要的。這就需要用到軟件工程中的OOP設計

簡單介紹一下:
- 實驗配置的管理(實驗配置就是深度學習實驗中的各種參數) 
          
- 使用yacs管理配置。
 - 配置文件一般分默認配置(default)和新增配置(argparse)
 
 - 模型的管理 
          
- 使用工廠模式,根據傳入參數得到對應模型。
 
 
├──  config
│    └── defaults.py  - here's the default config file.
│
│
├──  configs  
│    └── train_mnist_softmax.yml  - here's the specific config file for specific model or dataset.
│ 
│
├──  data  
│    └── datasets  - here's the datasets folder that is responsible for all data handling.
│    └── transforms  - here's the data preprocess folder that is responsible for all data augmentation.
│    └── build.py  		   - here's the file to make dataloader.
│    └── collate_batch.py   - here's the file that is responsible for merges a list of samples to form a mini-batch.
│
│
├──  engine
│   ├── trainer.py     - this file contains the train loops.
│   └── inference.py   - this file contains the inference process.
│
│
├── layers              - this folder contains any customed layers of your project.
│   └── conv_layer.py
│
│
├── modeling            - this folder contains any model of your project.
│   └── example_model.py
│
│
├── solver             - this folder contains optimizer of your project.
│   └── build.py
│   └── lr_scheduler.py
│   
│ 
├──  tools                - here's the train/test model of your project.
│    └── train_net.py  - here's an example of train model that is responsible for the whole pipeline.
│ 
│ 
└── utils
│    ├── logger.py
│    └── any_other_utils_you_need
│ 
│ 
└── tests					- this foler contains unit test of your project.
     ├── test_data_sampler.py
 
         
         
        另外推薦一個封裝的非常完善的庫,deep-person-reid, 鏈接:https://github.com/KaiyangZhou/deep-person-reid,這次總結中有一部分代碼參考自以上模型庫。
3. 熟悉工具
與上邊推薦的模板庫不同,個人覺得可以進行簡化處理,主要用到的python工具有:
- argparse
 - yaml
 - logging
 
前兩個用於管理配置,最后一個用於管理日志。
3.1 argparse
argparse是命令行解析工具,分為四個步驟:
-  
import argparse
 -  
parser = argparse.ArgumentParser()
 -  
parser.add_argument()
 -  
parser.parse_args()
 
第2步創建了一個對象,第3步為這個對象添加參數。
parser.add_argument('--batch_size', type=int, default=2048,
                    help='batch size')  # 8192
parser.add_argument('--save_dir', type=str,
                    help="save exp floder name", default="exp1_sandwich")
 
         
         
        --batch_size將作為參數的key,它對應的value是通過解析命令行(或者默認)得到的。type可以選擇int,str。
parser.add_argument('--finetune', action='store_true',
                    help='finetune model with distill')
 
        action可以指定參數處理方式,默認是“store”代表存儲的意思。如果使用"store_true", 表示他出現,那么對應參數為true,否則為false。
第4步,解析parser對象,得到的是可以通過參數訪問的對象。比如可以通過args.finetune 得到finetune的參數值。
3.2 yaml
yaml是可讀的數據序列化語言,常用於配置文件。
支持類型有:
- 標量(字符串、證書、浮點)
 - 列表
 - 關聯數組 字典
 
語法特點:
- 大小寫敏感
 - 縮進表示層級關系
 - 列表通過 "-" 表示,字典通過 ":"表示
 - 注釋使用 "#"
 
安裝用命令:
pip install pyyaml
 
        舉個例子:
name: tosan
age: 22
skill:
  name1: coding
  time: 2years
job:
  - name2: JD
    pay: 2k
  - name3: HW
    pay: 4k
 
        注意:關鍵字不能重復;不能使用tab,必須使用空格。
處理的腳本:
import yaml 
f = open("configs/test.yml", "r")
y = yaml.load(f)
print(y)
 
        輸出結果:
YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.
  y = yaml.load(f)
{'name': 'tosan', 'age': 22, 'skill': {'name1': 'coding', 'time': '2years'}, 'job': [{'name2': 'JD', 'pay': '2k'}, {'name3': 'HW', 'pay': '4k'}]}
 
        這個警告取消方法是:添加默認loader
import yaml 
f = open("configs/test.yml", "r")
y = yaml.load(f, Loader=yaml.FullLoader)
print(y)
 
        保存:
content_dict = {
	'name':"ch",
}
f = open("./config.yml","w")
print(yaml.dump(content_dict, f))
 
        支持的類型:
# 支持數字,整形、float
pi: 3.14 
# 支持布爾變量
islist: true
isdict: false
# 支持None 
cash: ~
# 時間日期采用ISO8601
time1: 2021-6-9 21:59:43.10-05:00
#強制轉化類型
int_to_str: !!str 123
bool_to_str: !!str true
# 支持list
- 1
- 2
- 3
# 復合list和dict
test2:
  - name: xxx
    attr1: sunny
    attr2: rainy
    attr3: cloudy
 
        3.3 logging
日志對程序執行情況的排查非常重要,通過日志文件,可以快速定位出現的問題。本文將簡單介紹使用logging生成日志的方法。
logging模塊介紹
logging是python自帶的包,一共有五個level:
- debug: 查看程序運行的信息,調試過程中需要使用。
 - info: 程序是否如預期執行的信息。
 - warn: 警告信息,但不影響程序執行。
 - error: 出現錯誤,影響程序執行。
 - critical: 嚴重錯誤
 
logging用法
import logging
logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
logging.info("program start")
 
        format參數設置了時間,規定了輸出的格式。
import logging
 #先聲明一個 Logger 對象
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
#然后指定其對應的 Handler 為 FileHandler 對象
handler = logging.FileHandler('Alibaba.log')
#然后 Handler 對象單獨指定了 Formatter 對象單獨配置輸出格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
 
        Filehandler是用於將日志寫入到文件,如這里將所有日志輸出到Alibaba.log文件夾中。
3.4 補充argparse和yaml的配合
# process argparse & yaml
if not args.config:
    opt = vars(args)
    args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    opt.update(args)
    args = opt
else:  # yaml priority is higher than args
    opt = yaml.load(open(args.config), Loader=yaml.FullLoader)
    opt.update(vars(args))
    args = argparse.Namespace(**opt)
 
        4. 實驗管理
實驗的完整記錄需要以下幾方面內容:
- 日志文件:記錄運行全過程的日志。
 - 權重文件:運行過程中保存的checkpoint。
 - 可視化文件:tensorboard中運行得到的文件。
 - 配置文件:詳細記錄當前運行的配置(調參必備)。
 - 文件備份:用於保存當前版本的代碼,可以用於回滾。
 
那么按照以下方式進行組織:
exp
	- 實驗名+日期
		- runs: tensorboard保存的文件
		- weights: 權重文件
		- config.yml: 配置文件
		- scripts: 核心文件備份
			- train.py
			- xxxxxxxx
 
        代碼實現:
import logging
import argparse
import yaml 
parser = argparse.ArgumentParser("ResNet20-cifar100")
parser.add_argument('--batch_size', type=int, default=2048,
                    help='batch size')  # 8192
parser.add_argument('--learning_rate', type=float,
                    default=0.1, help='init learning rate')  parser.add_argument('--config', help="configuration file",
                    type=str, default="configs/meta.yml")
parser.add_argument('--save_dir', type=str,
                    help="save exp floder name", default="exp1")
args = parser.parse_args()
# process argparse & yaml
if not args.config:
    opt = vars(args)
    args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    opt.update(args)
    args = opt
else:  # yaml priority is higher than args
    opt = yaml.load(open(args.config), Loader=yaml.FullLoader)
    opt.update(vars(args))
    args = argparse.Namespace(**opt)
args.exp_name = args.save_dir + "_" + datetime.datetime.now().strftime("%mM_%dD_%HH") + "_" + \
    "{:04d}".format(random.randint(0, 1000))
# 文件處理
if not os.path.exists(os.path.join("exp", args.exp_name)):
    os.makedirs(os.path.join("exp", args.exp_name))
# 日志文件
log_format = "%(asctime)s %(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt="%m/%d %I:%M:%S %p")
fh = logging.FileHandler(os.path.join("exp", args.exp_name, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logging.info(args)
# 配置文件
with open(os.path.join("exp", args.exp_name, "config.yml"), "w") as f:
    yaml.dump(args, f)
# Tensorboard文件
writer = SummaryWriter("exp/%s/runs/%s-%05d" %
                       (args.exp_name, time.strftime("%m-%d", time.localtime()), random.randint(0, 100)))
# 文件備份
create_exp_dir(os.path.join("exp", args.exp_name),
               scripts_to_save=glob.glob('*.py'))
def create_exp_dir(path, scripts_to_save=None):
    if not os.path.exists(path):
        os.mkdir(path)
    print('Experiment dir : {}'.format(path))
    if scripts_to_save is not None:
        if not os.path.exists(os.path.join(path, 'scripts')):
            os.mkdir(os.path.join(path, 'scripts'))
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'scripts', os.path.basename(script))
            shutil.copyfile(script, dst_file)
 
        5. 結果

6. 參考文獻
https://github.com/L1aoXingyu/Deep-Learning-Project-Template
https://sungwookyoo.github.io/tips/ArgParser/
https://github.com/KaiyangZhou/deep-person-reid
https://www.cnblogs.com/pprp/p/10624655.html
