torch.utils.data.dataset.random_split随机划分后对划分后数据处理
在使用torch.utils.data.dataset.random_split后,生成同属于Dataset类型的Subset类,若想对划分后的训练集(train)和验证集(validation)再进行处理,只需对train_set对象进行浅拷贝即可改变类内属性。
data_set = MySegmentation(cfg, split='train')
# data_set.change_split()
n_val = int(len(data_set) * cfg["train"]["val_percent"])
n_train = len(data_set) - n_val
train_set, val_set = random_split(data_set, [n_train, n_val])
# 对划分后的数据集浅拷贝,修改默认类内属性
train_set.dataset = copy(data_set)
val_set.dataset.split = "val"
print("train dataset:{}".format(len(train_set)))
print("validation dataset:{}".format(len(val_set)))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True,**kwargs)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs)
return train_loader, val_loader
在随机划分数据集后,数据集为subset类,包含两个对象dataset和indices,其中indice为对应随机抽出的索引位置。其中dataset在划分的 训练集 和 测试集中数据仍指向相同地址,改变其中一个对象属性则都会全部修改。
通过使用浅拷贝的方法,改变其中一个数据集的指向地址,再对属性进行修改,就会修改指定的对象属性。
上图则是浅拷贝之后改变数据集地址。