修改pytorch官方實例適用於自己的二分類遷移學習項目


本demo從pytorch官方的遷移學習示例修改而來,增加了以下功能:

  1. 根據AUC來迭代最優參數;
  2. 五折交叉驗證;
  3. 輸出驗證集錯誤分類圖片;
  4. 輸出分類報告並保存AUC結果圖片。
      1 import os
      2 import numpy as np
      3 import torch
      4 import torch.nn as nn
      5 from torch.optim import lr_scheduler
      6 import torchvision
      7 from torchvision import datasets, models, transforms
      8 from torch.utils.data import DataLoader
      9 from sklearn.metrics import roc_auc_score, classification_report
     10 from sklearn.model_selection import KFold
     11 from torch.autograd import Variable
     12 import torch.optim as optim
     13 import time
     14 import copy
     15 import shutil
     16 import sys
     17 import scikitplot as skplt
     18 import matplotlib.pyplot as plt
     19 import pandas as pd
     20 
     21 plt.switch_backend('agg')
     22 N_CLASSES = 2
     23 BATCH_SIZE = 8
     24 DATA_DIR = './data'
     25 LABEL_DICT = {0: 'class_1', 1: 'class_2'}
     26 
     27 
     28 def imshow(inp, title=None):
     29     """Imshow for Tensor."""
     30     inp = inp.numpy().transpose((1, 2, 0))
     31     mean = np.array([0.485, 0.456, 0.406])
     32     std = np.array([0.229, 0.224, 0.225])
     33     inp = std * inp + mean
     34     inp = np.clip(inp, 0, 1)
     35     plt.imshow(inp)
     36     if title is not None:
     37         plt.title(title)
     38     plt.pause(100)
     39 
     40 
     41 def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
     42     since = time.time()
     43     # 先深拷貝一份當前模型的參數,后面迭代過程中若遇到更優模型則替換
     44     best_model_wts = copy.deepcopy(model.state_dict())
     45     # best_acc = 0.0
     46     # 初始auc
     47     best_auc = 0.0
     48     best_desc = [0, 0, None]
     49     best_img_name = None
     50     plt_auc = [None, None]
     51 
     52     for epoch in range(num_epochs):
     53         print('Epoch {}/{}'.format(epoch, num_epochs - 1))
     54         print('- ' * 50)
     55 
     56         for phase in ['train', 'val']:
     57             if phase == 'train':
     58                 # 訓練的時候進行學習率規划,其定義在下面給出
     59                 scheduler.step()
     60                 model.train(True)
     61             else:
     62                 model.train(False)
     63             phase_pred = np.array([])
     64             phase_label = np.array([])
     65             img_name = np.zeros((1, 2))
     66             prob_pred = np.zeros((1, 2))
     67             running_loss = 0.0
     68             running_corrects = 0
     69             # 這樣迭代方便跟蹤圖片路徑,輸出錯誤圖片名稱
     70             for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
     71                 inputs, labels = data
     72                 if use_gpu:
     73                     inputs = Variable(inputs.cuda())
     74                     labels = Variable(labels.cuda())
     75                 else:
     76                     inputs, labels = Variable(inputs), Variable(labels)
     77 
     78                 # 梯度參數設為0
     79                 optimizer.zero_grad()
     80 
     81                 # forward
     82                 outputs = model(inputs)
     83                 _, preds = torch.max(outputs.data, 1)
     84                 loss = criterion(outputs, labels)
     85 
     86                 # backward + 訓練階段優化
     87                 if phase == 'train':
     88                     loss.backward()
     89                     optimizer.step()
     90 
     91                 if phase == 'val':
     92                     img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
     93                     prob = outputs.data.cpu().numpy()
     94                     prob_pred = np.append(prob_pred, prob, axis=0)
     95 
     96                 phase_pred = np.append(phase_pred, preds.cpu().numpy())
     97                 phase_label = np.append(phase_label, labels.data.cpu().numpy())
     98                 running_loss += loss.item() * inputs.size(0)
     99                 running_corrects += torch.sum(preds == labels.data).float()
    100             print()
    101             epoch_loss = running_loss / dataset_sizes[phase]
    102             epoch_acc = running_corrects / dataset_sizes[phase]
    103             epoch_auc = roc_auc_score(phase_label, phase_pred)
    104             print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    105                 phase, epoch_loss, epoch_acc, epoch_auc))
    106             report = classification_report(phase_label, phase_pred, target_names=class_names)
    107             print(report)
    108 
    109             img_name = zip(img_name[1:], phase_pred)
    110             # 當驗證時遇到了更好的模型則予以保留
    111             if phase == 'val' and epoch_auc > best_auc:
    112                 best_auc = epoch_auc
    113                 best_desc = epoch_acc, epoch_auc, report
    114                 best_img_name = img_name
    115                 # 深拷貝模型參數
    116                 best_model_wts = copy.deepcopy(model.state_dict())
    117                 plt_auc = phase_label, prob_pred[1:]
    118 
    119         print()
    120     print(plt_auc[0].shape, plt_auc[1].shape)
    121     csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    122     csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    123     csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    124     csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    125     skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    126     plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    127     time_elapsed = time.time() - since
    128     print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    129     reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    130                                                                                          best_desc[2])
    131     report_file.write(reports)
    132     print(reports)
    133     print('List the wrong judgement img ...')
    134     count = 0
    135     for i in best_img_name:
    136         actual_label = int(i[0][1])
    137         pred_label = i[1]
    138         if actual_label != pred_label:
    139             tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    140                        f'pred: {LABEL_DICT[pred_label]}'
    141             print(tmp_word)
    142             label_file.write(tmp_word + '\n')
    143             count += 1
    144     print(f'This fold has {count} wrong records ...')
    145 
    146     # 載入最優模型參數
    147     model.load_state_dict(best_model_wts)
    148     return model
    149 
    150 
    151 def plot_img():
    152     for i, data in enumerate(dataloaders['train']):
    153         inputs, classes = data
    154         out = torchvision.utils.make_grid(inputs)
    155         imshow(out, title=[class_names[x] for x in classes])
    156 
    157 
    158 # 此函數可以修改適用於自己項目的圖片文件名
    159 def move_file(data, file_path, dir_path, root_path):
    160     label_0 = 'class_2'
    161     label_1 = 'class_1'
    162     print(f'start copy the {file_path} file ...')
    163     os.chdir(dir_path)
    164     if os.path.exists(file_path):
    165         print(f'Find exist {file_path} file, the file will be dropped.')
    166         shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    167         print(f'Finish drop the {file_path} file.')
    168 
    169     os.mkdir(file_path)
    170     tmp_path = os.path.join(os.getcwd(), file_path)
    171     tmp_pre_path = os.getcwd()
    172     for d in data:
    173         pre_path = os.path.join(tmp_pre_path, d)
    174         os.chdir(tmp_path)
    175         if d[:2] == label_0:
    176             if not os.path.exists(label_0):
    177                 os.mkdir(label_0)
    178             cur_path = os.path.join(tmp_path, label_0, d)
    179             shutil.copyfile(pre_path, cur_path)
    180         if d[:2] == label_1:
    181             if not os.path.exists(label_1):
    182                 os.mkdir(label_1)
    183             cur_path = os.path.join(tmp_path, label_1, d)
    184             shutil.copyfile(pre_path, cur_path)
    185     print('finish this work ...')
    186 
    187 
    188 if __name__ == "__main__":
    189     if not os.path.exists('roc_img'):
    190         os.mkdir('roc_img')
    191     if not os.path.exists('prob_result'):
    192         os.mkdir('prob_result')
    193     if not os.path.exists('report'):
    194         os.mkdir('report')
    195     if not os.path.exists('error_record'):
    196         os.mkdir('error_record')
    197     if not os.path.exists('model'):
    198         os.mkdir('model')
    199     label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w')
    200 
    201     kf = KFold(n_splits=5, shuffle=True, random_state=1)
    202     origin_path = '/home/project/'
    203     dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))])
    204 
    205     for m, n in enumerate(kf.split(dd_list), start=1):
    206         report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    207         print(f'The {m} fold for copy file and training ...')
    208         move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    209         os.chdir(origin_path)
    210         move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    211         os.chdir(origin_path)
    212         data_transforms = {
    213             'train': transforms.Compose([
    214                 # 裁剪到224,224
    215                 transforms.RandomResizedCrop(224),
    216                 # 隨機水平翻轉給定的PIL.Image,概率為0.5。即:一半的概率翻轉,一半的概率不翻轉。
    217                 transforms.RandomHorizontalFlip(),
    218                 # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),  # HSV以及對比度變化
    219                 transforms.ToTensor(),
    220                 # 把一個取值范圍是[0,255]的PIL.Image或者shape為(H,W,C)的numpy.ndarray,轉換成形狀為[C,H,W],取值范圍是[0,1.0]的FloadTensor
    221                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    222             ]),
    223             'val': transforms.Compose([
    224                 transforms.Resize(256),
    225                 transforms.CenterCrop(224),
    226                 transforms.ToTensor(),
    227                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    228             ]),
    229         }
    230 
    231         image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    232                                                   data_transforms[x])
    233                           for x in ['train', 'val']}
    234         dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    235                                                       shuffle=True, num_workers=8, pin_memory=False)
    236                        for x in ['train', 'val']}
    237 
    238         dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    239 
    240         class_names = image_datasets['train'].classes
    241         size = len(class_names)
    242         print('label mapping: ')
    243         print(image_datasets['train'].class_to_idx)
    244         use_gpu = torch.cuda.is_available()
    245         model_ft = None
    246         if sys.argv[1] == 'resnet':
    247             model_ft = models.resnet50(pretrained=True)
    248             num_ftrs = model_ft.fc.in_features
    249             model_ft.fc = nn.Sequential(
    250                 nn.Linear(num_ftrs, N_CLASSES),
    251                 nn.Sigmoid()
    252             )
    253 
    254         # 這邊可以自行把inception模型加進去
    255         if sys.argv[1] == 'inception':
    256             raise Exception("not provide inception model ...")
    257             # model_ft = models.inception_v3(pretrained=True)
    258 
    259         if sys.argv[1] == 'desnet':
    260             model_ft = models.densenet121(pretrained=True)
    261             num_ftrs = model_ft.classifier.in_features
    262             model_ft.classifier = nn.Sequential(
    263                 nn.Linear(num_ftrs, N_CLASSES),
    264                 nn.Sigmoid()
    265             )
    266             # use_gpu = False
    267 
    268         if use_gpu:
    269             model_ft = model_ft.cuda()
    270 
    271         criterion = nn.CrossEntropyLoss()
    272         optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    273         # 每7個epoch衰減0.1倍
    274         exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    275         model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    276         print('Start save the model ...')
    277         torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    278         print(f'The mission of the fold {m} finished.')
    279         print('# '*50)
    280         report_file.close()
    281     label_file.close()

     


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM