TensorFlow-similarity 学习笔记2
2021SC@SDUSC
学习内容:Tensorflow Similarity Supervised Learning Hello World
目录:similarity/examples/supervised_hello_world.ipynb
TensorFlowSimilarity 是一个专注于让相似学习快捷简便的python库。
本学习笔记演示了如何使用TensorFlow Similarity在一小部分MNIST classes的基础上来训练SimilarityModel(),这个模型能够从MNIST数据集中查询和提取出相似的图片
DATA preparation
我们将加载 MNIST 数据集并将我们的训练数据限制为 10 个类中的 N 个(默认为 6 个),以展示模型如何从训练期间未见过的类中找到类似的示例。该模型无需重新训练即可将匹配推广到未见过的类的能力是我们希望使用metric learning的主要原因之一。
警告:TensorFlowSimilarity期望 y_train 是一个 IntTensor,其中包含每个示例的类 ID,而不是传统上用于多类分类的标准分类编码。
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
要有效地学习相似性模型,每批必须包含每个类至少 2 个示例。
为了方便做到这一点,tf_similarity提供Samplers(),使我们能够设置每批次的类数和每个类的最小示例数。在这里,我们正在创建一个MultiShotMemorySampler(),它允许您采样内存内数据集,并提供每个类的多个示例。
TensorFlowSimilarity提供各种samplers,以满足不同的要求,包括用于Single-shot学习的SingleShotMemorySampler(),直接与 TensorFlow 数据集目录集成的 TFDatasetMultiShotMemorySampler()和 TFRecordDatasetSampler() 允许我们从存储在磁盘上的非常大的数据集(TFRecords shards)中取样。
CLASSES = [2, 3, 1, 7, 9