发现网上大量的代码都是mnist,我自己反正不是搞图像处理的,所以这个例子我怎么都不想搞;
wide&deep这种,包含各种特征的模型,才是我的需要,iris也是从文本训练模型,所以非常简单;
本文给出Python和Java访问Tensorflow的Serving代码。
Java版本使用Grpc访问Tensorflow的Serving代码
package io.github.qf6101.tensorflowserving;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import org.tensorflow.example.*;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 参考:https://www.jianshu.com/p/d82107165119
* 参考:https://github.com/grpc/grpc-java
*/
public class PssIrisGrpcClient {
public static Example createExample() {
Features.Builder featuresBuilder = Features.newBuilder();
Map<String, Float> dataMap = new HashMap<String, Float>();
dataMap.put("SepalLength", 5.1f);
dataMap.put("SepalWidth", 3.3f);
dataMap.put("PetalLength", 1.7f);
dataMap.put("PetalWidth", 0.5f);
Map<String, Feature> featuresMap = mapToFeatureMap(dataMap);
featuresBuilder.putAllFeature(featuresMap);
Features features = featuresBuilder.build();
Example.Builder exampleBuilder = Example.newBuilder();
exampleBuilder.setFeatures(features);
return exampleBuilder.build();
}
private static Map<String, Feature> mapToFeatureMap(Map<String, Float> dataMap) {
Map<String, Feature> resultMap = new HashMap<String, Feature>();
for (String key : dataMap.keySet()) {
// // data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5}
FloatList floatList = FloatList.newBuilder().addValue(dataMap.get(key)).build();
Feature feature = Feature.newBuilder().setFloatList(floatList).build();
resultMap.put(key, feature);
}
return resultMap;
}
public static void main(String[] args) {
String host = "127.0.0.1";
int port = 8888;
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port)
// Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid
// needing certificates.
.usePlaintext()
.build();
PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub = PredictionServiceGrpc.newBlockingStub(channel);
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder()
.setValue(1)
.build();
Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
.setName("iris")
.setVersion(version)
.setSignatureName("classification")
.build();
List<ByteString> exampleList = new ArrayList<ByteString>();
exampleList.add(createExample().toByteString());
TensorShapeProto.Dim featureDim = TensorShapeProto.Dim.newBuilder().setSize(exampleList.size()).build();
TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(featureDim).build();
org.tensorflow.framework.TensorProto tensorProto = TensorProto.newBuilder().addAllStringVal(exampleList).setDtype(DataType.DT_STRING).setTensorShape(shapeProto).build();
Predict.PredictRequest request = Predict.PredictRequest.newBuilder()
.setModelSpec(modelSpec)
.putInputs("inputs", tensorProto)
.build();
tensorflow.serving.Predict.PredictResponse response = blockingStub.predict(request);
System.out.println(response);
channel.shutdown();
}
}
需要增加如下maven依赖:
<!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow --> <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.12.0</version> </dependency> <!-- https://mvnrepository.com/artifact/io.grpc/grpc-netty --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty</artifactId> <version>1.20.0</version> </dependency> <!-- https://mvnrepository.com/artifact/io.grpc/grpc-protobuf --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> <version>1.20.0</version> </dependency> <!-- https://mvnrepository.com/artifact/io.grpc/grpc-stub --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> <version>1.20.0</version> </dependency>
输出结果:
outputs {
key: "scores"
value {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
dim {
size: 3
}
}
float_val: 0.9997806
float_val: 2.1938368E-4
float_val: 1.382611E-9
}
}
outputs {
key: "classes"
value {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
dim {
size: 3
}
}
string_val: "0"
string_val: "1"
string_val: "2"
}
}
# 创建 gRPC 连接
import pandas as pd
from grpc.beta import implementations
import tensorflow as tf
from tensorflow_serving.apis import prediction_service_pb2, classification_pb2
#channel = implementations.insecure_channel('127.0.0.1', 8500):8888
channel = implementations.insecure_channel('127.0.0.1', 8888)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
def _create_feature(v):
return tf.train.Feature(float_list=tf.train.FloatList(value=[v]))
data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5}
features1 = {k: _create_feature(v) for k, v in data1.items()}
example1 = tf.train.Example(features=tf.train.Features(feature=features1))
data2 = {"SepalLength":1.1,"SepalWidth":1.3,"PetalLength":1.7,"PetalWidth":0.5}
features2 = {k: _create_feature(v) for k, v in data2.items()}
example2 = tf.train.Example(features=tf.train.Features(feature=features2))
# 获取测试数据集,并转换成 Example 实例。
examples = [example1, example2]
# 准备 RPC 请求,指定模型名称。
request = classification_pb2.ClassificationRequest()
request.model_spec.name = 'iris'
request.input.example_list.examples.extend(examples)
# 获取结果
response = stub.Classify(request, 10.0)
print(response)
Python代码看起来简单不少,但是我们的线上服务都是Java,所以不好集成的,只能做一些离线的批量预测;
输出如下:
result {
classifications {
classes {
label: "0"
score: 0.9997805953025818
}
classes {
label: "1"
score: 0.00021938368445262313
}
classes {
label: "2"
score: 1.382611025668723e-09
}
}
classifications {
classes {
label: "0"
score: 0.0736534595489502
}
classes {
label: "1"
score: 0.8393719792366028
}
classes {
label: "2"
score: 0.08697459846735
}
}
}
model_spec {
name: "iris"
version {
value: 1
}
signature_name: "serving_default"
}
个人其实非常喜欢HTTP+JSON接口,完全不用搞这么多grpc这些麻烦的东西,尤其Java的grpc,遇到好多问题好崩溃;
不过号称grpc比http性能好不少,线上只能用grpc。