MXNet學習:預測結果-識別單張圖片


用到了model里的FeedForward.load和predict

import os
import mxnet as mx
import numpy as np
import Image
from collections import namedtuple

Batch = namedtuple('Batch',['data'])
synsets = [0,1,2,3,4,5,6,7,8,9]


def predict(img_url,model,synsets):
    img = Image.open(img_url)
    img = img.convert('L')
    img = img.resize((28,28),Image.ANTIALIAS)
    img.save(img_url)
    img = np.asarray(img,dtype=np.uint8)
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    val = mx.io.NDArrayIter(data=img)
    res =  model.predict(X=val)[0]
    for i in range(0,10):
        print "%d: %.2f" % (synsets[i],res[i])


model = mx.model.FeedForward.load('MNIST_MXNet',100)
while(1):
    img_url = raw_input("Enter the img_url: ")
    predict(img_url,model,synsets)

save時用到的是 model.save('MNIST_MXNet',100) 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM