說明
沒有聯網,先把模型下載下來
先學習怎么推斷,
然后再看怎么進行Dataset Dataloader transform
接着看怎么訓練和評價
軟件和硬件
cuda
查看cuda 版本
whereis nvcc
/usr/local/cuda-10.0/bin/nvcc -V
cat /usr/local/cuda/version.txt
libcudnn.so最終鏈接的文件名,文件名中包含版本號
GPU查看
lspci | grep -i nvidia
nvidia-sm
watch -n 1 nvidia-sm
示例代碼
import torch
import torch.cuda
import torch.nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import cv2
def get_model():
# 加載模型 model_ft = torchvision.models.vgg16(pretrained=False)
model_ft = models.resnet101(pretrained=False)
#model_path ="./models/vgg16-397923af.pth"
model_path ="./models/resnet101-5d3b4d8f.pth"
pre = torch.load(model_path)
model_ft.load_state_dict(pre)
model_ft.cuda()
return model_ft
# # 查看模型結構
# print(model_ft)
# # 查看網絡參數
# for name, parameters in model_ft.named_parameters():
# print(name, ':', parameters.size())
# # 網絡模型的卷積方式以及權重數值
# print("#############-parameters")
# for child in model_ft.children():
# print(child)
# # for param in child.parameters():
# # print(param)
def deal_img(img_path):
"""Transforming images on GPU"""
image = cv2.imread(img_path)
image_new = cv2.resize(image, (224,224))
my_transforms= transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225])
]
)
my_tensor = my_transforms(image_new)
my_tensor = my_tensor.resize_(1,3,224,224)
my_tensor= my_tensor.cuda()
return my_tensor
def cls_inference(cls_model,imgpth):
input_tensor = deal_img(imgpth)
cls_model.eval()
result = cls_model(input_tensor)
result_npy = result.data.cpu().numpy()
max_index = np.argmax(result_npy[0])
return max_index
def feature_extract(cls_model,imgpth):
cls_model.fc = torch.nn.LeakyReLU(0.1)
cls_model.eval()
input_tensor = deal_img(imgpth)
result = cls_model(input_tensor)
result_npy = result.data.cpu().numpy()
return result_npy[0]
if __name__ == "__main__":
image_path="./pytorch/data/train/cat/08.jpg"
model = get_model()
cls_label = cls_inference(model,image_path)
print(cls_label)
feature = feature_extract(model,image_path)
print(feature)
參考
使用pytorch預訓練模型分類與特征提取 https://blog.csdn.net/u010165147/article/details/72829969?spm=1001.2014.3001.5502