tf.keras.Model使用saved_model,自定義輸入輸出signature


環境:tensorflow2.2

使用tf.keras.Model.save保存saved_model格式時,默認的input和output比較通用,input_1, input2, output_1,output_2

自定義輸入輸出名字:

import tensorflow as tf

sigs = [tf.TensorSpec([None,8], tf.float32, name="a"),
        tf.TensorSpec([None,8], tf.float32, name="b"),
        tf.TensorSpec([None,8], tf.float32, name="c")]

class FullyConnectDnnModel(tf.keras.Model):
  def __init__(self, name):
    super().__init__(name=name)
    self.h1 = tf.keras.layers.Dense(1024, activation='relu')
    self.h2 = tf.keras.layers.Dense(512, activation='relu')
    self.h3 = tf.keras.layers.Dense(256, activation='relu')
    self.h4 = tf.keras.layers.Dense(128, activation='relu')
    self.h5 = tf.keras.layers.Dense(1)

  @tf.function(input_signature=[sigs])
  def call(self, emb_layer_list):
    emb_layers = tf.concat(emb_layer_list, axis=1)
    layer1 = self.h1(emb_layers)
    layer2 = self.h2(layer1)
    layer3 = self.h3(layer2)
    layer4 = self.h4(layer3)
    logits = self.h5(layer4)
    predict = tf.nn.sigmoid(logits)
    return {"logits":logits, "predict": predict}

model = FullyConnectDnnModel("test")
emb_layer_list = []
for i in range(3):
    emb_layer_list.append(tf.constant(1.0, shape=[4, 8]))
out = model(emb_layer_list)

model.save("./saved_model")

注意:

①call方法的輸入是個list,那么input_signature的輸入需要是個list[list[tf.TensorSpec]],如果輸入是一個tensor,那么input_signature的輸入是list[tf.TensorSpec],相當於input_signature必須是了list,list里面是什么需要和call的輸入類型對齊.(測試發現tf.2.2版本,keras.Model下的call方法,Input_signature不能傳dict,會報錯)

②call方法可以返回dict,但是官方文檔是這樣寫的.....有誤導性:

 

后面saved_model的文檔又是這樣寫的.....就很氣....:

 

③如果call方法返回的是多個dict,那么signature中的outputs的name將會默認加上output_1, output_2的前綴,主要為了防止兩個dict的key沖突,因此如果想保持outputs的name和我們所設置的相同,那么只能返回一個dict.

保存之后執行:

saved_model_cli show --dir=./saved_model --all

 

設置input_signatures還可以通過調用keras.Model的_set_inputs方法.

參考:

1.https://www.tensorflow.org/api_docs/python/tf/keras/Model

2.https://www.tensorflow.org/guide/saved_model?hl=zh-tw#specifying_signatures_during_export


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM