在學習和使用深度學習框架時,復現現有項目代碼是必經之路,也能加深對理論知識的理解,提高動手能力。本文參照相關博客整理項目常用組織方式,以及每部分功能,幫助更好的理解復現項目流程,文末提供分類示例項目。
1 項目組織
在做深度學習實驗或項目時,為了得到最優的模型結果,中間往往需要很多次的嘗試和修改。一般項目都包含以下幾個部分:
- 模型定義
- 數據處理和加載
- 訓練模型(Train&Validate)
- 訓練過程的可視化
- 測試(Test/Inference)
另外程序在組織過程中還應該滿足以下幾個要求:
- 模型需具有高度可配置性,便於修改參數、修改模型,反復實驗
- 代碼應具有良好的組織結構,使人一目了然
- 代碼應具有良好的說明,使其他人能夠理解
2 項目結構
- checkpoints/: 用於保存訓練好的模型,可使程序在異常退出后仍能重新載入模型,恢復訓練
- data/:數據相關操作,包括數據預處理、dataset實現等
- models/:模型定義,可以有多個模型,例如上面的AlexNet和ResNet34,一個模型對應一個文件
- utils/:可能用到的工具函數,在本次實驗中主要是封裝了可視化工具
- config.py:配置文件,所有可配置的變量都集中在此,並提供默認值
- main.py:主文件,訓練和測試程序的入口,可通過不同的命令來指定不同的操作和參數
- requirements.txt:程序依賴的第三方庫
- README.md:提供程序的必要說明
3 解析
3.1 __init__
- __init__ 可以為空,也可以定義包的屬性和方法,但必須存在,其他程序才能從這個目錄中讀取模塊和函數
3.2 數據加載
使用Dataset提供數據集的封裝,再使用Dataloader實現數據並行加載。
- def __init__(self..)
獲取圖片地址,並根據訓練、驗證和測試划分數據
- def __getitem__(self, index):
返回圖片的數據和label
- def __len__(self):
返回數據集數量
train_dataset = DogCat(opt.train_data_root, train=True) trainloader = DataLoader(train_dataset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers) for ii, (data, label) in enumerate(trainloader): train()
3.3 模型定義
型的定義主要保存在models/目錄下,其中BasicModule是對nn.Module的簡易封裝,提供快速加載和保存模型的接口。
nn.Module主要包括save和load兩個方法
from models import AlexNet
關於模型定義:
- 盡量使用nn.Sequential(比如AlexNet)
- 將經常使用的結構封裝成子Module(比如GoogLeNet的Inception結構,ResNet的Residual Block結構)
- 將重復且有規律性的結構,用函數生成(比如VGG的多種變體,ResNet多種變體都是由多個重復卷積層組成)
3.4 工具函數
可能會用到一些helper方法,這些方法可以統一放在utils/文件夾下,需要使用時再引入。在本例中主要是封裝了可視化工具visdom的一些操作,
3.5 配置文件
可配置的參數主要包括:
數據集參數(文件路徑、batch_size等)
訓練參數(學習率、訓練epoch等)
模型參數
在實際使用時,並不需要每次都修改config.py,只需要通過命令行傳入所需參數,覆蓋默認配置即可。
3.6 main函數
提到了fire
main中包括train、val、test、help等
訓練的主要步驟如下:
- 定義網絡
- 定義數據
- 定義損失函數和優化器
- 計算重要指標
- 開始訓練
- 訓練網絡
- 可視化各種指標
- 計算在驗證集上的指標
4 示例分類代碼
#coding:utf8 from config import opt import os import torch as t import models from data.dataset import DogCat from torch.utils.data import DataLoader from torch.autograd import Variable from torchnet import meter from utils.visualize import Visualizer from tqdm import tqdm def test(**kwargs): opt.parse(kwargs) import ipdb; ipdb.set_trace() # configure model model = getattr(models, opt.model)().eval() if opt.load_model_path: model.load(opt.load_model_path) if opt.use_gpu: model.cuda() # data train_data = DogCat(opt.test_data_root,test=True) test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers) results = [] for ii,(data,path) in enumerate(test_dataloader): input = t.autograd.Variable(data,volatile = True) if opt.use_gpu: input = input.cuda() score = model(input) probability = t.nn.functional.softmax(score)[:,0].data.tolist() # label = score.max(dim = 1)[1].data.tolist() batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ] results += batch_results write_csv(results,opt.result_file) return results def write_csv(results,file_name): import csv with open(file_name,'w') as f: writer = csv.writer(f) writer.writerow(['id','label']) writer.writerows(results) def train(**kwargs): opt.parse(kwargs) vis = Visualizer(opt.env) # step1: configure model model = getattr(models, opt.model)() if opt.load_model_path: model.load(opt.load_model_path) if opt.use_gpu: model.cuda() # step2: data train_data = DogCat(opt.train_data_root,train=True) val_data = DogCat(opt.train_data_root,train=False) train_dataloader = DataLoader(train_data,opt.batch_size, shuffle=True,num_workers=opt.num_workers) val_dataloader = DataLoader(val_data,opt.batch_size, shuffle=False,num_workers=opt.num_workers) # step3: criterion and optimizer criterion = t.nn.CrossEntropyLoss() lr = opt.lr optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay) # step4: meters loss_meter = meter.AverageValueMeter() confusion_matrix = meter.ConfusionMeter(2) previous_loss = 1e100 # train for epoch in range(opt.max_epoch): loss_meter.reset() confusion_matrix.reset() for ii,(data,label) in tqdm(enumerate(train_dataloader),total=len(train_data)): # train model input = Variable(data) target = Variable(label) if opt.use_gpu: input = input.cuda() target = target.cuda() optimizer.zero_grad() score = model(input) loss = criterion(score,target) loss.backward() optimizer.step() # meters update and visualize loss_meter.add(loss.data[0]) confusion_matrix.add(score.data, target.data) if ii%opt.print_freq==opt.print_freq-1: vis.plot('loss', loss_meter.value()[0]) # 進入debug模式 if os.path.exists(opt.debug_file): import ipdb; ipdb.set_trace() model.save() # validate and visualize val_cm,val_accuracy = val(model,val_dataloader) vis.plot('val_accuracy',val_accuracy) vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format( epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr)) # update learning rate if loss_meter.value()[0] > previous_loss: lr = lr * opt.lr_decay # 第二種降低學習率的方法:不會有moment等信息的丟失 for param_group in optimizer.param_groups: param_group['lr'] = lr previous_loss = loss_meter.value()[0] def val(model,dataloader): ''' 計算模型在驗證集上的准確率等信息 ''' model.eval() confusion_matrix = meter.ConfusionMeter(2) for ii, data in enumerate(dataloader): input, label = data val_input = Variable(input, volatile=True) val_label = Variable(label.type(t.LongTensor), volatile=True) if opt.use_gpu: val_input = val_input.cuda() val_label = val_label.cuda() score = model(val_input) confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor)) model.train() cm_value = confusion_matrix.value() accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum()) return confusion_matrix, accuracy def help(): ''' 打印幫助的信息: python file.py help ''' print(''' usage : python file.py <function> [--args=value] <function> := train | test | help example: python {0} train --env='env0701' --lr=0.01 python {0} test --dataset='path/to/dataset/root/' python {0} help avaiable args:'''.format(__file__)) from inspect import getsource source = (getsource(opt.__class__)) print(source) if __name__=='__main__': import fire fire.Fire()
參考:https://github.com/chenyuntc/pytorch-best-practice/blob/master/PyTorch%E5%AE%9E%E6%88%98%E6%8C%87%E5%8D%97.md
