2 (自我拓展)部署花的識別模型(學習tensorflow實戰google深度學習框架)


kaggle競賽的inception模型已經能夠提取圖像很好的特征,后續訓練出一個針對當前圖片數據的全連接層,進行花的識別和分類。這里見書即可,不再贅述。

書中使用google參加Kaggle競賽的inception模型重新訓練一個全連接神經網絡,對五種花進行識別,我姑且命名為模型flower_photos_model。我進一步拓展,將lower_photos_model模型進一步保存,然后部署和應用。然后,我們直接調用遷移之后又訓練好的模型,對花片進行預測。

這里討論兩種方式:使用import_meta_graph和使用saver()

首先,原書的遷移學習的代碼需要做一些改動。

writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph', sess.graph)
saver.save(sess, "Saved_model/flower_photos_model.ckpt")

 Saver()方式

我相較於訓練flower_photos_model模型時,增添了一個變量的定義:

即label_index=tf.argmax(final_tensor,1)

def main():
    #先定義相同的計算圖再加載遷移學習的模型
    bottleneck_input = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE], name='BottleneckInputPlaceholder')
    with tf.name_scope('final_training_ops'):
        weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))
        biases = tf.Variable(tf.zeros([n_classes]))
        logits = tf.matmul(bottleneck_input, weights) + biases
        final_tensor = tf.nn.softmax(logits)
        label_index=tf.argmax(final_tensor,1)
#利用import_meta_graph和import_graph_def加載的變量均不允許與當前定義計算圖有沖突。
#saver = tf.train.Saver()則只加載當前計算圖中定義的。
    saver = tf.train.Saver()
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.700)  
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        saver.restore(sess, "Saved_model/flower_photos_model.ckpt")
        #還是要加載一下inception模型 
        MODEL_DIR = './inception_dec_2015'
        MODEL_FILE= 'tensorflow_inception_graph.pb'
        with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
        print (bottleneck_tensor)
        print (jpeg_data_tensor)
        #為了在tensorboard中觀察加載的計算圖。
        writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph_use', sess.graph)
        writer.close()
        #image_path='./data/xiaojie_application/xiaojie_rose.jpg'
        image_path='./data/xiaojie_application/xiaojie_sunflowers.jpg'
        #image_path='./data/xiaojie_application/5547758_eea9edfd54_n.jpg'
        
        """測試一張圖片,能否獲取瓶頸向量。
        image_data = gfile.FastGFile(image_path, 'rb').read()
        print (sess.run(jpeg_data_tensor,{jpeg_data_tensor:image_data}))
        print ("xiaojie1")
        print (sess.run(bottleneck_tensor,{jpeg_data_tensor:image_data}))
        """
        label_index_value=evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index)
        #print (label_index_value)
        classes=['daisy','dandelion','roses','sunflowers','tulips']
        print ("預測的花的類型:",classes[label_index_value[0]])

相關的函數的定義:

evalution_xiaojie輸出預測的分類index。
def evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index):
#輸出一張圖片的預測結果    bottleneck_values=get_bottleneck_values_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor)
    bottlenecks = []
    bottlenecks.append(bottleneck_values)
    label_index_value = sess.run(label_index, feed_dict={
            bottleneck_input: bottlenecks})
    return label_index_value

獲取瓶頸向量(關於瓶頸向量,見書)

def get_bottleneck_values_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor):
    #瓶頸向量
    if not os.path.exists(CACHE_DIR): os.makedirs(CACHE_DIR)
    bottleneck_path = get_bottleneck_path_xiaojie(CACHE_DIR,image_path)
    print (bottleneck_path)
    if not os.path.exists(bottleneck_path):
        image_data = gfile.FastGFile(image_path, 'rb').read()
        bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
        bottleneck_string = ','.join(str(x) for x in bottleneck_values)
        with open(bottleneck_path, 'w') as bottleneck_file:
            bottleneck_file.write(bottleneck_string)
    else:
        with open(bottleneck_path, 'r') as bottleneck_file:
            bottleneck_string = bottleneck_file.read()
            bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
    return bottleneck_values

使用inception模型計算瓶頸向量

def run_bottleneck_on_image(sess, image_data, image_data_tensor, bottleneck_tensor):
    print("yes")
    bottleneck_values = sess.run(bottleneck_tensor, {image_data_tensor: image_data})
    bottleneck_values = np.squeeze(bottleneck_values)
    print("no")
    return bottleneck_values

瓶頸向量有一個緩存文件,這也是類似於原書訓練遷移學習模型時的做法

