1 import torch 2 import numpy as np 3 import torchvision 4 import torch.nn as nn 5 6 from torchvision import datasets,transforms,models 7 import matplotlib.pyplot as plt 8 import time 9 import os 10 import copy 11 print("Torchvision Version:",torchvision.__version__) 12 13 data_dir="./hymenoptera_data" 14 batch_size=32 15 input_size=224 16 model_name="resnet" 17 num_classes=2 18 num_epochs=15 19 feature_extract=True 20 data_transforms={ 21 "train":transforms.Compose([ 22 transforms.RandomResizedCrop(input_size), 23 transforms.RandomHorizontalFlip(), 24 transforms.ToTensor(), 25 transforms.Normalize([0.482,0.456,0.406],[0.229,0.224,0.225]) 26 ]), 27 "val":transforms.Compose([ 28 29 transforms.RandomResizedCrop(input_size), 30 transforms.RandomHorizontalFlip(), 31 transforms.ToTensor(), 32 transforms.Normalize([0.482, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 ]), 34 } 35 image_datasets={x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) 36 for x in ["train",'val']} 37 dataloader_dict={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size, 38 shuffle=True)for x in ['train','val']} 39 device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 inputs,labels=next(iter(dataloader_dict["train"])) 41 #print(inputs.shape)#一個batch 42 #print(labels) 43 44 45 #加載resent模型並修改全連接層 46 def set_parameter_requires_grad(model,feature_extract): 47 if feature_extract: 48 for param in model.parameters(): 49 param.requires_grad=False 50 51 def initialize_model(model_name,num_classes,feature_extract,use_pretrained=True): 52 if model_name=="resnet": 53 model_ft=models.resnet18(pretrained=use_pretrained) 54 set_parameter_requires_grad(model_ft,feature_extract) 55 num_ftrs=model_ft.fc.in_features 56 model_ft.fc=nn.Linear(num_ftrs,num_classes) 57 input_size=224 58 else: 59 print("model not implemented") 60 return None,None 61 62 return model_ft,input_size 63 model_ft,input_size=initialize_model(model_name,num_classes,feature_extract,use_pretrained=True) 64 #print(model_ft) 65 print('-'*200) 66 67 68 def train_model(model,dataloaders,loss_fn,optimizer,num_epochs): 69 best_model_wts=copy.deepcopy(model.state_dict) 70 best_acc=0. 71 val_acc_history=[] 72 for epoch in range(num_epochs): 73 for phase in ["train","val"]: 74 running_loss=0. 75 running_corrects=0. 76 if phase=="train": 77 model.train() 78 else: 79 model.eval() 80 81 for inputs,labels in dataloaders[phase]: 82 inputs,labels=inputs.to(device),labels.to(device) 83 84 with torch.autograd.set_grad_enabled(phase=="train"): 85 outputs=model(inputs) 86 loss=loss_fn(outputs,labels) 87 preds=outputs.argmax(dim=1) 88 if phase=="train": 89 optimizer.zero_grad() 90 loss.backward() 91 optimizer.step() 92 running_loss+=loss.item()*inputs.size(0) 93 running_corrects+=torch.sum(preds.view(-1)==labels.view(-1)).item() 94 95 epoch_loss=running_loss/len(dataloaders[phase].dataset) 96 epoch_acc=running_corrects/len(dataloaders[phase].dataset) 97 98 print("Phase{} loss:{}, acc:{}".format(phase,epoch_loss,epoch_acc)) 99 100 if phase=="val" and epoch_acc>best_acc: 101 best_acc=epoch_acc 102 best_model_wts=copy.deepcopy(model.state_dict()) 103 if phase=="val": 104 val_acc_history.append(epoch_acc) 105 model.load_state_dict(best_model_wts) 106 return model,val_acc_history 107 108 model_ft=model_ft.to(device) 109 optimizer=torch.optim.SGD(filter(lambda p: p.requires_grad,model_ft.parameters()), 110 lr=0.001,momentum=0.9) 111 loss_fn=nn.CrossEntropyLoss() 112 print("feature extraction: 我們不再改變訓練模型的參數,而是只更新我們改變過的部分模型參數。" 113 "我們之所以叫它feature extraction是因為我們把預訓練的CNN模型當做一個特征提取模型,利用提取出來的特征做來完成我們的訓練任務。") 114 _,ohist=train_model(model_ft,dataloader_dict,loss_fn,optimizer,num_epochs=num_epochs) 115 116 print("-"*200) 117 118 119 model_scratch,_=initialize_model(model_name,num_classes,feature_extract=False,use_pretrained=False) 120 model_scratch=model_ft.to(device) 121 optimizer=torch.optim.SGD(filter(lambda p: p.requires_grad,model_ft.parameters()), 122 lr=0.001,momentum=0.9) 123 loss_fn=nn.CrossEntropyLoss() 124 print("fine tuning: 從一個預訓練模型開始,我們改變一些模型的架構,然后繼續訓練整個模型的參數。") 125 _,scratch_ohist=train_model(model_ft,dataloader_dict,loss_fn,optimizer,num_epochs=num_epochs) 126 127 plt.title("Accuracy vs. Training Epoch") 128 plt.xlabel("Training Epoch") 129 plt.ylabel("Accuracy") 130 plt.plot(range(1,num_epochs+1),ohist,label="Pretrained") 131 plt.plot(range(1,num_epochs+1),scratch_ohist,label="No_pretrained") 132 plt.ylim((0,1.)) 133 plt.xticks(np.arange(1,num_epochs+1,1.0)) 134 plt.legend() 135 plt.show()