我有个文件夹,里面有一万个文件,每个文件都是N个T的容量,那么这就需要逐个文件、逐行读取,读取方法如下:
核心:构造IterableDataset
IterableDataset需要设置两个东西,一个是__init__
,一个是文件的路径列表
class MyIterableDataset(IterableDataset):
def __init__(self, file_list):
super(MyIterableDataset, self).__init__()
self.file_list = file_list # 这里设置所有待读取的文件的目录
def parse_file(self):
for file in self.file_list: # 逐个文件读取
print("读取文件:", file)
with open(file, 'r') as file_obj:
for line in file_obj: # 逐行读取
# 这里可以根据具体文件格式读取,但要记得 yield 一定要返回类似于1行、1个单位的数据,可以对数据加工
yield line
def __iter__(self):
# 如果 batch_size = 3,则会循环3次这个方法
return self.parse_file()
案例代码
有个文件夹datas
,下面有两个文件1.csv
与2.csv
两个文件,
1.csv
的内容是:
1,2,3,4,1
1,2,3,4,2
1,2,3,4,3
1,2,3,4,4
1,2,3,4,5
2.csv
的内容是:
2,2,3,4,1
2,2,3,4,2
2,2,3,4,3
2,2,3,4,4
2,2,3,4,5
逐文件、逐行读取的方法如下:
from torch.utils.data import IterableDataset, DataLoader
import glob
class MyIterableDataset(IterableDataset):
def __init__(self, file_list):
super(MyIterableDataset, self).__init__()
self.file_list = file_list
def parse_file(self):
for file in self.file_list:
print("读取文件:", file)
with open(file, 'r') as file_obj:
for line in file_obj:
yield line
def __iter__(self):
return self.parse_file()
if __name__ == '__main__':
all_file_list = glob.glob("datas/*.csv") # 得到datas目录下的所有csv文件的路径
dataset = MyIterableDataset(all_file_list)
# 这里batch_size=3,意味着每次读取dataloader都会循环三次dataset
# drop_last是指到最后,如果凑够了3个数据就返回,如果凑不够就舍弃掉最后的数据
dataloader = DataLoader(dataset, batch_size=3, drop_last=True)
for data in dataloader:
print(data)
结果:
读取文件: datas/1.csv
['1,2,3,4,1\n', '1,2,3,4,2\n', '1,2,3,4,3\n']
读取文件: datas/2.csv
['1,2,3,4,4\n', '1,2,3,4,5\n', '2,2,3,4,1\n']
['2,2,3,4,2\n', '2,2,3,4,3\n', '2,2,3,4,4\n']
可以看到:
- 读取完第1个文件后,读取第2个文件
- 把第2个文件的第一行和第1个文件的剩下两行凑在一起返回
- 舍弃掉了最后没凑齐3行数据的部分数据