def get_bottleneck_path_xiaojie(CACHE_DIR,image_path):
    file_name_suffix=image_path.split('/')[-1]
    file_name_no_suffix=file_name_suffix.split('.')[0]
    bottleneck_file_name=file_name_no_suffix+('_cache.txt')
    bottleneck_path=os.path.join(CACHE_DIR, bottleneck_file_name)
    return bottleneck_path

定義的全局變量

BOTTLENECK_TENSOR_SIZE = 2048
n_classes = 5
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
CACHE_DIR='./data/xiaojie_application/cache_bottleneck/'

Saver方式的說明:

Saver只能導出持久化模型中與當前代碼定義計算圖相匹配的部分。

因此,對於之前inception也需要再一次重新加載。

此外,當前代碼定義計算圖,比持久化模型flower_photos_model多定義了一個變量,即label_index=tf.argmax(final_tensor,1),即輸出預測的分類index。

import_meta_graph方式

import_meta_graph方式與saver方式的不同點在於會導入完整的計算圖,因此當前代碼不能定義和要加載計算圖相互沖突的部分。

相關函數定義的代碼均不變。只將main函數的內容和全局變量改為:

def main():
    #如果使用tf.train.import_meta_graph的話,就會重復加載計算圖。因此,避免重復,當前代碼中不能定義重復的。
     
    #saver = tf.train.Saver()
    saver = tf.train.import_meta_graph("Saved_model/flower_photos_model.ckpt.meta")
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.700)  
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    #with tf.Session() as sess:
        #如果直接使用saver = tf.train.Saver()和restore還原一個model.ckpt文件,是不可能將之前遷移學習那個模型利用import_graph_def加載的inception模型加載進來的。
        saver.restore(sess, "Saved_model/flower_photos_model.ckpt")

        bottleneck_tensor= sess.graph.get_tensor_by_name(import_BOTTLENECK_TENSOR_NAME)
        jpeg_data_tensor = sess.graph.get_tensor_by_name(import_JPEG_DATA_TENSOR_NAME)
        print (bottleneck_tensor)
        print (jpeg_data_tensor)

        writer = tf.summary.FileWriter('./graphs/flower_photos_model_graph_use', sess.graph)
        writer.close()
        image_path='./data/xiaojie_application/xiaojie_rose.jpg'
        #image_path='./data/xiaojie_application/xiaojie_sunflowers.jpg'
        #image_path='./data/xiaojie_application/5547758_eea9edfd54_n.jpg'
        
        """測試一張圖片
        image_data = gfile.FastGFile(image_path, 'rb').read()
        print (sess.run(jpeg_data_tensor,{jpeg_data_tensor:image_data}))
        print ("xiaojie1")
        print (sess.run(bottleneck_tensor,{jpeg_data_tensor:image_data}))
        """
        bottleneck_input= sess.graph.get_tensor_by_name("BottleneckInputPlaceholder:0")
        final_tensor = sess.graph.get_tensor_by_name("final_training_ops/Softmax:0")
        label_index=tf.argmax(final_tensor,1)
        label_index_value=evalution_xiaojie(sess,image_path,jpeg_data_tensor,bottleneck_tensor,bottleneck_input,label_index)
        print (label_index_value)
        classes=['daisy','dandelion','roses','sunflowers','tulips']
        print ("預測的花的類型:",classes[label_index_value[0]]) 

全局變量改為:

import_BOTTLENECK_TENSOR_NAME = 'import/pool_3/_reshape:0'

import_JPEG_DATA_TENSOR_NAME = 'import/DecodeJpeg/contents:0'

這是因為,使用import_meta_graph方式的話,當前代碼不能定義任何與持久化模型中計算圖沖突的節點。此外,在flower_photos_model模型對全連接層進行訓練的過程中,已經利用import_graph_def的方式導入google Inception v3的持久化模型pb文件,因此,已經包括了google的模型。通過在tensorboard中查看,會發現,所有導入的模塊節點之前會帶上import節點。因此,在訓練flower_photos_model模型時,使用的是pool_3/_reshape:0獲取張量,而此時,只能使用import/pool_3/_reshape:0'獲取張量。

只能使用import/pool_3/_reshape:0'獲取張量。

        final_tensor = sess.graph.get_tensor_by_name("final_training_ops/Softmax:0")

然后,我們再定義一個label_index

        label_index=tf.argmax(final_tensor,1)

因此,同saver模型一樣,所有的其它函數接口和實現都不用變。

最后的結果很nice。可以識別五種花朵,可以直接部署應用。

程序附件

鏈接:https://pan.baidu.com/s/11YtyDEyV84jONPi9tO2TCw 密碼:8mfj


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM