使用pytorch測試單張圖片(test single image with pytorch)


以下代碼實現使用pytorch測試一張圖片

引用文章:

https://www.learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/

from __future__ import print_function, division

from PIL import Image
import torch
from torchvision import transforms
import matplotlib.pyplot as plt


plt.ion()   # interactive mode

# 模型存儲路徑
model_save_path = '/home/guomin/.cache/torch/checkpoints/resnet18-customs-angle.pth'

# ------------------------ 加載數據 --------------------------- #
# Data augmentation and normalization for training
# Just normalization for validation
# 定義預訓練變換
preprocess_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


class_names = ['0', '180', '270', '90']
# 這個順序很重要,要和訓練時候的類名順序一致

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ------------------------ 載入模型並且訓練 --------------------------- #
model = torch.load(model_save_path)
model.eval()
# print(model)

image_PIL = Image.open('image.jpg')
# 
image_tensor = preprocess_transform(image_PIL)
# 以下語句等效於 image_tensor = torch.unsqueeze(image_tensor, 0)
image_tensor.unsqueeze_(0)
# 沒有這句話會報錯
image_tensor = image_tensor.to(device)

out = model(image_tensor)
# 得到預測結果,並且從大到小排序
_, indices = torch.sort(out, descending=True)
# 返回每個預測值的百分數
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100

print([(class_names[idx], percentage[idx].item()) for idx in indices[0][:5]])

 結果返回:

[('270', 99.9299545288086), ('90', 0.06985548883676529), ('0', 0.0001458235055906698), ('180', 4.714601891464554e-05)]

 

注意: 這里,class_names的順序尤為重要,我這個順序的由來是因為我在訓練模型的時候是按照標簽名稱分類圖片,即我把文件名就命名為標簽的名字,然后里面存放這相應的訓練圖片。這樣的話 python 就會根據文件名的第一個首字母的順序排列文件,所以得到 class_names = ['0', '180', '270', '90'] ,因為python讀取字符串的時候不是按照自然序列的來讀,而是按照首字母的大小順序進行讀取。

以下是我的文件結構,大家可以參考下

 或者准確來說,我的class_names的由來使用代碼是:

# 獲取val圖片已得到類別class_names
image_datasets_val = datasets.ImageFolder(os.path.join(data_dir, 'val'), preprocess_transform) # 得到分類的種類名稱
class_names = image_datasets_val.classes

希望大家能好好理解這一部分,加油!

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM