原因是因為checkpoint設置好的確是保存了相關字段。但是其中設置的train_dataset卻已經走過了epoch輪,當你再繼續訓練時候,train_dataset是從第一個load_data開始。
# -*- coding:utf-8 -*-
import os
import numpy as np
import torch
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from matplotlib import pyplot as plt
import os
from PIL import Image
os.environ ['KMP_DUPLICATE_LIB_OK'] ='True'
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
fmap_block = list()
import torch.nn.functional as F
grad_block = list()
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
torch.manual_seed(1) # 設置隨機種子
rmb_label = {"1": 0, "100": 1}
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 2)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool1(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 參數設置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
checkpoint_interval=5
# ============================ step 1/5 數據 ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
raise Exception(r"數據 {} 不存在, 回到lesson-06\1_split_dataset.py生成數據".format(split_dir))
train_dir = os.path.join(split_dir, "train")
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
net = Net()
criterion = nn.CrossEntropyLoss() # 選擇損失函數
# ============================ step 4/5 優化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 選擇優化器
checkpointdict = torch.load('./checkpoint4.pkl')
net.load_state_dict(checkpointdict["model_state_dict"])
optimizer.load_state_dict(checkpointdict["optimizer_state_dict"])
startepoch = checkpointdict["epoch"]
# ============================ step 5/5 訓練 ============================
train_curve = list()
iter_count = 0
for epoch in range(startepoch+1,MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
for counti in range(6):
for i, data in enumerate(train_loader):
if counti <5:
continue
else:
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 統計分類情況
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印訓練信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# if ((epoch + 1) % checkpoint_interval == 0):
# checkpoint = {"model_state_dict": net.state_dict(),
# "optimizer_state_dict": optimizer.state_dict(),
# "epoch": epoch}
# path_checkpoint = './checkpoint{}.pkl'.format(epoch)
# torch.save(checkpoint, path_checkpoint)
# if ((epoch + 1) % 5 == 0):
# print("退出")
# break