from LAC
问题描述
在这个框架下,要计算loss的时候:
for i, ((inputs_w, inputs_s), targets) in enumerate(train_loader):
# ---------------------------- compute loss -----------------------------------
我对读取train_loader
产生了疑问,为什么要这么读取呢?
一般来说,是这样读取的:
for i, x in enumerate(train_loader):
可是本来是x
的地方多了那么多括号,很奇怪;
思考方式:
倒着推理,是因为从train_loader
里面读取出来的,所以就是它里面的问题;
train_loader
是怎么创建的呢?
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True, drop_last=True)
其他参数都和数据没有关系,是train_dataset
的锅了;
train_dataset, val_dataset, test_dataset = get_datasets(args)
在函数get_datasets()
中:
def get_datasets(args):
...
train_dataset = data_handler(source_data['train']['images'][split_idx['train']], source_data['train']['labels'][split_idx['train'], :], args.dataset_dir, transform=train_transform)
source_data
没问题,是data_handler
的锅;
用的coco
,
HANDLER_DICT = {
'voc': VOC2012_handler,
'coco': COCO2014_handler,
'nus': NUS_WIDE_handler,
'cub': CUB_200_2011_handler,
}
data_handler = HANDLER_DICT[args.dataset_name]
去找handler
,
class COCO2014_handler(Dataset):
def __init__(self, X, Y, data_path, transform=None):
self.X = X
self.Y = Y
self.transform = transform
self.data_path = data_path
def __getitem__(self, index):
x = Image.open(self.data_path+'/'+self.X[index]).convert('RGB')
x = self.transform(x)
y = self.Y[index]
return x, y
# -> (inputs_s, inputs_w), targets
def __len__(self):
return len(self.X)
其中在__getitem__()
中对x
进行了transform
处理,问题出在这,
在函数get_datasets()
中找到,其使用的transform
:
train_transform = TransformUnlabeled_WS(args)
class TransformUnlabeled_WS(object):
def __init__(self, args):
self.weak = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize((args.img_size, args.img_size)),
transforms.ToTensor()])
self.strong = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize((args.img_size, args.img_size)),
CutoutPIL(cutout_factor=0.5),
RandAugment(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])
def __call__(self, x):
weak = self.weak(x)
strong = self.strong(x)
return weak, strong
# 这里返回了2个
落脚到class TransformUnlabeled_WS()
上,在函数__call__()
返回了weak, strong
两个变量,所以之前会写(inputs_w, inputs_s)
,而不是inputs
;
getitem() 调用情况
以__
开头的函数是隐式调用,我之前从逻辑上推理不出来是在什么时候用的,是我解决不了该问题的关键;
缺乏python函数调用常识,惭愧、惭愧;