Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。
DataLoader是一个比较重要的类,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)
当我们集成了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引
from torch.utils.data import Dataset
class PTB(Dataset):
"""battery dataset."""
def __init__(self, data_dir, split,battery_dataset=[],**kwargs):
"""
Args:
csv_file (string): Path to the csv file with annotations.
data_dir (string): data path0
"""
super().__init__()
self.data_dir = data_dir
try:
for file in os.listdir(self.data_dir):
# print("file",os.path.join(data_dir,file))
df = pd.read_csv(os.path.join(data_dir,file), encoding="gbk")
# self.battery_frame = df.values
# # print("self.battery_frame",self.battery_frame)
# # print("self.battery_frame",self.battery_frame.shape)
# battery_dataset.append(self.battery_frame)
windows=32
windows_move=1
if df.shape[0]>=windows:
self.battery_frame = df.values
# print("self.battery_frame",self.battery_frame)
# print("self.battery_frame",self.battery_frame.shape)
feature_num = self.battery_frame.shape[0]-windows+windows_move
for index in range(0,feature_num,windows_move):
feature_df = self.battery_frame[index:(index + windows)]
battery_dataset.append(feature_df)
self.battery_dataset = battery_dataset
except RuntimeError:
pass
print(len(self.battery_dataset))
def __len__(self):
#返回文件数据的数目
print(len(self.battery_dataset))
return len(self.battery_dataset)
# return 1800000
def __getitem__(self, idx):
#接收一个索引,返回一个样本(tensor维度相同)
print (idx)
# battery = self.battery_frame.get_chunk(128).as_matrix().astype('float')
# battery = self.battery_dataset[idx].as_matrix().astype('float')
battery = self.battery_dataset[idx]
print("__getitem__",battery.shape)
return battery