官方代码使用TensorFlow训练一个卷积神经网络进行二元分类任务,数据集包含猫和狗的图像。让我们逐步分析这段代码:
1. 用到的库:
- `argparse`:用于解析命令行参数的模块。
- `tensorflow`:用于构建和训练机器学习模型的TensorFlow库。
- `os`:Python的`os`模块,用于与操作系统进行交互。
2. 使用`argparse`定义命令行参数:
- `--mode`:指定运行模式,可以是"train"或"test"(默认是"train")。
- `--num_epochs`:训练的轮数(默认为5)。
- `--batch_size`:训练的批次大小(默认为32)。
- `--learning_rate`:优化器的学习率(默认为0.001)。
- `--data_dir`:存储训练数据的目录(默认为"/gemini/data-1")。
- `--train_dir`:保存训练好的模型的目录(默认为"/gemini/output")。
3. 定义函数`_decode_and_resize`:
- 该函数接受文件名和标签作为输入。
- 它读取图像文件,解码图像,将其调整为固定大小(150x150),并将像素值缩放到[0, 1]的范围内。
- 函数返回调整大小后的图像和标签。
4. 构建训练数据集:
- 通过拼接猫和狗的训练文件路径,创建了训练数据集的文件名张量 `train_filenames`。
- 通过拼接零(猫的标签)和一(狗的标签),创建了对应的标签张量 `train_labels`。
- 使用 `tf.data.Dataset.from_tensor_slices` 将文件名和标签组成的元组作为数据集的元素,创建了训练数据集 `train_dataset`。
- 调用 `map` 方法,将 `_decode_and_resize` 函数应用于数据集中的每个元素,以进行图像解码和调整大小操作。
- 可以取消注释 `train_dataset.shuffle(buffer_size=20000)` 来在训练时对数据进行随机洗牌(打乱顺序)。
5. 构建卷积神经网络模型:
- 使用 TensorFlow 构建了一个序列模型 `model`,该模型由一系列卷积层、池化层和全连接层组成。
- 模型的输入形状是(150, 150, 3),即150x150像素的彩色图像。
- 模型的最后一层是一个包含两个神经元的 softmax 输出层,用于进行二元分类。
6. 编译模型:
- 使用 `model.compile` 方法编译模型,指定了优化器(Adam优化器)和损失函数(稀疏分类交叉熵)。
- 还定义了模型评估的指标,这里使用了稀疏分类准确度作为指标。
7. 训练模型:
- 使用 `model.fit` 方法来训练模型,将训练数据集 `train_dataset` 传递给它,指定训练轮数为 `args.num_epochs`。
8. 保存模型:
- 使用 `model.save` 方法将训练好的模型保存到指定的目录 `args.train_dir` 中。
9. 构建测试数据集:
- 通过拼接猫和狗的测试文件路径,创建了测试数据集的文件名张量 `test_filenames`。
- 通过拼接零(猫的标签)和一(狗的标签),创建了对应的标签张量 `test_labels`。
- 使用 `tf.data.Dataset.from_tensor_slices` 将文件名和标签组成的元组作为数据集的元素,创建了测试数据集 `test_dataset`。
10. 评估模型:
- 通过迭代测试数据集 `test_dataset`,使用模型进行预测,并计算分类准确度(`sparse_categorical_accuracy`)。
- 最后,打印测试准确度的结果。