DataWhale训练营个人笔记2:DogVsCat数据集-数据分析

文章介绍了如何使用Python的argparse库解析命令行参数,并结合TensorFlow构建一个卷积神经网络模型,用于训练猫狗图像分类。包括数据预处理、模型结构、编译和训练过程。
摘要由CSDN通过智能技术生成
import argparse  # 导入argparse库,用于命令行参数解析
import tensorflow as tf
import os  # 用于文件操作

# 创建命令行参数解析器
parser = argparse.ArgumentParser(description='Process some integers')
parser.add_argument('--mode', default='train', help='train or test')  # 添加一个--mode参数,默认值为'train'
parser.add_argument("--num_epochs", default=5, type=int)  # 添加一个--num_epochs参数,默认值为5
parser.add_argument("--batch_size", default=32, type=int)  # 添加一个--batch_size参数,默认值为32
parser.add_argument("--learning_rate", default=0.001)  # 添加一个--learning_rate参数,默认值为0.001
parser.add_argument("--data_dir", default="/gemini/data-1")  # 添加一个--data_dir参数,默认值为"/gemini/data-1"
parser.add_argument("--train_dir", default="/gemini/output")  # 添加一个--train_dir参数,默认值为"/gemini/output"
args = parser.parse_args()  # 解析命令行参数并存储到args中


def _decode_and_resize(filename, label):
    # 读取图像文件并解码
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    # 调整图像大小并进行归一化处理
    image_resized = tf.image.resize(image_decoded, [150, 150]) / 255.0
    return image_resized, label


if __name__ == "__main__":
    train_dir = args.data_dir + "/train"  # 获取训练数据集目录
    cats = []  # 创建一个空列表存储猫的图像文件路径
    dogs = []  # 创建一个空列表存储狗的图像文件路径
    # 遍历训练数据集目录的文件
    for file in os.listdir(train_dir):
        if file.startswith("dog"):  # 如果文件以"dog"开头,说明是狗的图像文件
            dogs.append(train_dir + "/" + file)  # 将狗的图像文件路径添加到dogs列表中
        else:
            cats.append(train_dir + "/" + file)  # 否则,将猫的图像文件路径添加到cats列表中
    print("dogSize:%d catSize:%d" % (len(cats), len(dogs)))  # 打印猫和狗的图像文件数量
    train_cat_filenames = tf.constant(cats[:10000])  # 创建一个tensor保存前1万个猫的图像文件路径
    train_dog_filenames = tf.constant(dogs[:10000])  # 创建一个tensor保存前1万个狗的图像文件路径
    train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)  # 沿着最后一个维度拼接猫和狗的图像文件路径
    train_labels = tf.concat([
        tf.zeros(train_cat_filenames.shape, dtype=tf.int32),  # 创建一个全零的tensor作为猫的标签
        tf.ones(train_dog_filenames.shape, dtype=tf.int32)  # 创建一个全一的tensor作为狗的标签
    ], axis=-1)  # 沿着最后一个维度拼接猫和狗的标签

    train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))  # 创建训练数据集

    # 使用_decode_and_resize函数对训练数据集进行预处理
    train_dataset = train_dataset.map(map_func=_decode_and_resize,
                                      num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_dataset = train_dataset.shuffle(buffer_size=20000)  # 打乱训练数据集
    train_dataset = train_dataset.batch(args.batch_size)  # 设定批次大小
    train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)  # 数据预取

    # 构建卷积神经网络模型
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(150, 150, 3)),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Conv2D(64, 3, activation="relu"),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Conv2D(128, 3, activation="relu"),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Conv2D(128, 3, activation="relu"),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(512, activation="relu"),
        tf.keras.layers.Dense(2, activation="softmax")
    ])

    # 编译模型
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=args.learning_rate),
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    )

    model.fit(train_dataset, epochs=args.num_epochs)  # 训练模型
    model.save(args.train_dir)  # 保存模型

    # 构建测试数据集
    test_cat_filenames = tf.constant(cats[10000:])  # 创建一个tensor保存后面的猫的图像文件路径
    test_dog_filenames = tf.constant(dogs[10000:])  # 创建一个tensor保存后面的狗的图像文件路径
    test_filenames = tf.concat([test_cat

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值