目录
2.2.1 读取train,evaluation和test图片至变量data中。
2.2.2 设计one shot learning的学习过程
2.2.3 真实数据的one shot learning过程
1. 代码中所有Classes和类之间的调用关系
如下:
类 | 调用类 | 再调用类 | 所在文件 |
Class MiniImageNetDataSet | - | - | data.py |
Class Classifier | Class MetaConvolution | Class MetaNetwork | meta_matching_network.py |
Class TaskContextEncoder | Class TaskTransformer | - | |
Class DistanceNetwork | - | - | |
Class Extractor | - | - | |
Class AttentionClassify | - | - |
2. 代码思路解析
2.1 代码的入口
train_meta_matching_network.py
在data.py中设置图片路径和csv路径:
resizetargetpath
csv_file_dir
2.2 数据读取
2.2.1 读取train,evaluation和test图片至变量data中。
data = dataset.MiniImageNetDataSet(batch_size=batch_size, classes_per_set=classes_per_set, samples_per_class=samples_per_class, shuffle_classes=True)
data.datasets["train"] size:[64,600,84,84,3] # 共64类,每个类600张图片,共38400张图片。[84,84,3]是图片的长宽和3个channels data.datasets["eval"] size:[16,600,84,84,3] # 共16类,每个类600张图片,共9600张图片 data.datasets["test"] size:[20,600,84,84,3] # 共20类,每个类600张图片,共12000张图片
2.2.2 设计one shot learning的学习过程
experiment = Experim