Google機器學習筆記(七)TF.Learn 手寫文字識別


轉載請注明作者:夢里風林
Google Machine Learning Recipes 7
官方中文博客 - 視頻地址
Github工程地址 https://github.com/ahangchen/GoogleML
歡迎Star,也歡迎到Issue區討論

mnist問題

  • 計算機視覺領域的Hello world
  • 給定55000個圖片,處理成28*28的二維矩陣,矩陣中每個值表示一個像素點的灰度,作為feature
  • 給定每張圖片對應的字符,作為label,總共有10個label,是一個多分類問題

TensorFlow

  • 可以按教程用Docker安裝,也可以直接在Linux上安裝
  • 你可能會擔心,不用Docker的話怎么開那個notebook呢?其實notebook就在主講人的Github頁
  • 可以用這個Chrome插件:npviewer直接在瀏覽器中閱讀ipynb格式的文件,而不用在本地啟動iPython notebook
  • 我們的教程在這里:ep7.ipynb
  • 把代碼從ipython notebook中整理出來:tflearn_mnist.py

代碼分析

  • 下載數據集
mnist = learn.datasets.load_dataset('mnist')

恩,就是這么簡單,一行代碼下載解壓mnist數據,每個img已經灰度化成長784的數組,每個label已經one-hot成長度10的數組

在我的深度學習筆記看One-hot是什么東西

  • numpy讀取圖像到內存,用於后續操作,包括訓練集(只取前10000個)和驗證集
data = mnist.train.images
labels = np.asarray(mnist.train.labels, dtype=np.int32)
test_data = mnist.test.images
test_labels = np.asarray(mnist.test.labels, dtype=np.int32)
max_examples = 10000
data = data[:max_examples]
labels = labels[:max_examples]
  • 可視化圖像
def display(i):
    img = test_data[i]
    plt.title('Example %d. Label: %d' % (i, test_labels[i]))
    plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
    plt.show()

用matplotlib展示灰度圖

  • 訓練分類器
    • 提取特征(這里每個圖的特征就是784個像素值)
feature_columns = learn.infer_real_valued_columns_from_input(data)
  • 創建線性分類器並訓練
classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)
classifier.fit(data, labels, batch_size=100, steps=1000)

注意要制定n_classes為labels的數量

  • 分類器實際上是在根據每個feature判斷每個label的可能性,
  • 不同的feature有的重要,有的不重要,所以需要設置不同的權重
  • 一開始權重都是隨機的,在fit的過程中,實際上就是在調整權重

  • 最后可能性最高的label就會作為預測輸出

  • 傳入測試集,預測,評估分類效果

result = classifier.evaluate(test_data, test_labels)
print result["accuracy"]

速度非常快,而且准確率達到91.4%

可以只預測某張圖,並查看預測是否跟實際圖形一致

# here's one it gets right
print ("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0]))
display(0)
# and one it gets wrong
print ("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8]))
display(8)
  • 可視化權重以了解分類器的工作原理
weights = classifier.weights_
a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)

  • 這里展示了8個張圖中,每個像素點(也就是feature)的weights,
  • 紅色表示正的權重,藍色表示負的權重
  • 作用越大的像素,它的顏色越深,也就是權重越大
  • 所以權重中紅色部分幾乎展示了正確的數字

Next steps


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM