原型网络对训练集中不存在的类别也具有泛化能力,与孪生网络一样,也试图学习度量空间来进行分类,基本思想是创建每个类的原型表示,并根据类原型与查询点之间的距离对查询点(新点)进行分类。
使用原型网络对Omniglot字符集分类
import os
import glob
from PIL import Image
import numpy as np
import tensorflow as tf
root_dir = 'data/'
# 该文件包括语言名称、旋转信息和字符数量
train_split_path = os.path.join(root_dir, 'splits', 'train.txt')
with open(train_split_path, 'r') as train_split:
train_classes = [line.rstrip() for line in train_split.readlines()]
# 类的数量
no_of_classes = len(train_classes)
# 样本数量
num_examples = 20
img_width = 28
img_height = 28
channels = 1
#初始化数据集的类的数量、样本数量、图像高度和宽度
train_dataset = np.zeros([no_of_classes, num_examples, img_height, img_width], dtype=np.float32)
# 读取所有图像,转换为numpy数组