MXNet学习:预测结果-识别单张图片


用到了model里的FeedForward.load和predict

import os
import mxnet as mx
import numpy as np
import Image
from collections import namedtuple

Batch = namedtuple('Batch',['data'])
synsets = [0,1,2,3,4,5,6,7,8,9]


def predict(img_url,model,synsets):
    img = Image.open(img_url)
    img = img.convert('L')
    img = img.resize((28,28),Image.ANTIALIAS)
    img.save(img_url)
    img = np.asarray(img,dtype=np.uint8)
    img = img.reshape(1,1,28,28).astype(np.float32)/255
    val = mx.io.NDArrayIter(data=img)
    res =  model.predict(X=val)[0]
    for i in range(0,10):
        print "%d: %.2f" % (synsets[i],res[i])


model = mx.model.FeedForward.load('MNIST_MXNet',100)
while(1):
    img_url = raw_input("Enter the img_url: ")
    predict(img_url,model,synsets)

save时用到的是 model.save('MNIST_MXNet',100) 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM