令history = model.fit(...),用history使得训练结果可视化,并在过拟合之前提前结束训练(tf,keras)

这段代码展示了如何在Keras中构建模型并利用EarlyStopping回调进行训练优化。通过监控验证集损失,当验证损失在10个连续的epoch内没有改善时,训练将提前停止。同时,定义了一个辅助函数`plot_history`来绘制平均绝对误差和均方误差随训练进程的变化,帮助理解模型的训练效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
hist.tail()

def plot_history(history):
  hist = pd.DataFrame(history.history)
  hist['epoch'] = history.epoch

  plt.figure()
  plt.xlabel('Epoch')
  plt.ylabel('Mean Abs Error [MPG]')
  plt.plot(hist['epoch'], hist['mae'],
           label='Train Error')
  plt.plot(hist['epoch'], hist['val_mae'],
           label = 'Val Error')
  plt.ylim([0,5])
  plt.legend()

  plt.figure()
  plt.xlabel('Epoch')
  plt.ylabel('Mean Square Error [$MPG^2$]')
  plt.plot(hist['epoch'], hist['mse'],
           label='Train Error')
  plt.plot(hist['epoch'], hist['val_mse'],
           label = 'Val Error')
  plt.ylim([0,20])
  plt.legend()
  plt.show()


plot_history(history)

 

model = build_model()

# patience 值用来检查改进 epochs 的数量
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)

history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,
                    validation_split = 0.2, verbose=0, callbacks=[early_stop, PrintDot()])

plot_history(history)

 

将 pandas 导入为 PD 将 numpy 导入为 NP 将 Seaborn 导入为 SNS 将 matplotlib.pyplot 导入为 PLT %matplotlib 内联 将 TensorFlow 导入为 TF 导入随机 从 cv2 import 调整大小 from glob import glob 导入警告 warnings.filterwarnings(“ignore”)img_height = 244 img_width = 244 train_ds = tf.keras.utils.image_dataset_from_directory( 'D:/Faulty_solar_panel', validation_split=0.2, subset='training', image_size=(img_height, img_width), batch_size=32, seed=42, shuffle=True) val_ds = tf.keras.utils.image_dataset_from_directory( 'D:/Faulty_solar_panel', validation_split=0.2, subset='validation', image_size=(img_height, img_width), batch_size=32, seed=42, shuffle=True)class_names = train_ds.class_names 打印(class_names) train_dsbase_model = tf.keras.applications.VGG16( include_top=False、 weights='imagenet', input_shape=(img_height、img_width、3) ) base_model.trainable = False inputs = tf.keras.Input(shape=(img_height, img_width, 3)) x = tf.keras.applications.vgg16.preprocess_input(输入) x = base_model(x, training=False) x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.Dropout(0.3)(x) 输出 = tf.keras.layers.Dense(90)(x) 模型 = tf.keras.Model(输入,输出) model.summary()model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])epoch = 15 model.fit(train_ds, validation_data=val_ds, epochs=纪元, 回调 = [ tf.keras.callbacks.EarlyStopping( monitor=“val_loss”, min_delta=1e-2, 耐心 = 3, verbose=1, restore_best_weights=) ] )# 微调 base_model.trainable = 真 对于 base_model.layers[:14] 中的 layer: layer.trainable =model.summary()model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])epoch = 15 历史 = model.fit(train_ds, validation_data=val_ds, epochs=epoch, 回调 = [ tf.keras.callbacks.EarlyStopping( monitor=“val_loss”, min_delta=1e-2, 耐心 = 3, verbose=1, ) ] ) get_ac = history.history['准确性'] get_los = history.history['损失'] val_acc = history.history['val_accuracy'] val_loss = history.history['val_loss'] 纪元 = 范围(len(get_ac)) plt.plot(epochs, get_ac, 'g', label='训练数据的准确性') plt.plot(epochs, get_los, 'r', label='训练数据丢失') plt.title('训练数据准确性和损失') plt.legend(loc=0) plt.figure() plt.plot(epochs, get_ac, 'g', label='训练数据的准确性') plt.plot(epochs, val_acc, 'r', label='验证数据的准确性') plt.title('训练和验证准确性') plt.legend(loc=0) plt.figure() plt.plot(epochs, get_los, 'g', label='训练数据丢失') plt.plot(纪元, val_loss, 'r', label='验证数据丢失') plt.title('训练和验证损失') plt.legend(loc=0) plt.figure() plt.show()把这段代码使用的模型改为mobilenet模型提升精度,给出修改后的完整代码
最新发布
03-31
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Alocus_

如果我的内容帮助到你,打赏我吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值