tfserving模型部署見:https://www.cnblogs.com/bincoding/p/13266685.html
demo代碼:https://github.com/haibincoder/tf_tools
對應restful入參:
{
"inputs": {
"input": [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]],
"drop_out": 1,
"sequence_length": [26]
},
"signature_name":"predict"
}
python代碼:
from grpc.beta import implementations
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
# 獲取stub
channel = implementations.insecure_channel('localhost', 8500)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel._channel)
# 模型簽名
request = predict_pb2.PredictRequest()
request.model_spec.name = 'ner'
# request.model_spec.version = 'latest'
request.model_spec.signature_name = 'predict'
# 構造入參
x_data = [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]]
drop_out = 1
sequence_length = [26]
request.inputs['input'].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
request.inputs['sequence_length'].CopyFrom(tf.make_tensor_proto(sequence_length, dtype=tf.int32))
request.inputs['drop_out'].CopyFrom(tf.make_tensor_proto(drop_out, dtype=tf.float32))
# 返回CRF結果,輸出發射概率矩陣和狀態轉移概率矩陣
result = stub.Predict(request, 10.0) # 10 secs timeout
print(result)
java pom:
<dependencies>
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
<version>1.14.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>1.14.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
<version>1.14.0</version>
</dependency>
</dependencies>
java代碼:
public static void main(String[] args) {
ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 8500).usePlaintext(true).build();
// 這里使用block模式
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
// 創建請求
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
// 模型名稱和模型方法名預設
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName("ner");
modelSpecBuilder.setSignatureName("predict");
predictRequestBuilder.setModelSpec(modelSpecBuilder);
// 設置入參,訪問默認是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
List<Float> input = Arrays.asList(13f, 45f, 13f, 13f, 49f, 1f, 49f, 196f, 594f, 905f, 48f, 231f, 318f, 712f, 1003f, 477f, 259f, 291f, 287f, 161f, 65f, 62f, 82f, 68f, 2f, 10f);
TensorProto.Builder inputTensorProto = TensorProto.newBuilder();
inputTensorProto.setDtype(DataType.DT_INT32);
inputTensorProto.addAllFloatVal(input); # !!! INT64需使用addAllInt64Val
TensorShapeProto.Builder inputShapeBuilder = TensorShapeProto.newBuilder();
inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(input.size()));
inputTensorProto.setTensorShape(inputShapeBuilder.build());
int dropout = 1;
TensorProto.Builder dropoutTensorProto = TensorProto.newBuilder();
dropoutTensorProto.setDtype(DataType.DT_FLOAT);
dropoutTensorProto.addIntVal(dropout);
TensorShapeProto.Builder dropoutShapeBuilder = TensorShapeProto.newBuilder();
dropoutShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
dropoutTensorProto.setTensorShape(dropoutShapeBuilder.build());
List<Integer> seqLength = Collections.singletonList(26);
TensorProto.Builder seqLengthTensorProto = TensorProto.newBuilder();
seqLengthTensorProto.setDtype(DataType.DT_INT32);
seqLengthTensorProto.addAllIntVal(seqLength);
TensorShapeProto.Builder seqLengthShapeBuilder = TensorShapeProto.newBuilder();
seqLengthShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
seqLengthTensorProto.setTensorShape(seqLengthShapeBuilder.build());
predictRequestBuilder.putInputs("input", inputTensorProto.build());
predictRequestBuilder.putInputs("drop_out", dropoutTensorProto.build());
predictRequestBuilder.putInputs("sequence_length", seqLengthTensorProto.build());
// 訪問並獲取結果
Predict.PredictResponse predictResponse = stub.withDeadlineAfter(3, TimeUnit.SECONDS).predict(predictRequestBuilder.build());
Map<String, TensorProto> result = predictResponse.getOutputsMap();
// CRF模型結果,發射概率矩陣和狀態概率矩陣
System.out.println("預測值是:" + result.toString());
}
注意事項:
- 請求type和模型定義的type保持一致,可以到tfserving網頁查看模型參數:
否則會報錯:Expects arg[0] to be float but int32 is provided
tfserving restful網頁:http://localhost:8501/v1/models/ner/metadata
tfserving部署方法見:https://www.cnblogs.com/bincoding/p/13266685.html