tensorflow MNIST新手教程


官方教程代码如下:

 1 import gzip  
 2 import os  
 3 import tempfile  
 4   
 5 import numpy  
 6 from six.moves import urllib  
 7 from six.moves import xrange  # pylint: disable=redefined-builtin  
 8 import tensorflow as tf  
 9 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
10 
11 mnist = read_data_sets("MNIST_data/", one_hot=True) 
12 
13 x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量  
14 W = tf.Variable(tf.zeros([784,10]))  #权重,初始化值为全零  
15 b = tf.Variable(tf.zeros([10]))  #偏置,初始化值为全零  
16   
17 #进行模型计算,y是预测,y_ 是实际  
18 y = tf.nn.softmax(tf.matmul(x,W) + b)  
19   
20 y_ = tf.placeholder("float", [None,10])  
21   
22 #计算交叉熵  
23 cross_entropy = -tf.reduce_sum(y_*tf.log(y))  
24 #接下来使用BP算法来进行微调,以0.01的学习速率  
25 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  
26   
27 #上面设置好了模型,添加初始化创建变量的操作  
28 init = tf.global_variables_initializer()  
29 #启动创建的模型,并初始化变量  
30 sess = tf.Session()  
31 sess.run(init)  
32 #开始训练模型,循环训练1000次  
33 for i in range(1000):  
34     #随机抓取训练数据中的100个批处理数据点  
35     batch_xs, batch_ys = mnist.train.next_batch(100)  
36     sess.run(train_step, feed_dict={x:batch_xs,y_:batch_ys})  
37       
38 ''''' 进行模型评估 '''  
39   
40 #判断预测标签和实际标签是否匹配  
41 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))   
42 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  
43 #计算所学习到的模型在测试数据集上面的正确率  
44 print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) )

运行出现错误,不能导入数据,解决方案如下:

1..坑爹的GWF,严重阻碍人工智能的发展,与习大大十九大报告背道而驰,只能靠爱国青年曲线救国。方法为修改mnist.py文件(tensorflow.contrib.learn.python.learn.datasets.mnist),SOURCE_URL从 'https://storage.googleapis.com/cvdf-datasets/mnist/'改为 'http://yann.lecun.com/exdb/mnist/',如此就能运行了。

2.手动下载 http://yann.lecun.com/exdb/mnist/,处理代码如下,注意一定要把图片像素除以255,否则正确率只有0.098

 1 import tensorflow as tf
 2 import struct
 3 import numpy as np
 4 
 5 with open('train-labels.idx1-ubyte','rb') as lb:
 6    magic,n=struct.unpack('>II',lb.read(8))
 7    labels = np.fromfile(lb,dtype=np.uint8)
 8 
 9 with open('train-images.idx3-ubyte','rb') as img:
10    magic,num,rows,cols=struct.unpack('>IIII',img.read(16))
11    images = np.fromfile(img,dtype=np.uint8).reshape(-1,784)
12    images = images.astype(np.float32)
13    images = np.multiply(images, 1.0 / 255.0)
14 
15 with open('t10k-labels.idx1-ubyte','rb') as lb:
16    magic,n=struct.unpack('>II',lb.read(8))
17    testlabels = np.fromfile(lb,dtype=np.uint8)
18 
19 with open('t10k-images.idx3-ubyte','rb') as img:
20    magic,num,rows,cols=struct.unpack('>IIII',img.read(16))
21    testimages = np.fromfile(img,dtype=np.uint8).reshape(-1,784)
22    testimages = testimages.astype(np.float32)
23    testimages = np.multiply(testimages, 1.0 / 255.0)
24 
25 x = tf.placeholder(tf.float32, [None, 784])
26 W = tf.Variable(tf.zeros([784, 10]))
27 b = tf.Variable(tf.zeros([10]))
28 y = tf.nn.softmax(tf.matmul(x, W) + b)
29 y_ = tf.placeholder(tf.float32, [None, 10])
30 cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
31 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
32 
33 sess = tf.InteractiveSession()
34 tf.global_variables_initializer().run()
35 
36 labels = sess.run(tf.one_hot(labels,10))
37 testlabels = sess.run(tf.one_hot(testlabels,10))
38 
39 for _ in range(1000):
40   a=np.random.permutation(np.arange(60000))
41   bx = images[a[:100]]
42   by = labels[a[:100]]
43   sess.run(train_step,feed_dict={x:bx,y_:by})
44 
45 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
46 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
47 print(sess.run(accuracy, feed_dict={x:testimages, y_:testlabels}))

执行后结果约为0.92.

另外下列onehot方法非常好用:

def dense_to_one_hot(labels_dense, num_classes):
  """Convert class labels from scalars to one-hot vectors."""
  num_labels = labels_dense.shape[0]
  index_offset = numpy.arange(num_labels) * num_classes
  labels_one_hot = numpy.zeros((num_labels, num_classes))
  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  return labels_one_hot

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM