本demo從pytorch官方的遷移學習示例修改而來,增加了以下功能:
- 根據AUC來迭代最優參數;
- 五折交叉驗證;
- 輸出驗證集錯誤分類圖片;
- 輸出分類報告並保存AUC結果圖片。
1 import os 2 import numpy as np 3 import torch 4 import torch.nn as nn 5 from torch.optim import lr_scheduler 6 import torchvision 7 from torchvision import datasets, models, transforms 8 from torch.utils.data import DataLoader 9 from sklearn.metrics import roc_auc_score, classification_report 10 from sklearn.model_selection import KFold 11 from torch.autograd import Variable 12 import torch.optim as optim 13 import time 14 import copy 15 import shutil 16 import sys 17 import scikitplot as skplt 18 import matplotlib.pyplot as plt 19 import pandas as pd 20 21 plt.switch_backend('agg') 22 N_CLASSES = 2 23 BATCH_SIZE = 8 24 DATA_DIR = './data' 25 LABEL_DICT = {0: 'class_1', 1: 'class_2'} 26 27 28 def imshow(inp, title=None): 29 """Imshow for Tensor.""" 30 inp = inp.numpy().transpose((1, 2, 0)) 31 mean = np.array([0.485, 0.456, 0.406]) 32 std = np.array([0.229, 0.224, 0.225]) 33 inp = std * inp + mean 34 inp = np.clip(inp, 0, 1) 35 plt.imshow(inp) 36 if title is not None: 37 plt.title(title) 38 plt.pause(100) 39 40 41 def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25): 42 since = time.time() 43 # 先深拷貝一份當前模型的參數,后面迭代過程中若遇到更優模型則替換 44 best_model_wts = copy.deepcopy(model.state_dict()) 45 # best_acc = 0.0 46 # 初始auc 47 best_auc = 0.0 48 best_desc = [0, 0, None] 49 best_img_name = None 50 plt_auc = [None, None] 51 52 for epoch in range(num_epochs): 53 print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 54 print('- ' * 50) 55 56 for phase in ['train', 'val']: 57 if phase == 'train': 58 # 訓練的時候進行學習率規划,其定義在下面給出 59 scheduler.step() 60 model.train(True) 61 else: 62 model.train(False) 63 phase_pred = np.array([]) 64 phase_label = np.array([]) 65 img_name = np.zeros((1, 2)) 66 prob_pred = np.zeros((1, 2)) 67 running_loss = 0.0 68 running_corrects = 0 69 # 這樣迭代方便跟蹤圖片路徑,輸出錯誤圖片名稱 70 for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler): 71 inputs, labels = data 72 if use_gpu: 73 inputs = Variable(inputs.cuda()) 74 labels = Variable(labels.cuda()) 75 else: 76 inputs, labels = Variable(inputs), Variable(labels) 77 78 # 梯度參數設為0 79 optimizer.zero_grad() 80 81 # forward 82 outputs = model(inputs) 83 _, preds = torch.max(outputs.data, 1) 84 loss = criterion(outputs, labels) 85 86 # backward + 訓練階段優化 87 if phase == 'train': 88 loss.backward() 89 optimizer.step() 90 91 if phase == 'val': 92 img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0) 93 prob = outputs.data.cpu().numpy() 94 prob_pred = np.append(prob_pred, prob, axis=0) 95 96 phase_pred = np.append(phase_pred, preds.cpu().numpy()) 97 phase_label = np.append(phase_label, labels.data.cpu().numpy()) 98 running_loss += loss.item() * inputs.size(0) 99 running_corrects += torch.sum(preds == labels.data).float() 100 print() 101 epoch_loss = running_loss / dataset_sizes[phase] 102 epoch_acc = running_corrects / dataset_sizes[phase] 103 epoch_auc = roc_auc_score(phase_label, phase_pred) 104 print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format( 105 phase, epoch_loss, epoch_acc, epoch_auc)) 106 report = classification_report(phase_label, phase_pred, target_names=class_names) 107 print(report) 108 109 img_name = zip(img_name[1:], phase_pred) 110 # 當驗證時遇到了更好的模型則予以保留 111 if phase == 'val' and epoch_auc > best_auc: 112 best_auc = epoch_auc 113 best_desc = epoch_acc, epoch_auc, report 114 best_img_name = img_name 115 # 深拷貝模型參數 116 best_model_wts = copy.deepcopy(model.state_dict()) 117 plt_auc = phase_label, prob_pred[1:] 118 119 print() 120 print(plt_auc[0].shape, plt_auc[1].shape) 121 csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2']) 122 csv_file['true_label'] = pd.DataFrame(plt_auc[0]) 123 csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x]) 124 csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False) 125 skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class']) 126 plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600) 127 time_elapsed = time.time() - since 128 print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 129 reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1], 130 best_desc[2]) 131 report_file.write(reports) 132 print(reports) 133 print('List the wrong judgement img ...') 134 count = 0 135 for i in best_img_name: 136 actual_label = int(i[0][1]) 137 pred_label = i[1] 138 if actual_label != pred_label: 139 tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \ 140 f'pred: {LABEL_DICT[pred_label]}' 141 print(tmp_word) 142 label_file.write(tmp_word + '\n') 143 count += 1 144 print(f'This fold has {count} wrong records ...') 145 146 # 載入最優模型參數 147 model.load_state_dict(best_model_wts) 148 return model 149 150 151 def plot_img(): 152 for i, data in enumerate(dataloaders['train']): 153 inputs, classes = data 154 out = torchvision.utils.make_grid(inputs) 155 imshow(out, title=[class_names[x] for x in classes]) 156 157 158 # 此函數可以修改適用於自己項目的圖片文件名 159 def move_file(data, file_path, dir_path, root_path): 160 label_0 = 'class_2' 161 label_1 = 'class_1' 162 print(f'start copy the {file_path} file ...') 163 os.chdir(dir_path) 164 if os.path.exists(file_path): 165 print(f'Find exist {file_path} file, the file will be dropped.') 166 shutil.rmtree(os.path.join(root_path, dir_path, file_path)) 167 print(f'Finish drop the {file_path} file.') 168 169 os.mkdir(file_path) 170 tmp_path = os.path.join(os.getcwd(), file_path) 171 tmp_pre_path = os.getcwd() 172 for d in data: 173 pre_path = os.path.join(tmp_pre_path, d) 174 os.chdir(tmp_path) 175 if d[:2] == label_0: 176 if not os.path.exists(label_0): 177 os.mkdir(label_0) 178 cur_path = os.path.join(tmp_path, label_0, d) 179 shutil.copyfile(pre_path, cur_path) 180 if d[:2] == label_1: 181 if not os.path.exists(label_1): 182 os.mkdir(label_1) 183 cur_path = os.path.join(tmp_path, label_1, d) 184 shutil.copyfile(pre_path, cur_path) 185 print('finish this work ...') 186 187 188 if __name__ == "__main__": 189 if not os.path.exists('roc_img'): 190 os.mkdir('roc_img') 191 if not os.path.exists('prob_result'): 192 os.mkdir('prob_result') 193 if not os.path.exists('report'): 194 os.mkdir('report') 195 if not os.path.exists('error_record'): 196 os.mkdir('error_record') 197 if not os.path.exists('model'): 198 os.mkdir('model') 199 label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') 200 201 kf = KFold(n_splits=5, shuffle=True, random_state=1) 202 origin_path = '/home/project/' 203 dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) 204 205 for m, n in enumerate(kf.split(dd_list), start=1): 206 report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w') 207 print(f'The {m} fold for copy file and training ...') 208 move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path) 209 os.chdir(origin_path) 210 move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path) 211 os.chdir(origin_path) 212 data_transforms = { 213 'train': transforms.Compose([ 214 # 裁剪到224,224 215 transforms.RandomResizedCrop(224), 216 # 隨機水平翻轉給定的PIL.Image,概率為0.5。即:一半的概率翻轉,一半的概率不翻轉。 217 transforms.RandomHorizontalFlip(), 218 # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及對比度變化 219 transforms.ToTensor(), 220 # 把一個取值范圍是[0,255]的PIL.Image或者shape為(H,W,C)的numpy.ndarray,轉換成形狀為[C,H,W],取值范圍是[0,1.0]的FloadTensor 221 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 222 ]), 223 'val': transforms.Compose([ 224 transforms.Resize(256), 225 transforms.CenterCrop(224), 226 transforms.ToTensor(), 227 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 228 ]), 229 } 230 231 image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x), 232 data_transforms[x]) 233 for x in ['train', 'val']} 234 dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, 235 shuffle=True, num_workers=8, pin_memory=False) 236 for x in ['train', 'val']} 237 238 dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 239 240 class_names = image_datasets['train'].classes 241 size = len(class_names) 242 print('label mapping: ') 243 print(image_datasets['train'].class_to_idx) 244 use_gpu = torch.cuda.is_available() 245 model_ft = None 246 if sys.argv[1] == 'resnet': 247 model_ft = models.resnet50(pretrained=True) 248 num_ftrs = model_ft.fc.in_features 249 model_ft.fc = nn.Sequential( 250 nn.Linear(num_ftrs, N_CLASSES), 251 nn.Sigmoid() 252 ) 253 254 # 這邊可以自行把inception模型加進去 255 if sys.argv[1] == 'inception': 256 raise Exception("not provide inception model ...") 257 # model_ft = models.inception_v3(pretrained=True) 258 259 if sys.argv[1] == 'desnet': 260 model_ft = models.densenet121(pretrained=True) 261 num_ftrs = model_ft.classifier.in_features 262 model_ft.classifier = nn.Sequential( 263 nn.Linear(num_ftrs, N_CLASSES), 264 nn.Sigmoid() 265 ) 266 # use_gpu = False 267 268 if use_gpu: 269 model_ft = model_ft.cuda() 270 271 criterion = nn.CrossEntropyLoss() 272 optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) 273 # 每7個epoch衰減0.1倍 274 exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) 275 model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25) 276 print('Start save the model ...') 277 torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl') 278 print(f'The mission of the fold {m} finished.') 279 print('# '*50) 280 report_file.close() 281 label_file.close()
