keras模型使用
keras官方文档中文版:https://keras.io/zh/
from keras.layers import Input
from keras.models import Model
from keras import optimizers
from keras.callbacks import ModelCheckpoint
input = Input(shape=(c.size_train[0], c.size_train[1], 4)) #keras模型的输入是keras tensor
pred = Net(input) #自定义网络
model = Model(input, pred) #将指定输入输出的网络变成Model:将培训和评估程序添加到网络(Network)中。
model.summary() #打印网络的字符串摘要。
adam = optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
model.compile(optimizer=adam, loss=['categorical_crossentropy' for _ in range(4)], loss_weights=[10.0, 0.1, 0.1, 0.1], metrics=['accuracy'])
modelcheck = ModelCheckpoint(model_weights, monitor='val_l1_acc', save_best_only=False, mode='auto')
callable = [modelcheck] #
model.load_weights(args['model'] + '/' + a[-1], by_name=True)
H = model.fit_generator(generator=train_set,
steps_per_epoch=train_numb // c.batch_size,
epochs=c.num_epochs,
verbose=1,
validation_data=val_set,
validation_steps=valid_numb // c.batch_size,
callbacks=callable,
max_q_size=1)
#generator
以上为某中使用方式的代码片段
一. Input()
keras.engine.input_layer #模块keras.engine.input_layer,输入层代码(Input和InputLayer)。
Input()用于实例化Keras tensor(张量)。
- Keras张量是来自底层后端(Theano,TensorFlow或CNTK)的张量对象,我们通过某些属性进行扩充,这些属性允许我们仅通过了解模型的输入和输出来构建Keras模型。
- 例如,如果a,b和c是Keras张量,则可以这样做:model = Model(input = [a,b],output = c)
- 添加的Keras属性是:
_keras_shape:通过Keras端形状推断传播的整数形状元组。
_keras_history:应用于张量的最后一层。 可以递归地从该层检索整个图层图。
def Input(shape: Any = None,
batch_shape: Any = None,
name: Any = None,
dtype: Any = None,
sparse: bool = False,
tensor: Any = None) -> {__len__}
参数
- shape:形状元组(整数),不包括批量大小。 例如,shape =(32,)表示预期输入将是32维向量的批次。
- batch_shape:形状元组(整数),包括批量大小。 例如,batch_shape
=(10,32)表示预期输入将是10个32维向量的批次。 batch_shape =(None,32)表示任意数量的32维向量的批次。 - name:图层的可选名称字符串。 在模型中应该是唯一的(不要重复使用相同的名称两次)。 如果没有提供,它将自动生成。
- dtype:输入所期望的数据类型,作为字符串(float32,float64,int32 …)
- sparse:一个布尔值,指定要创建的占位符是否为稀疏。
- tensor:可选的现有张量以包装到Input层。 如果设置,该图层将不会创建占位符张量。
返回值
一个tensor张量。
例子
from keras.layers import Input
from keras.models import Model
#这是Keras的逻辑回归
x = Input(shape=(32,))
y = Dense(16, activation='softmax')(x) #Dense模型的输出
model = Model(x, y) #见下条说明
二. Model(Network)类
keras.engine.training #模块keras.engine.training,Keras引擎的训练相关部分。
class Model(Network)
The Model class adds training & evaluation routines to a Network.
“Model类将培训和评估程序添加到网络(Network)中。”
from keras.models import Model
model = Model(input, output) #定义模型的输入输出
model.summary() #见下条说明
–> 1. Network(Layer)类
keras.engine.network #模块keras.engine.network。网络是组合图层的方式:模型的拓扑形式。
class Network(Layer)
A Network is a directed acyclic graph of layers.It is the topological form of a “model”. A Model is simply a Network with added training routines.
“Network是层(Layer)的有向无环图。它是“模型”的拓扑形式。 Model只是一个增加了训练程序的网络。”
属性
-
name
-
inputs
-
outputs
-
layers
-
input_spec (list of class instances“类实例列表”)
each entry describes one required input: ndim; dtype.
“每个条目描述一个必需的输入:ndim;dtype。” -
trainable (boolean)
-
input_shape
-
output_shape
-
weights (list of variables)
-
trainable_weights (list of variables)
-
non_trainable_weights (list of variables)
-
losses
-
updates
-
state_updates
-
stateful
方法
- __ call __
- summary
- get_layer
- get_weights
- set_weights
- get_config
- compute_output_shape
- save
- add_loss
- add_update
- get_losses_for
- get_updates_for
- to_json
- to_yaml
- reset_states
类方法
- from_config
Raises(增加/附加?)
- TypeError: if input tensors are not Keras tensors (tensors returned by Input).
“如果输入张量不是Keras张量(Input返回的张量)”
–> 2. Layer(object)类
keras.engine.base_layer 模块keras.engine.base_layer。包含基础Layer类,所有层都从该类继承。
Abstract base layer class. “抽象基础层 类”
属性
- input, output: Input/output tensor(s). Note that if the layer is used more than once (shared layer), this is ill-defined and will raise an exception. In such cases, use layer.get_input_at(node_index).
- input_mask, output_mask: Mask tensors. Same caveats apply as input, output.
- input_shape: Shape tuple. Provided for convenience, but note that there may be cases in which this attribute is ill-defined (e.g. a shared layer with multiple input shapes), in which case requesting input_shape will raise an Exception. Prefer using layer.get_input_shape_at(node_index).
- input_spec: List of InputSpec class instances each entry describes one required input: ndim; dtype. A layer with n input tensors must have an input_spec of length n.
- name: String, must be unique within a model.
- non_trainable_weights: List of variables.
- output_shape: Shape tuple. See input_shape.
- stateful:Boolean indicating whether the layer carries additional non-weight state. Used in, for instance, RNN cells to carry information between batches.
- supports_masking: Boolean indicator of whether the layer supports masking, typically for unused timesteps in a sequence.
- trainable: Boolean, whether the layer weights will be updated during training.
- trainable_weights: List of variables.
- uses_learning_phase:Whether any operation of the layer uses K.in_training_phase() or K.in_test_phase().
- weights: The concatenation of the lists trainable_weights and non_trainable_weights (in this order).
方法
- call(x, mask=None): Where the layer’s logic lives.
- __call __(x, mask=None): Wrapper around the layer logic (call).
If x is a Keras tensor:
Connect current layer with last layer from tensor: self._add_inbound_node(last_layer)
Add layer to tensor history
If layer is not built:
Build from x._keras_shape
- compute_mask(x, mask)
- compute_output_shape(input_shape)
- count_params()
- get_config()
- get_input_at(node_index)
- get_input_mask_at(node_index)
- get_input_shape_at(node_index)
- get_output_at(node_index)
- get_output_mask_at(node_index)
- get_output_shape_at(node_index)
- get_weights()
- set_weights(weights)
类方法
- from_config(config)
Internal methods(内部方法)
- _add_inbound_node(layer, index=0)
- assert_input_compatibility()
- build(input_shape)
三. summary()
keras.engine.network.Network #模块keras.engine.network中的Network类,Network 类的具体说明见上一节:
keras.engine.network #模块keras.engine.network。网络是组合图层的方式:模型的拓扑形式。
打印网络的字符串摘要。
def summary(self,
line_length: Any = None,
positions: Any = None,
print_fn: Any = None) -> None
参数
line_length: Total length of printed lines “打印行的总长度”
(e.g. set this to adapt the display to different terminal window sizes.
“ 例如,将其设置为使显示适应不同的终端窗口大小。”)
positions: Relative or absolute positions of log elements in each line. If not provided, defaults to [.33, .55, .67, 1.].
“每行中log元素的相对或绝对位置。 如果未提供,则默认为[.33,.55,.67, 1.] ”
print_fn: Print function to use.(要使用的打印功能。)
It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary.It defaults to print (prints to stdout).
“它将在摘要的每一行上调用。 您可以将其设置为自定义函数以捕获字符串摘要。 默认为打印(打印到标准输出)。”