tensorflow model.fit()解读

tf v2.4 文档 链接: fit

示例

fit(
    x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
    validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,
    sample_weight=None, initial_epoch=0, steps_per_epoch=None,
    validation_steps=None, validation_batch_size=None, validation_freq=1,
    max_queue_size=10, workers=1, use_multiprocessing=False
)

参数说明

参数描述
xInput data. 可以是numpy数组、数组列表、tensorflow tensor、tensor列表;也可以是 tf.data dataset,需返回(inputs, targets) 或 (inputs, targets, sample_weights);也可以是generator or keras.utils.Sequence返回(inputs, targets) 或 (inputs, targets, sample_weights)。
yTarget data. 与输入数据x一样,它可以是Numpy数组或TensorFlow张量。它应与x一致。如果x是datasets, generators, 或keras.utils.Sequence实例,则不应指定y,因为target将从x获得。
batch_sizeInteger or None. 每次梯度更新的样本数。如果未指定,batch_size将默认为32。如果数据是datasets, generators或keras.utils.Sequence实例的形式,请不要指定batch_size(因为它们会批量生成)。
epochsInteger. epoch 是对所提供的整个x和y数据的迭代。与与initial_epoch一起时,epoch被理解为“最后的epoch”,此时模型不针对由epoch给定的值进行训练,而仅在达到索引epoch的epoch之前进行训练。
verbose取值为0、1或2。0=静音,1=进度条,2=每 epoch 一行。当记录到文件时,进度条不是特别有用,因此当不以交互方式运行时(例如,在生产环境中),建议使用verbose=2。
callbackskeras.Callback.Callback实例的列表。培训期间应用的回调列表。请参阅tf.keras.callbacks.Note tf.keras Callback.ProgbarLogger和tf.keras.Callback.History回调是自动创建的,不需要传递到model.fit.
validation_split介于0和1之间。表示用作验证数据的训练数据的部分。该模型将分离训练数据的这一部分,不会对其进行训练,并将在每个 epoch 结束时评估该数据的loss和任何模型 metric。在 shuffle 之前,从提供的x和y数据中的最后一个样本中选择验证数据。当x是 datasets, generators或keras.utils.Sequence 实例时,不支持此参数。
validation_data在每个 epoch 结束时评估损失和任何模型度量的数据。模型将不会基于该数据进行训练。因此,请注意,使用validation_split或validation_data提供的数据的验证丢失不受正则化层(如噪声和丢失)的影响。validationdata将覆盖validationsplit。validation_data可以是:Numpy数组或张量的元组(x_val,y_val),Numpy数组的元组(x_val、y_val和val_sample_weights),dataset对于前两种情况,必须提供batch_size。对于最后一种情况,可以提供validation_steps。注意validation_data并不支持x中支持的所有数据类型,例如dict、generator或keras.utils.Sequence
shuffleBoolean(是否在每个epoch之前对训练数据进行shuffuse)或str(for “batch”)。当x是 generator 时,将忽略此参数。'“batch”是处理HDF5数据限制的特殊选项;它以批量大小的块进行shuffle。当steps_per_epoch不是None时没有效果。
class_weight可选字典,将类索引(整数)映射到权重(浮点)值,用于加权损失函数(仅在训练期间)。这有助于告诉模型“更多地注意”来自未充分表示的类的样本。
sample_weight训练样本的可选Numpy权重数组,用于加权损失函数(仅在训练期间)。可以传递与 input 长度相同的 1D Numpy数组(权重和采样之间的1:1映射),或者在包含时间的情况下,可以传递具有形状(samples,sequence_length)的2D数组,以将不同的权重应用于每个采样的每个时间步长。当x是 datasets, generators 或keras.utils.Sequence实例时,不支持此参数,而是提供sample_weights作为x的第三个元素。
initial_epochInteger. 开始训练的 epoch,用于恢复之前的训练过程。
steps_per_epochInteger or None. 声明一个epoch已完成并开始下一个epech之前的步骤总数(batches of samples)。使用输入张量(如TensorFlow数据张量)进行训练时,默认的“None”等于数据集中的样本数除以批大小,如果无法确定,则为1。如果x是tf.data数据集,并且’steps_per_epoch’为None’,则epoch将运行,直到耗尽输入数据集。传递无限重复的数据集时,必须指定steps_per_epoch参数。input 为数组时不支持此参数。
validation_steps仅当提供了validation_data并且是tf.data数据集时才相关。在每个epoch 末尾执行验证时,在停止之前要绘制的步骤总数(batches of samples)。如果“validation_steps”为“None”,则验证将一直运行,直到validation_data数据集耗尽。在无限重复数据集的情况下,它将进入无限循环。如果指定了“validation_steps”,并且只使用数据集的一部分,则计算将从每个epoch的数据集开始。这确保每次使用相同的验证样本。
validation_batch_sizeInteger or None. 每个验证批次的样本数量。如果未指定,则默认为batch_size。如果数据是datasets, generators 或keras.utils.Sequence实例的形式,请不要指定validation_batch_size(因为它们会生成 batch)。
validation_freq仅在提供验证数据时相关。整数或collections_abc.Container实例(例如,列表、元组等)。如果是整数,则指定在执行新的验证运行之前要运行的训练时段数,例如,validation_freq=2每2个时段运行验证。如果是容器,则指定要在其上运行验证的epoch,例如validation_freq=[1,2,10]在第一个、第二个和第十个epoch的末尾运行验证。
max_queue_sizeInteger. 仅用于generator 或keras.utils.Sequence输入。generator队列的最大值。如果未指定,max_queue_size将默认为10。
workers仅用于 generator 或keras.utils.Sequence输入。使用基于进程的线程时要启动的最大进程数。如果未指定,worker 将默认为1。如果为0,将在主线程上执行generator。
use_multiprocessingBoolean. 仅用于 generator 或keras.utils.Sequence输入。如果为True,则使用基于进程的线程。如果未指定,use_multiprocessing将默认为False。由于该实现依赖于多处理,因此不应将不可拾取(non-picklable)的参数传递给 generator,因为它们不能轻松传递给子进程。

返回

A History object. 其 History.history 属性是连续时段的训练损失值和度量值的记录,以及验证损失值和验证度量值(如果适用)。

错误抛出

类迭代器输入的解包行为:一种常见的模式是将tf.data.Dataset、generator或tf.keras.utils.Sequence传递给fit的x参数,这实际上不仅会产生特征(x),还可以产生目标(y)和样本权重。Keras要求这种迭代器的输出是明确的。迭代器应该返回长度为1、2或3的元组,其中可选的第二个和第三个元素将分别用于y和sample_weight。提供的任何其他类型都将封装在一个长度为一的元组中,有效地将所有内容都视为“x”。当生成dict时,它们仍然应该遵循顶级元组结构。例如({“x0”:x0,“x1”:x1},y)。Keras不会试图将功能、目标和权重从单个dict的键中分离出来。namedtuple是不支持的数据类型,因为它的行为类似于有序数据类型(元组)和映射数据类型(dict)。例,给定namedtuple形式的namedtuple: (“example_tuple”,[“y”,“x”]),在解释值时是否反转元素的顺序是模糊的。更糟糕的是形式为namedtuple(“other_tuple”,[“x”,“y”,“z”])的元组,其中不清楚元组是打算解包为x、y和sample_weight,还是作为单个元素传递给x。因此,如果数据处理代码遇到namedtuple会抛出ValueError如下所示。

RuntimeError如果从未编译模型,或者,如果model.fit包装在tf.function中。
ValueError如果提供的输入数据与模型预期的不匹配,或者输入数据为空。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值