DataLoader 和 Dataset

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。

DataLoader是一个比较重要的类,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)
当我们集成了一个 Dataset类之后,我们需要重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引

from torch.utils.data import Dataset
class PTB(Dataset):
    """battery dataset."""
    def __init__(self, data_dir, split,battery_dataset=[],**kwargs):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            data_dir (string): data path0
        """
        super().__init__()
        self.data_dir = data_dir
        try:
            for file in os.listdir(self.data_dir):
                # print("file",os.path.join(data_dir,file))
                df = pd.read_csv(os.path.join(data_dir,file), encoding="gbk")

                # self.battery_frame = df.values
                # # print("self.battery_frame",self.battery_frame)
                # # print("self.battery_frame",self.battery_frame.shape)
                # battery_dataset.append(self.battery_frame)

                windows=32
                windows_move=1
                if df.shape[0]>=windows:
                    self.battery_frame = df.values
                    # print("self.battery_frame",self.battery_frame)
                    # print("self.battery_frame",self.battery_frame.shape)
                    
                    feature_num = self.battery_frame.shape[0]-windows+windows_move
                    for index in range(0,feature_num,windows_move):
                        feature_df = self.battery_frame[index:(index + windows)]                
                        battery_dataset.append(feature_df)
                    self.battery_dataset = battery_dataset
        except RuntimeError:
            pass
        print(len(self.battery_dataset))
    def __len__(self):
        #返回文件数据的数目
        print(len(self.battery_dataset))
        return len(self.battery_dataset)
        # return 1800000
    def __getitem__(self, idx):
        #接收一个索引,返回一个样本(tensor维度相同)
        print (idx)
        # battery = self.battery_frame.get_chunk(128).as_matrix().astype('float')
        # battery = self.battery_dataset[idx].as_matrix().astype('float')
        battery = self.battery_dataset[idx]
        print("__getitem__",battery.shape)

        return battery
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值