以mobilenet v2为例
import tensorflow as tf
import tensorflow_hub as hub
IMAGE_SHAPE = (224, 224)
m = tf.keras.Sequential([
hub.KerasLayer("https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/classification/4", input_shape=IMAGE_SHAPE+(3,), output_shape=[1001]),
tf.keras.layers.Softmax()
])
class MyModule(tf.Module):
def __init__(self, model):
self.model = model
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def output(self, input):
result = self.model(input)
return { "probability": result }
module = MyModule(m)
tf.saved_model.save(module,'./savedmodel/mobilenetv2/1',signatures=module.output)