以一个简单的数据为例
import torch
import torch.utils.data as data
import random
x = list(range(20))
class Mydataset(data.Dataset):
def __init__(self, x):
self.x = x
self.random = random.random()
self.real_reference_probability = 0 #自定义一个概率参数用于改变匹配数据的条件判断
def __getitem__(self, index):
x = self.x[index]
random_p = random.random()
if random_p < self.real_reference_probability:
y = 2 * x + self.random +1
z = True
else:
y = 2 * x
z = False
return x, y ,z
def __len__(self):
return len(self.x)
#定义dataload
my_dataset = Mydataset(x)
my_dataload = data.DataLoader(
my_dataset,
batch_size=5,
shuffle=True,
)
for epoch in range(10):
for i, data in enumerate(my_dataload):
# 将数据从 train_loader 中读出来,一次读取的样本数是32个
x, y ,z = data
print("epoch:", epoch, "的第" , i, z) #打印代替训练过程
my_dataset.real_reference_probability +=0.1 #通过改变自定的参数,逐步增大匹配的难度
打印结果,可以看到输出的z逐渐替换成true
epoch: 0 的第 0 tensor([False, False, False, False, False])
epoch: 0 的第 1 tensor([False, False, False, False, False])
epoch: 0 的第 2 tensor([False, False, False, False, False])
epoch: 0 的第 3 tensor([False, False, False, False, False])
epoch: 1 的第 0 tensor([False, True, False, False, False])
epoch: 1 的第 1 tensor([False, False, False, False, False])
epoch: 1 的第 2 tensor([False, False, False, False, True])
epoch: 1 的第 3 tensor([False, False, False, False, False])
epoch: 2 的第 0 tensor([False, True, False, True, False])
epoch: 2 的第 1 tensor([False, False, False, False, False])
epoch: 2 的第 2 tensor([False, False, True, False, False])
epoch: 2 的第 3 tensor([False, False, False, False, False])
epoch: 3 的第 0 tensor([False, False, True, False, True])
epoch: 3 的第 1 tensor([False, False, True, False, False])
epoch: 3 的第 2 tensor([False, False, False, True, False])
epoch: 3 的第 3 tensor([False, True, True, True, False])
epoch: 4 的第 0 tensor([False, False, False, False, True])
epoch: 4 的第 1 tensor([False, False, False, False, False])
epoch: 4 的第 2 tensor([False, True, True, True, True])
epoch: 4 的第 3 tensor([False, True, False, True, False])
epoch: 5 的第 0 tensor([ True, False, False, True, False])
epoch: 5 的第 1 tensor([ True, True, True, False, False])
epoch: 5 的第 2 tensor([ True, False, False, False, False])
epoch: 5 的第 3 tensor([ True, False, True, False, True])
epoch: 6 的第 0 tensor([ True, False, True, False, False])
epoch: 6 的第 1 tensor([False, False, False, False, True])
epoch: 6 的第 2 tensor([ True, False, True, True, True])
epoch: 6 的第 3 tensor([ True, True, True, False, True])
epoch: 7 的第 0 tensor([False, False, False, False, False])
epoch: 7 的第 1 tensor([False, True, False, True, True])
epoch: 7 的第 2 tensor([True, True, True, True, True])
epoch: 7 的第 3 tensor([False, True, True, True, False])
epoch: 8 的第 0 tensor([True, True, True, True, True])
epoch: 8 的第 1 tensor([ True, False, True, False, False])
epoch: 8 的第 2 tensor([False, True, True, True, True])
epoch: 8 的第 3 tensor([True, True, True, True, True])
epoch: 9 的第 0 tensor([True, True, True, True, True])
epoch: 9 的第 1 tensor([ True, True, False, True, True])
epoch: 9 的第 2 tensor([ True, True, True, False, True])
epoch: 9 的第 3 tensor([ True, True, True, True, False])