首先是官方文档:
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Arguments:
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset. Should return a tuple
of either `(inputs, targets)` or
`(inputs, targets, sample_weights)`.
- A generator or `keras.utils.Sequence` returning `(inputs, targets)`
or `(inputs, targets, sample_weights)`.
A more detailed description of unpacking behavior for iterator types
(Dataset, generator, Sequence) is given below.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset, generator,
or `keras.utils.Sequence` instance, `y` should
not be specified (since targets will be obtained from `x`).
我要讲的重点是输入x和y分别应该传的参数,如果传入的x数据是Dataset, generator, Sequence类型,他可以自动从数组里拆解出target,不应该再指定y了。他会根据你指定的输入和输出个数,自动从前往后拆解为输入和输出所对应的target(可以是多输入多输出,比如每一个数据组是[a,b,c,d,e],a,b是输入,c,d,e是3个输出分别对应的target)。
除了这三种数据类型,你都需要为其传入对应的y参数。如果你是多输入的模型,x就设置为列表,比如[a,b],a是一组输入,b是另一组输入。