state_dict()函數可以返回所有的狀態數據。load_state_dict()函數可以加載這些狀態數據。
推薦使用:
#保存 t.save(net.state_dict(),"net.pth") #加載 net2=Net() net2.load_state_dict(t.load("net.pth"))
不推薦直接save與load,因為這種方式嚴重依賴模型定義方法以及文件路徑結構等,容易出問題。
t.save(net,"net.pth") net2=t.load("net.pth")
【PyTorch中已封裝的網絡模型】https://pytorch.org/docs/stable/torchvision/index.html
從上圖看出,有針對分類問題、語義分割、目標識別、視頻分類的模型。
以分類模型為例,PyTorch中已封裝的模型如下:
使用方式,參考標黃部分
######################################## 1、使用torchvision加載並預處理數據集 #### datasets的ImageFolder讀圖 from torchvision.datasets import ImageFolder dataset=ImageFolder("E:/data/dogcat_2/train/") #獲取路徑,返回的是所有圖的data、label from torchvision import transforms as T #設置格式化條件 transform=T.Compose([T.Resize((64,64)), T.ToTensor(), #PIL Image轉Tensor,[0,255]自動歸一化為[0,1] T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #標准化,減均值除標准差 ]) dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform) testset=ImageFolder("E:/data/dogcat_2/test/",transform=transform) #### DataLoader from torch.utils.data import DataLoader dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 ) testloader=DataLoader(testset,batch_size=4,shuffle=True,num_workers=2) #### 顯示第1個batch的4幅圖(隨機) from torchvision.transforms import ToPILImage from torchvision.utils import make_grid dataiter = iter(dataloader) (images, labels) = dataiter.next() print(labels) #打印標簽 show=ToPILImage() show(make_grid(images*0.5+0.5)).resize((4*64,64)) ######################################## 2、定義網絡 from torchvision import models net=models.alexnet() ######################################## 3、定義損失函數和優化器 import torch.nn as nn from torch import optim criterion=nn.CrossEntropyLoss() #交叉熵損失函數 optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #隨機梯度下降法,指定要調整的參數和學習率,動量算法加速更新權重 ######################################## 4、訓練網絡並更新網絡參數 for epoch in range(2): # 在整個數據集上輪番訓練多次,輪訓一次叫一個回合(epoch) running_loss = 0.0 for i, data in enumerate(dataloader, 0): # 輸入數據 inputs, labels = data # 梯度清零 optimizer.zero_grad() # forward + backward outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() #更新參數 optimizer.step() # 打印一些關於訓練的統計信息 running_loss += loss.item() if i % 200 == 199: # 每 200 個batch打印一次 print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200)) running_loss = 0.0 print('Finished Training') ######################################## 5、測試網絡 import torchvision as tv import torch as t #datasets測試集中前4幅圖,並輸出標簽 dataiter = iter(testloader) (images, labels) = dataiter.next() #返回1個batch(4張圖) # 輸出圖像和正確的類標簽 #print('實際的label:', ' '.join('%5s' % classes[labels[j]] for j in range(4))) show(tv.utils.make_grid((images+1)/2)).resize((400,100)) #測試 outputs = net(images) #預測上邊得到的batch(4張圖),返回得分(每一類都打分) _, predicted = t.max(outputs, 1) #每1張圖得分最高的那個類的下標 print(outputs) print(predicted) #print('預測結果:', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) show(tv.utils.make_grid((images+1)/2)).resize((400,100)) #測試整個測試集 correct = 0 #預測正確的圖片數 total = 0 #總共的圖片數 with t.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = t.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('10000張測試集中的准確率: %d %%' % (100 * correct / total))