Keras貓狗大戰三:加載模型,預測目錄中圖片,畫混淆矩陣


版權聲明:本文為博主原創文章,歡迎轉載,並請注明出處。聯系方式:460356155@qq.com

 一、加載模型,預測測試集

%matplotlib inline
import matplotlib.pyplot as plt

import os
import itertools
import cv2

import numpy as np
from sklearn.metrics import confusion_matrix

from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model

dst_path = r'D:\BaiduNetdiskDownload\small'
model_file = r"D:\fastai\projects\cats_and_dogs_small_1.h5"
test_dir = os.path.join(dst_path, 'test')

batch_size = 20

model = load_model(model_file)

test_datagen = ImageDataGenerator(rescale=1. / 255)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=batch_size,
    class_mode='binary')

test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator.samples / batch_size)
print('test acc: %.3f%%' % test_acc)
Found 400 images belonging to 2 classes.
test acc: 0.747%

二、預測測試集,畫混淆矩陣
def get_input_xy(src=[]):
    pre_x = []
    true_y = []

    class_indices = {'cat': 0, 'dog': 1}

    for s in src:
        input = cv2.imread(s)
        input = cv2.resize(input, (150, 150))
        input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
        pre_x.append(input)

        _, fn = os.path.split(s)
        y = class_indices.get(fn[:3])
        true_y.append(y)

    pre_x = np.array(pre_x) / 255.0

    return pre_x, true_y


def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predict label')


test = os.listdir(test_dir)

images = []

# 獲取每張圖片的地址,並保存在列表images中
for testpath in test:
    for fn in os.listdir(os.path.join(test_dir, testpath)):
        if fn.endswith('jpg'):
            fd = os.path.join(test_dir, testpath, fn)
            images.append(fd)

# 得到規范化圖片及true label
pre_x, true_y = get_input_xy(images)

# 預測
pred_y = model.predict_classes(pre_x)

# 畫混淆矩陣
confusion_mat = confusion_matrix(true_y, pred_y)
plot_sonfusion_matrix(confusion_mat, classes=range(2))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM