import random
class MyDataset:
def __init__(self,all_text,all_label,batch_size,shuffle):
self.all_text = all_text
self.all_label = all_label
self.batch_size = batch_size
self.shuffle = shuffle
# if self.shuffle == True:
# random.shuffle(self.all_text)
# random.shuffle(self.all_label)
assert len(all_text) == len(self.all_label) # 预先 assert
def __getitem__(self, index):
if index < len(self):
text = self.all_text[index]
label = self.all_label[index]
return text,label
else:
return None,None
def __iter__(self): #
# print("__iter__")
self.cursor = 0
self.shuffle_index = [i for i in range(len(self))]
if self.shuffle:
random.shuffle(self.shuffle_index)
return self
def __next__(self): #
# 判读取完没有
if self.cursor >= len(self):
raise StopIteration # 报一个错误 : 终止循环的信号,预习一下,try ,except ,异常捕获,报错机制
# 取一个batch_size 的数据
batch_text = []
batch_label = []
for i in range(self.batch_size):
if self.cursor < len(self.shuffle_index):
index = self.shuffle_index[self.cursor]
text,label = self[index]
batch_text.append(text)
batch_label.append(label)
# 光标后移
self.cursor += 1
return batch_text,batch_label
def __len__(self):
return len(self.all_text)
def get_data():
all_text = ["今天天气正好", "晚上的麻辣烫很难吃", "这件衣服很难看", "早上空腹吃早饭不健康", "晚上早点睡觉很健康"]
all_label = [1, 0, 0, 0, 1]
return all_text,all_label
if __name__ == "__main__":
all_text, all_label = get_data()
batch_size = 2
epoch = 10
shuffle = True
dataset = MyDataset(all_text,all_label,batch_size,shuffle)
for e in range(epoch):
print("*"*100)
for batch_text,batch_label in dataset:#把一个对象放到for in 他会去触发-iter-然后他回去走next
print(batch_text,batch_label)
【无标题】
于 2024-07-17 01:03:31 首次发布