用什么形式的变量读取dataloader,`for i, ((inputs_w, inputs_s), targets) in enumerate(train_loader)`括号里到底怎么写

from LAC

问题描述

在这个框架下,要计算loss的时候:

for i, ((inputs_w, inputs_s), targets) in enumerate(train_loader):
	# ---------------------------- compute loss -----------------------------------
	

我对读取train_loader产生了疑问,为什么要这么读取呢?
一般来说,是这样读取的:

for i, x in enumerate(train_loader):

可是本来是x的地方多了那么多括号,很奇怪;

思考方式:

倒着推理,是因为从train_loader里面读取出来的,所以就是它里面的问题;
train_loader是怎么创建的呢?

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

其他参数都和数据没有关系,是train_dataset的锅了;

train_dataset, val_dataset, test_dataset = get_datasets(args)

在函数get_datasets()中:

def get_datasets(args):
	...
	train_dataset = data_handler(source_data['train']['images'][split_idx['train']], source_data['train']['labels'][split_idx['train'], :], args.dataset_dir, transform=train_transform)

source_data没问题,是data_handler的锅;
用的coco

HANDLER_DICT = {
    'voc': VOC2012_handler,
    'coco': COCO2014_handler,
    'nus': NUS_WIDE_handler,
    'cub': CUB_200_2011_handler,
}
data_handler = HANDLER_DICT[args.dataset_name]

去找handler

class COCO2014_handler(Dataset):
    def __init__(self, X, Y, data_path, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform
        self.data_path = data_path

    def __getitem__(self, index):
        x = Image.open(self.data_path+'/'+self.X[index]).convert('RGB')
        x = self.transform(x)
        y = self.Y[index]
        return x, y
        # -> (inputs_s, inputs_w), targets

    def __len__(self):
        return len(self.X)

其中在__getitem__()中对x进行了transform处理,问题出在这,
在函数get_datasets()中找到,其使用的transform

train_transform = TransformUnlabeled_WS(args)

class TransformUnlabeled_WS(object):
    def __init__(self, args):
        self.weak = transforms.Compose([
			transforms.RandomHorizontalFlip(),
			transforms.Resize((args.img_size, args.img_size)),
			transforms.ToTensor()])

        self.strong = transforms.Compose([
			transforms.RandomHorizontalFlip(),
			transforms.Resize((args.img_size, args.img_size)),
			CutoutPIL(cutout_factor=0.5),
			RandAugment(),
			transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
			transforms.ToTensor()])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return weak, strong
        # 这里返回了2个

落脚到class TransformUnlabeled_WS()上,在函数__call__()返回了weak, strong两个变量,所以之前会写(inputs_w, inputs_s),而不是inputs

getitem() 调用情况

__开头的函数是隐式调用,我之前从逻辑上推理不出来是在什么时候用的,是我解决不了该问题的关键;
缺乏python函数调用常识,惭愧、惭愧;

  • 24
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
给你提供了完整代码,但在运行以下代码时出现上述错误,该如何解决?Batch_size = 9 DataSet = DataSet(np.array(x_train), list(y_train)) train_size = int(len(x_train)*0.8) test_size = len(y_train) - train_size train_dataset, test_dataset = torch.utils.data.random_split(DataSet, [train_size, test_size]) TrainDataloader = Data.DataLoader(train_dataset, batch_size=Batch_size, shuffle=False, drop_last=True) TestDataloader = Data.DataLoader(test_dataset, batch_size=Batch_size, shuffle=False, drop_last=True) model = Transformer(n_encoder_inputs=3, n_decoder_inputs=3, Sequence_length=1).to(device) epochs = 10 optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) criterion = torch.nn.MSELoss().to(device) val_loss = [] train_loss = [] best_best_loss = 10000000 for epoch in tqdm(range(epochs)): train_epoch_loss = [] for index, (inputs, targets) in enumerate(TrainDataloader): inputs = torch.tensor(inputs).to(device) targets = torch.tensor(targets).to(device) inputs = inputs.float() targets = targets.float() tgt_in = torch.rand((Batch_size, 1, 3)) outputs = model(inputs, tgt_in) loss = criterion(outputs.float(), targets.float()) print("loss", loss) loss.backward() optimizer.step() train_epoch_loss.append(loss.item()) train_loss.append(np.mean(train_epoch_loss)) val_epoch_loss = _test() val_loss.append(val_epoch_loss) print("epoch:", epoch, "train_epoch_loss:", train_epoch_loss, "val_epoch_loss:", val_epoch_loss) if val_epoch_loss < best_best_loss: best_best_loss = val_epoch_loss best_model = model print("best_best_loss ---------------------------", best_best_loss) torch.save(best_model.state_dict(), 'best_Transformer_trainModel.pth')
07-15

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值