最近開始整理一下tensorflow,准備出一個tensorflow實戰系列,以饗讀者。
學習一個深度學習框架,一般遵循這樣的思路:數據如何讀取,如如何從圖片和標簽數據中讀出成tensorflow可以使用的數據,其次是如何搭建網絡,然后就是如何訓練模型,保存模型,使用模型。最后就是可視化了。
tensorflow上開發了很多有用的包:如tensorlayers,tflearns,slim等,這些包可以讓你很方便的構建網絡模型。
入門系列你可以直接按照tensorflow的官方文檔來跑就可以了。咱就不贅敘了。
實戰第一步,我們開始構建tensorflow的數據集。
tensorflow可以讀取很多種數據,1直接從磁盤上讀取jpg文件,這個比較費時間。2讀取csv格式的數據。這個我沒有深挖。3讀取bin格式的數據,它的例子中就有是讀取已經保存的bin文件的,在models/image文件夾下的一個例子。4tfrecords方法。這個方法比較方便,也是tensorflow的默認文件格式。
就用這個第四種方法了。
直接上存的代碼:
def createtraindata():
cwd='/home/xxx/data/imagedata/'
classes={'bird','dog','person'}
writer = tf.python_io.TFRecordWriter("train.tfrecords")//保存的tfrecord的文件名是train.tfrecords
for index, name in enumerate(classes):
class_path = cwd + name + "/"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((224, 224))
img_raw = img.tobytes() #將圖片轉化為原生bytes
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString()) #序列化為字符串
writer.close()
代碼不難,就是一些平常的python操作。這個是我跑通了的。如果有問題請留言