这是第一个版本,在此我们要明白最重要的一点就是对于取数据的时候的逻辑操作,在循环一个对象的时候,我们先去经历一次iter,然后iter会返回一个---next--对象,我们去到里面取数据,还有就是再取数据的时候我们要知道self【】指的就是这个对象的本身。
class MyDataset:
def __init__(self,all_text,all_label,batch_size):
self.all_text = all_text
self.all_label = all_label
self.batch_size = batch_size
assert len(all_text) == len(self.all_label) ,"数据和标签的长度都不一样,玩毛线啊!" # 预先 assert
self.cursor = 0
def __getitem__(self, index):#根据index一个一个获取
if index < len(self):
text = self.all_text[index]
label = self.all_label[index]
return text,label
else:
return None,None
def __iter__(self): #
return self
def __next__(self): #
# 判读取完没有
if self.cursor >= len(self):
raise StopIteration # 报一个错误 : 终止循环的信号,预习一下,try ,except ,异常捕获,报错机制
# 取一个batch_size 的数据
batch_text = []
batch_label = []
for i in range(self.batch_size):
text,label = self[self.cursor]#这个self就是对象本身,self[]就是__getitem__()方法,因为里面给的是index返回数据
if text != None:
batch_text.append(text)
batch_label.append(label)
# 光标后移
self.cursor += 1
return batch_text,batch_label
def __len__(self):
return len(self.all_text)
def get_data():
all_text = ["今天天气正好", "晚上的麻辣烫很难吃", "这件衣服很难看", "早上空腹吃早饭不健康", "晚上早点睡觉很健康"]
all_label = [1, 0, 0, 0,1]
return all_text,all_label
if __name__ == "__main__":
all_text, all_label = get_data()
batch_size = 4
epoch = 10
dataset = MyDataset(all_text,all_label,batch_size)
dataset2 = MyDataset(all_text,all_label,batch_size)
dataset3 = MyDataset(all_text,all_label,batch_size)
for e in range(epoch):
for batch_text,batch_label in dataset:
#当对象放在for in 里面的时候,他触发一次iter()方法,(iterator迭代器,返回一个具有--nett的对象)然后再去调用__next__()方法,
# 然后在__next__()方法里面去调用__getitem__()方法,然后返回数据,然后再把数据放到for循环里面
# 里面的循环是我们去dataset拿batchsize次数据
print(batch_text,batch_label)
#
这是第二个版本
在这个版本中我们想要 将取的数据进行打乱,前面被引掉的代码他打乱的逻辑是将数据整体打乱,然后去取数据,那样非常的消耗内存,我们可以固定数据不动,然后将取数据的方式改变一下,我们可以将索引打乱,然后进行数据的提取。因为每个循环都会进入iter,所以我们在iter中将索引给打乱,后面取数据也就不一样了,这样在next中取数据的时候也就不用判断会不会是空值了,直接取就行。
import random
class MyDataset:
def __init__(self,all_text,all_label,batch_size,shuffle):
self.all_text = all_text
self.all_label = all_label
self.batch_size = batch_size
self.shuffle = shuffle
# if self.shuffle == True:
# random.shuffle(self.all_text)
# random.shuffle(self.all_label)
assert len(all_text) == len(self.all_label) # 预先 assert
self.cursor = 0
def __getitem__(self, index):
if index < len(self):
text = self.all_text[index]
label = self.all_label[index]
return text,label
else:
return None,None
def __iter__(self): #
# print("__iter__")
self.cursor = 0
self.shuffle_index = [i for i in range(len(self))]
if self.shuffle:
random.shuffle(self.shuffle_index)
return self
def __next__(self): #
# 判读取完没有
if self.cursor >= len(self):
raise StopIteration # 报一个错误 : 终止循环的信号,预习一下,try ,except ,异常捕获,报错机制
# 取一个batch_size 的数据
batch_text = []
batch_label = []
for i in range(self.batch_size):
if self.cursor < len(self.shuffle_index):
index = self.shuffle_index[self.cursor]
text,label = self[index]
batch_text.append(text)
batch_label.append(label)
# 光标后移
self.cursor += 1
return batch_text,batch_label
def __len__(self):
return len(self.all_text)
def get_data():
all_text = ["今天天气正好", "晚上的麻辣烫很难吃", "这件衣服很难看", "早上空腹吃早饭不健康", "晚上早点睡觉很健康"]
all_label = [1, 0, 0, 0, 1]
return all_text,all_label
if __name__ == "__main__":
all_text, all_label = get_data()
batch_size = 2
epoch = 10
shuffle = True
dataset = MyDataset(all_text,all_label,batch_size,shuffle)
for e in range(epoch):
print("*"*100)
for batch_text,batch_label in dataset:
print(batch_text,batch_label)
第三个版本使用numpy
list1 = [1,2,3,4,5]
a = list1[1]
print(a)
2
# 我想要取前五个数据就只能是list1[0:6],它取得都是连续的
# 我们可以给list套一个numpy
list2 = np.array(list1)
b = list2[[1,3]]
print(b)
[2 4]
在使用numpy的时候,取数据就会变得很简单了,我们在一开始传入数据的时候就将其变成numpy形式,在打乱索引的时候也可以用numpy直接打乱即可最方便的是我们直接去掉了gititem,之前用getitem是为了在--next--中传一个索引值,然后调用getitem去取数据,然而我们用了numpy之后,它满足一个高级索引,就可以直接取到我们想要的数据了。
import random
import numpy as np
class MyDataset:
def __init__(self,all_text,all_label,batch_size,shuffle):
# self.all_text = all_text
# self.all_label = all_label
self.all_text = np.array(all_text) # 为什么换成numpy?不换行不行?
self.all_label = np.array(all_label) # 为什么换成numpy?不换行不行?
self.batch_size = batch_size
self.shuffle = shuffle
assert len(all_text) == len(self.all_label) # 预先 assert
self.cursor = 0
def __iter__(self): #
self.cursor = 0
self.shuffle_index = np.arange(len(self))
if self.shuffle:
np.random.shuffle(self.shuffle_index)
return self
def __next__(self): #
# 判读取完没有
if self.cursor >= len(self):
raise StopIteration # 报一个错误 : 终止循环的信号,预习一下,try ,except ,异常捕获,报错机制
# 取一个batch_size 的数据
batch_index = self.shuffle_index[self.cursor:self.cursor+self.batch_size]
batch_text = self.all_text[batch_index] # 为什么换成numpy?不换行不行?
batch_label = self.all_label[batch_index]
self.cursor += batch_size
return batch_text,batch_label
def __len__(self):
return len(self.all_text)
def get_data():
all_text = ["今天天气正好", "晚上的麻辣烫很难吃", "这件衣服很难看", "早上空腹吃早饭不健康", "晚上早点睡觉很健康"]
all_label = [1, 0, 0, 0, 1]
return all_text,all_label
if __name__ == "__main__":
all_text, all_label = get_data()
batch_size = 2
epoch = 10
shuffle = True
dataset = MyDataset(all_text,all_label,batch_size,shuffle)
for e in range(epoch):
print("*"*100)
for batch_text,batch_label in dataset:
print(batch_text,batch_label)
第四个版本
这个版本就是将dataset和dataloader给分开,这一部分记得怎么联系。
import random
import numpy as np
class MyDataset:
def __init__(self,all_text,all_label,batch_size,shuffle):
# self.all_text = all_text
# self.all_label = all_label
self.all_text = np.array(all_text) # 为什么换成numpy?不换行不行?
self.all_label = np.array(all_label) # 为什么换成numpy?不换行不行?
self.batch_size = batch_size
self.shuffle = shuffle
assert len(all_text) == len(self.all_label) # 预先 assert
def __iter__(self): #
return MyDataLoader(self)
def __len__(self):
return len(self.all_text)
class MyDataLoader:
def __init__(self,dataset):
self.cursor = 0
self.dataset = dataset
self.shuffle_index = np.arange(len(dataset))
if self.dataset.shuffle:
np.random.shuffle(self.shuffle_index)
def __next__(self): #
# 判读取完没有
if self.cursor >= len(dataset):
raise StopIteration # 报一个错误 : 终止循环的信号,预习一下,try ,except ,异常捕获,报错机制
# 取一个batch_size 的数据
batch_index = self.shuffle_index[self.cursor:self.cursor + self.dataset.batch_size]
batch_text = self.dataset.all_text[batch_index] # 为什么换成numpy?不换行不行?
batch_label = self.dataset.all_label[batch_index]
self.cursor += self.dataset.batch_size
return batch_text, batch_label
def get_data():
all_text = ["今天天气正好", "晚上的麻辣烫很难吃", "这件衣服很难看", "早上空腹吃早饭不健康", "晚上早点睡觉很健康"]
all_label = [1, 0, 0, 0, 1]
return all_text,all_label
if __name__ == "__main__":
all_text, all_label = get_data()
batch_size = 2
epoch = 10
shuffle = True
dataset = MyDataset(all_text,all_label,batch_size,shuffle)
for e in range(epoch):
print("*"*100)
for batch_text,batch_label in dataset:
print(batch_text,batch_label) #