tf v2.4 文档 链接: fit
示例
fit(
x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,
sample_weight=None, initial_epoch=0, steps_per_epoch=None,
validation_steps=None, validation_batch_size=None, validation_freq=1,
max_queue_size=10, workers=1, use_multiprocessing=False
)
参数说明
参数 | 描述 |
---|---|
x | Input data. 可以是numpy数组、数组列表、tensorflow tensor、tensor列表;也可以是 tf.data dataset,需返回(inputs, targets) 或 (inputs, targets, sample_weights);也可以是generator or keras.utils.Sequence返回(inputs, targets) 或 (inputs, targets, sample_weights)。 |
y | Target data. 与输入数据x一样,它可以是Numpy数组或TensorFlow张量。它应与x一致。如果x是datasets, generators, 或keras.utils.Sequence实例,则不应指定y,因为target将从x获得。 |
batch_size | Integer or None. 每次梯度更新的样本数。如果未指定,batch_size将默认为32。如果数据是datasets, generators或keras.utils.Sequence实例的形式,请不要指定batch_size(因为它们会批量生成)。 |
epochs | Integer. epoch 是对所提供的整个x和y数据的迭代。与与initial_epoch一起时,epoch被理解为“最后的epoch”,此时模型不针对由epoch给定的值进行训练,而仅在达到索引epoch的epoch之前进行训练。 |
verbose | 取值为0、1或2。0=静音,1=进度条,2=每 epoch 一行。当记录到文件时,进度条不是特别有用,因此当不以交互方式运行时(例如,在生产环境中),建议使用verbose=2。 |
callbacks | keras.Callback.Callback实例的列表。培训期间应用的回调列表。请参阅tf.keras.callbacks.Note tf.keras Callback.ProgbarLogger和tf.keras.Callback.History回调是自动创建的,不需要传递到model.fit. |
validation_split | 介于0和1之间。表示用作验证数据的训练数据的部分。该模型将分离训练数据的这一部分,不会对其进行训练,并将在每个 epoch 结束时评估该数据的loss和任何模型 metric。在 shuffle 之前,从提供的x和y数据中的最后一个样本中选择验证数据。当x是 datasets, generators或keras.utils.Sequence 实例时,不支持此参数。 |
validation_data | 在每个 epoch 结束时评估损失和任何模型度量的数据。模型将不会基于该数据进行训练。因此,请注意,使用validation_split或validation_data提供的数据的验证丢失不受正则化层(如噪声和丢失)的影响。validationdata将覆盖validationsplit。validation_data可以是:Numpy数组或张量的元组(x_val,y_val),Numpy数组的元组(x_val、y_val和val_sample_weights),dataset对于前两种情况,必须提供batch_size。对于最后一种情况,可以提供validation_steps。注意validation_data并不支持x中支持的所有数据类型,例如dict、generator或keras.utils.Sequence |
shuffle | Boolean(是否在每个epoch之前对训练数据进行shuffuse)或str(for “batch”)。当x是 generator 时,将忽略此参数。'“batch”是处理HDF5数据限制的特殊选项;它以批量大小的块进行shuffle。当steps_per_epoch不是None时没有效果。 |
class_weight | 可选字典,将类索引(整数)映射到权重(浮点)值,用于加权损失函数(仅在训练期间)。这有助于告诉模型“更多地注意”来自未充分表示的类的样本。 |
sample_weight | 训练样本的可选Numpy权重数组,用于加权损失函数(仅在训练期间)。可以传递与 input 长度相同的 1D Numpy数组(权重和采样之间的1:1映射),或者在包含时间的情况下,可以传递具有形状(samples,sequence_length)的2D数组,以将不同的权重应用于每个采样的每个时间步长。当x是 datasets, generators 或keras.utils.Sequence实例时,不支持此参数,而是提供sample_weights作为x的第三个元素。 |
initial_epoch | Integer. 开始训练的 epoch,用于恢复之前的训练过程。 |
steps_per_epoch | Integer or None. 声明一个epoch已完成并开始下一个epech之前的步骤总数(batches of samples)。使用输入张量(如TensorFlow数据张量)进行训练时,默认的“None”等于数据集中的样本数除以批大小,如果无法确定,则为1。如果x是tf.data数据集,并且’steps_per_epoch’为None’,则epoch将运行,直到耗尽输入数据集。传递无限重复的数据集时,必须指定steps_per_epoch参数。input 为数组时不支持此参数。 |
validation_steps | 仅当提供了validation_data并且是tf.data数据集时才相关。在每个epoch 末尾执行验证时,在停止之前要绘制的步骤总数(batches of samples)。如果“validation_steps”为“None”,则验证将一直运行,直到validation_data数据集耗尽。在无限重复数据集的情况下,它将进入无限循环。如果指定了“validation_steps”,并且只使用数据集的一部分,则计算将从每个epoch的数据集开始。这确保每次使用相同的验证样本。 |
validation_batch_size | Integer or None. 每个验证批次的样本数量。如果未指定,则默认为batch_size。如果数据是datasets, generators 或keras.utils.Sequence实例的形式,请不要指定validation_batch_size(因为它们会生成 batch)。 |
validation_freq | 仅在提供验证数据时相关。整数或collections_abc.Container实例(例如,列表、元组等)。如果是整数,则指定在执行新的验证运行之前要运行的训练时段数,例如,validation_freq=2每2个时段运行验证。如果是容器,则指定要在其上运行验证的epoch,例如validation_freq=[1,2,10]在第一个、第二个和第十个epoch的末尾运行验证。 |
max_queue_size | Integer. 仅用于generator 或keras.utils.Sequence输入。generator队列的最大值。如果未指定,max_queue_size将默认为10。 |
workers | 仅用于 generator 或keras.utils.Sequence输入。使用基于进程的线程时要启动的最大进程数。如果未指定,worker 将默认为1。如果为0,将在主线程上执行generator。 |
use_multiprocessing | Boolean. 仅用于 generator 或keras.utils.Sequence输入。如果为True,则使用基于进程的线程。如果未指定,use_multiprocessing将默认为False。由于该实现依赖于多处理,因此不应将不可拾取(non-picklable)的参数传递给 generator,因为它们不能轻松传递给子进程。 |
返回
A History object. 其 History.history 属性是连续时段的训练损失值和度量值的记录,以及验证损失值和验证度量值(如果适用)。
错误抛出
类迭代器输入的解包行为:一种常见的模式是将tf.data.Dataset、generator或tf.keras.utils.Sequence传递给fit的x参数,这实际上不仅会产生特征(x),还可以产生目标(y)和样本权重。Keras要求这种迭代器的输出是明确的。迭代器应该返回长度为1、2或3的元组,其中可选的第二个和第三个元素将分别用于y和sample_weight。提供的任何其他类型都将封装在一个长度为一的元组中,有效地将所有内容都视为“x”。当生成dict时,它们仍然应该遵循顶级元组结构。例如({“x0”:x0,“x1”:x1},y)。Keras不会试图将功能、目标和权重从单个dict的键中分离出来。namedtuple是不支持的数据类型,因为它的行为类似于有序数据类型(元组)和映射数据类型(dict)。例,给定namedtuple形式的namedtuple: (“example_tuple”,[“y”,“x”]),在解释值时是否反转元素的顺序是模糊的。更糟糕的是形式为namedtuple(“other_tuple”,[“x”,“y”,“z”])的元组,其中不清楚元组是打算解包为x、y和sample_weight,还是作为单个元素传递给x。因此,如果数据处理代码遇到namedtuple会抛出ValueError如下所示。
RuntimeError | 如果从未编译模型,或者,如果model.fit包装在tf.function中。 |
---|---|
ValueError | 如果提供的输入数据与模型预期的不匹配,或者输入数据为空。 |