基於tensorflow的MNIST手寫識別


這個例子,是學習tensorflow的人員通常會用到的,也是基本的學習曲線中的一環。我也是!

 

這個例子很簡單,這里,就是簡單的說下,不同的tensorflow版本,相關的接口函數,可能會有不一樣喲。在TensorFlow的中文介紹文檔中的內容,有些可能與你使用的tensorflow的版本不一致了,我這里用到的tensorflow的版本就有這個問題。 另外,還給大家說下,例子中的MNIST所用到的資源圖片,在原始的官網上,估計很多人都下載不到了。我也提供一下下載地址。

 

我的tensorflow的版本信息:

>>> import tensorflow as tf
>>> print tf.VERSION    
1.0.1
>>> print tf.GIT_VERSION
v1.0.0-65-g4763edf-dirty
>>> print tf.COMPILER_VERSION
4.8.4

 

下面,就看看,我參考的中文tensorflow網站的代碼,在自己的環境里,運行的結果。

 1 [root@bogon tensorflow]# python
 2 Python 2.7.5 (default, Nov  6 2016, 00:28:07) 
 3 [GCC 4.8.5 20150623 (Red Hat 4.8.5-11)] on linux2
 4 Type "help", "copyright", "credits" or "license" for more information.
 5 >>> import tensorflow.examples.tutorials.mnist.input_data as input_data
 6 >>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
 7 Traceback (most recent call last):
 8   File "<stdin>", line 1, in <module>
 9   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py", line 211, in read_data_sets
10     SOURCE_URL + TRAIN_IMAGES)
11   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 208, in maybe_download
12     temp_file_name, _ = urlretrieve_with_retry(source_url)
13   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 165, in wrapped_fn
14     return fn(*args, **kwargs)
15   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 190, in urlretrieve_with_retry
16     return urllib.request.urlretrieve(url, filename)
17   File "/usr/lib64/python2.7/urllib.py", line 94, in urlretrieve
18     return _urlopener.retrieve(url, filename, reporthook, data)
19   File "/usr/lib64/python2.7/urllib.py", line 240, in retrieve
20     fp = self.open(url, data)
21   File "/usr/lib64/python2.7/urllib.py", line 203, in open
22     return self.open_unknown_proxy(proxy, fullurl, data)
23   File "/usr/lib64/python2.7/urllib.py", line 222, in open_unknown_proxy
24     raise IOError, ('url error', 'invalid proxy for %s' % type, proxy)
25 IOError: [Errno url error] invalid proxy for http: '10.90.1.101:8080'
26 >>> 
27 >>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
28 Extracting MNIST_data/train-images-idx3-ubyte.gz 29 Extracting MNIST_data/train-labels-idx1-ubyte.gz 30 Extracting MNIST_data/t10k-images-idx3-ubyte.gz 31 Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 32 >>> import tensorflow as tf
33 >>> x = tf.placeholder(tf.float32, [None, 784])
34 >>> W = tf.Variable(tf.zeros([784,10]))
35 >>> b = tf.Variable(tf.zeros([10]))
36 >>> y = tf.nn.softmax(tf.matmul(x,W) + b)
37 >>> y_ = tf.placeholder("float", [None,10])
38 >>> cross_entropy = -tf.reduce_sum(y_*tf.log(y))
39 >>> train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
40 >>> init = tf.initialize_all_variables() 41 WARNING:tensorflow:From <stdin>:1: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
42 Instructions for updating:
43 Use `tf.global_variables_initializer` instead.
44 >>> init = tf.global_variables_initializer() 45 >>> sess = tf.Session()
46 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
47 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
48 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
49 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
50 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
51 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
52 >>> sess.run(init)
53 >>> for i in range(1000):
54 ...   batch_xs, batch_ys = mnist.train.next_batch(100)
55 ...   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
56 ... 
57 >>> correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
58 >>> accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
59 >>> print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
60 0.9088
61 >>> 

上述日志,是我的測試全過程記錄,上面反映的信息有如下幾點:

1. 紅色部分的錯誤,因為我本地機器是通過代理上網的,這個過程中,tensorflow會用urllib進行MNIST的圖片資源的下載,由於網絡問題,資源文件下載失敗。

2. 都有哪些資源文件要下載呢?追蹤日志中的文件/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py第211行前后:

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):
  if fake_data:

    def fake():
      return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

 TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   SOURCE_URL + TRAIN_IMAGES)
  with open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   SOURCE_URL + TRAIN_LABELS)
  with open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   SOURCE_URL + TEST_IMAGES)
  with open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   SOURCE_URL + TEST_LABELS)
  with open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError(
        'Validation size should be between 0 and {}. Received: {}.'
        .format(len(train_images), validation_size))

  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

  train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
  validation = DataSet(validation_images,
                       validation_labels,
                       dtype=dtype,
                       reshape=reshape)
  test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)

  return base.Datasets(train=train, validation=validation, test=test)

看到上面紅色的部分,就是這里需要下載的圖片資源文件。這個,我的網絡環境是下載不了的。我通過其他途徑下載到了這里需要的資源。我將下載的圖片資源,放在了我進入python時所在的路徑下。雖然直接下載沒有成功,但是在當前路徑下還是創建了MNIST_data的目錄的。如下圖,紅色圈目錄就是程序創建的目錄。我將下載的train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz放在MNIST_data目錄了

然后,再次執行mnist = input_data.read_data_sets("MNIST_data/", one_hot=True),就ok了,不會報錯。得到28-31行的輸出信息。

3. 執行到第40行的代碼時,爆出WARNING,提示用新的函數,按照提示信息,執行了第41行的代碼,OK。說明版本兼容性,在tensorflow中需要注意

4. 執行后,得到結果,如60行顯示,識別率為0.9088。

 

關於MNIST的這個例子的手寫識別性能的理論,不是本博文的重點,讀者可以參照MNIST相關的文章自行學習。

最后,附上MNIST這個例子中,用到的資源圖片下載地址,點擊進行下載。(說明:需要積分才能下載的,諒解)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM