深度学习-二分类问题

学习查询方法

链接: link
这是TensorFlow的API网址,里面对函数意义以及各参数进行了详细的描述。

加载数据

# 获取训练集Dataset
train_dataset, valid_dataset = keras.utils.image_dataset_from_directory(
    path+"/train",
    labels="inferred",#"inferred"标签根据目录推断生成
    class_names=["cat","dog"],#类名显示列表,必须与子目录匹配
    label_mode="binary",#二分类问题,标签编码为01
    image_size=(150, 150),#指定图像大小
    validation_split=0.2,#划分一部分给验证集
    subset="both",#要返回的数据的子集。“培训”、“验证”或“两者兼而有之”。
    seed=42,
    batch_size=32)#数据批次的大小
train_dataset = train_dataset.map(lambda x,label : (x/255,label)).prefetch(1)#在对图像数据进行归一化处理,将图像的像素值从[0, 255]区间缩放到[0, 1]区间。同时,它使用了prefetch(1)方法来预取数据,以提高数据读取效率。
valid_dataset = valid_dataset.map(lambda x,label : (x/255,label)).prefetch(1)

# 获取测试集Dataset
test_dataset = keras.utils.image_dataset_from_directory(
    path+"/test",
    labels="inferred",
    class_names=["cat","dog"],
    label_mode="binary",
    shuffle=False,#不打乱数据
    image_size=(150, 150),
    batch_size=32)
test_dataset = test_dataset.map(lambda x,label : (x/255,label)).prefetch(1)

构建模型

# 从keras加载VGG16网络
input_data = keras.layers.Input(shape=(150,150,3))
vgg16_model = keras.applications.VGG16(include_top=False, #不要全连接层
                                       weights=None, #不要预训练权重
                                       input_tensor=input_data)

# 添加自己的全连接层,并创建模型
x = keras.layers.Flatten()(vgg16_model.output)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(256, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
output_data = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.models.Model(input_data, output_data)
model.summary()

#将Keras模型转换为点格式并保存到文件。
plot_model(model,to_file="model.png",
            show_shapes=True,
            show_layer_activations=True)

optimizer = keras.optimizers.RMSprop(learning_rate=5e-5)#选择优化模型,选择合适的学习率
model.compile(loss='binary_crossentropy',
              optimizer=optimizer,
              metrics=['acc'])

model_file = "cat_dog_vgg16.keras"
ck_callback = keras.callbacks.ModelCheckpoint( #保存Keras模型或模型权重的回调。
    model_file,
    monitor="val_acc",#监视验证集的准确度
    save_best_only=True#只保存最佳的
    )
训练模型,画出各种指标
# 训练模型
max_epochs = 40
history = model.fit(train_dataset, 
                    epochs=max_epochs,
                    validation_data=valid_dataset,
                    callbacks=[ck_callback])

# 获取训练数据,画训练过程性能指标变化图
history = history.history
acc = history['acc']
val_acc = history['val_acc']
loss = history['loss']
val_loss = history['val_loss']
epo = np.arange(len(acc))
plt.figure()
plt.subplot(121);plt.grid();plt.title("accuracy");
plt.plot(epo, acc, label='train')
plt.plot(epo,val_acc, label='val');plt.legend()
plt.subplot(122);plt.grid();plt.title("loss");
plt.plot(epo, loss, label='train')
plt.plot(epo,val_loss, label='val');plt.legend()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值