两种创建model
的方式
1:链式函数创建
要创建输入层inputs
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
2:使用对象创建
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
属性
属性 | 描述 |
---|---|
layers | 层 |
metrics_names | 所有输出的标签 |
run_eagerly | 是否使用eagerly模式,默认False ,静态图 |
sample_weights | |
state_updates |
方法
- compile
compile(
optimizer,
loss=None,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
**kwargs
)
参数 | 描述 |
---|---|
optimizer | (string,Object)优化器 |
loss | (String,Object,Function),如果模型有多个输出,可以为不同的输出指定不同的损失函数 |
metrics | (List(String))衡量指标,比如[‘accuracy’,‘mse’] |
loss_weights | |
sample_weight_mode | |
weighted_metrics | |
target_tensors | |
distribute | |
**kwargs |
evaluate
evaluate(
x=None,
y=None,
batch_size=None,
verbose=1,
sample_weight=None,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False
)
参数 | 描述 |
---|---|
x | (numpy array;tensor;[tensor];dict;tf.data;keras.utils.Sequence) |
y | |
batch_size | (int)每一次梯度下降使用的样本数量.默认为32,如果输入数据已经指定了batch_size,则不要再次指定 |
verbose | |
sample_weight | |
steps | (int)执行多少个batch之后打印日志信息,默认,一个epoch,打印一次 |
callbacks | |
max_queue_size | |
workers | |
use_multiprocessing |
evaluate_generator
evaluate_generator(
generator,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0
)
fit
fit_generator
get_layer
get_layer(
name=None,
index=None
)
load_weights
load_weights(
filepath,
by_name=False
)
predict
predict(
x,
batch_size=None,
verbose=0,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False
)
predict_generator
predict_generator(
generator,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0
)
predict_on_batch
predict_on_batch(x)
reset_metrics
reset_states
save
保存模型为HDF5文件
save_weights
summary
summary(
line_length=None,
positions=None,
print_fn=None
)
test_on_batch
to_json
to_yaml
train_on_batch
参考:
https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/Model