通过字典映射决定调用哪个py脚本中的方法
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset
ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[cfg.ds]
print(f"ds_list[cfg.ds]: {ds_class}") # proxy_anchor.dataset.Inshop.Inshop_Dataset
# 调用SOP方法,创建Dataset
# 给ds_class赋了一个方法名,就能直接调用方法了
ds_train = ds_class(cfg.path, "train", train_tr)
其中调用的CUBirds方法
cub.py
from .base import *
class CUBirds(BaseDataset):
def __init__(self, root, mode, transform = None):
self.root = root + '/CUB_200_2011'
self.mode = mode
self.transform = transform
if self.mode == 'train':
self.classes = range(0,100)
elif self.mode == 'eval':
self.classes = range(100,200)
BaseDataset.__init__(self, self.root, self.mode, self.transform)
index = 0
for i in torchvision.datasets.ImageFolder(root =
os.path.join(self.root, 'images')).imgs:
# i[1]: label, i[0]: root
y = i[1]
# fn needed for removing non-images starting with `._`
fn = os.path.split(i[0])[1]
if y in self.classes and fn[:2] != '._':
self.ys += [y]
self.I += [index]
self.im_paths.append(os.path.join(self.root, i[0]))
index += 1