【Pytorch】random_split()随机划分后不同数据集做不同数据增强

torch.utils.data.dataset.random_split随机划分后对划分后数据处理

在使用torch.utils.data.dataset.random_split后,生成同属于Dataset类型的Subset类,若想对划分后的训练集(train)和验证集(validation)再进行处理,只需对train_set对象进行浅拷贝即可改变类内属性。

         data_set = MySegmentation(cfg, split='train')
	     # data_set.change_split()
	     n_val = int(len(data_set) * cfg["train"]["val_percent"])
	     n_train = len(data_set) - n_val
	     train_set, val_set = random_split(data_set, [n_train, n_val])
	     # 对划分后的数据集浅拷贝,修改默认类内属性
  		 train_set.dataset = copy(data_set)
         val_set.dataset.split = "val"
	     print("train dataset:{}".format(len(train_set)))
	     print("validation dataset:{}".format(len(val_set)))
	     train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True,**kwargs)
	     val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs)
	     return train_loader, val_loader

在随机划分数据集后,数据集为subset类,包含两个对象dataset和indices,其中indice为对应随机抽出的索引位置。其中dataset在划分的 训练集 和 测试集中数据仍指向相同地址,改变其中一个对象属性则都会全部修改。
在这里插入图片描述
通过使用浅拷贝的方法,改变其中一个数据集的指向地址,再对属性进行修改,就会修改指定的对象属性。
在这里插入图片描述
上图则是浅拷贝之后改变数据集地址。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值