学习一下这个里面函数,文件和数据的操作。
除了知道输入是什么,还要知道输出是什么,什么类型,能进行什么操作。
class Dataset(torch.utils.data.Dataset):
def __init__(self, path, stim):
_, _, filenames = next(os.walk(path))
filenames = sorted(filenames)
all_data = []
all_label = []
for dat in filenames:
temp = pickle.load(open(os.path.join(path,dat), 'rb'), encoding='latin1')
all_data.append(temp['data'])
if stim == "Valence":
all_label.append(temp['labels'][:,:1]) #the first index is valence
elif stim == "Arousal":
all_label.append(temp['labels'][:,1:2]) # Arousal #the second index is arousal
self.data = np.vstack(all_data)[:, :32, ] #shape: (1280, 32, 8064) --> take only the first 32 channels
shape = self.data.shape
#perform segmentation=====
segments = 12
self.data = self.data.reshape(shape[0], shape[1], int(shape[2]/segments), segments)
#data shape: (1280, 32, 672, 12)
self.data = self.data.transpose(0, 3, 1, 2)
#data shape: (1280, 12, 32, 672)
self.data = self.data.reshape(shape[0] * segments, shape[1], -1)
#data shape: (1280*12, 32, 672)
#==========================
self.label = np.vstack(all_label) #(1280, 1) ==> 1280 samples,
self.label = np.repeat(self.label, 12)[:, np.newaxis] #the dimension 1 is lost after repeat, so need to unsqueeze (1280*12, 1)
del temp, all_data, all_label
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
single_data = self.data[idx]
single_label = (self.label[idx] > 5).astype(float) #convert the scale to either 0 or 1 (to classification problem)
batch = {
'data': torch.Tensor(single_data),
'label': torch.Tensor(single_label)
}
return batch
文件遍历学习
第一句值得深入研究
_, _, filenames = next(os.walk(path))
os.walk(路径)遍历文件,返回路径,路径下的文件夹,路径下的文件
path='E:\EEG\DATASET\DEAP\deap_set\data_preprocessed_python'
for root, dirs, files in os.walk(path):
print("root:"+root)
print(dirs)
print(files)
它的返回值是一个生成器,只能遍历打印或者用next(),会不断遍历路径下的所有文件夹。
next(可迭代对象,最后默认值)不停迭代,输出下一个对象。
可迭代对象,iterable对象。生成器也是一个可迭代对象。
如果我这样写呢,filenames = next(os.walk(path))
,返回的是一个三元的元组。生成器经过next()变为了元组。
可迭代对象与可迭代器,生成器的分析
可以直接作用于for循环的数据类型有以下几种:
- 一类是集合数据类型,如list、tuple、dict、set、str等;
- 一类是generator,包括生成器和带yield的generator function。
这些可以直接作用于for循环的对象统称为可迭代对象:Iterable。
可以使用isinstance()判断一个对象是否是Iterable对象:
>>> from collections.abc import Iterable
>>> isinstance([], Iterable)
可以被next()函数调用并不断返回下一个值的对象称为迭代器:Iterator。
生成器(generator)不但可以作用于for循环,还可以被next()函数不断调用并返回下一个值,直到最后抛出StopIteration错误表示无法继续返回下一个值了。与可迭代器的区别。
list、dict、str虽然是Iterable,却不是Iterator
为什么呢?
这是因为Python的Iterator对象表示的是一个数据流,Iterator对象可以被next()函数调用并不断返回下一个数据,直到没有数据时抛出StopIteration错误。可以把这个数据流看做是一个有序序列,但我们却不能提前知道序列的长度,只能不断通过next()函数实现按需计算下一个数据,所以Iterator的计算是惰性的,只有在需要返回下一个数据时它才会计算。
Iterator甚至可以表示一个无限大的数据流,例如全体自然数。而使用list是永远不可能存储全体自然数的。
把list、dict、str等Iterable变成Iterator可以使用iter()函数
总结
凡是可作用于for循环的对象都是Iterable类型,它们是有限的有规律的,确定的;
凡是可作用于next()函数的对象都是Iterator类型,它们表示一个惰性计算的序列,用到才给你;
集合数据类型如list、dict、str等是Iterable,但不是Iterator,不过可以通过iter()函数获得一个Iterator对象。
Python的for循环本质上就是通过不断调用next()函数实现的。
_, _, filenames = next(os.walk(path))
这句代码,就是为了得到遍历数据文件名。
1、sorted()
sorted()可以对所有可迭代类型进行排序,并且返回新的已排序的列表。语法如下:
sorted(iterable, cmp=None, key=None, reverse=False)
一共可接受4个参数,含义分别如下:
1.可迭代类型,例如字典、列表、
2.比较函数
3.可迭代类型中某个属性,对给定元素的每一项进行排序
4.降序或升序
加载数据
定义了两个列表,用于存放标签与数据
all_data = []#存数据
all_label = []#存标签
for dat in filenames:
temp = pickle.load(open(os.path.join(path,dat), 'rb'), encoding='latin1')
#加载数据,字典类型
all_data.append(temp['data'])
#取出数据,取出字典中data对应的值
if stim == "Valence":
all_label.append(temp['labels'][:,:1])
#取第一列,为Valence,L[0:3]表示,从索引0开始取,直到索引3为止,但不包括索引3。即索引0,1,2,正好是3个元素。
elif stim == "Arousal":
all_label.append(temp['labels'][:,1:2])
#取第二列, Arousal
我自己写了几句,分析理解了一下
import pprint as pp
import numpy as np
#pp = pprint.PrettyPrinter(indent=4)
path="E:\EEG\DATASET\DEAP\deap_set\data_preprocessed_python"
all_data = []
all_label = []
#遍历文件
_, _, filenames = next(os.walk(path))
filenames = sorted(filenames)
print(filenames)
output:
['s01.dat', 's02.dat', 's03.dat', 's04.dat', 's05.dat', 's06.dat', 's07.dat', 's08.dat', 's09.dat', 's10.dat', 's11.dat', 's12.dat', 's13.dat', 's14.dat', 's15.dat', 's16.dat', 's17.dat', 's18.dat', 's19.dat', 's20.dat', 's21.dat', 's22.dat', 's23.dat', 's24.dat', 's25.dat', 's26.dat', 's27.dat', 's28.dat', 's29.dat', 's30.dat', 's31.dat', 's32.dat']
#分析数据文件
temp = pickle.load(open(os.path.join(path,filenames[0]), 'rb'), encoding='latin1')
print(type(temp))#字典类型
print(temp.keys())#查看包括哪些keys,分别取出
output:
<class 'dict'>
dict_keys(['labels', 'data'])
#取出对应的数据
all_label=temp['labels']
all_data=temp['data']
print(np.shape(all_label))
print(np.shape(all_data))
output:
(40, 4)标签,40条视频
(40, 40, 8064)数据,40条视频,40个通道,每个通道8064个数据点
pp.pprint(all_labels)
output:
array([[7.71, 7.6 , 6.9 , 7.83],
[8.1 , 7.31, 7.28, 8.47],
[8.58, 7.54, 9. , 7.08],
[4.94, 6.01, 6.12, 8.06],
[6.96, 3.92, 7.19, 6.05],
[8.27, 3.92, 7. , 8.03],
pp.pprint(all_data)
output:
array([[[ 9.48231681e-01, 1.65333533e+00, 3.01372577e+00, ...,
-2.82648937e+00, -4.47722969e+00, -3.67692812e+00],
[ 1.24706590e-01, 1.39008270e+00, 1.83509881e+00, ...,
-2.98702069e+00, -6.28780884e+00, -4.47429041e+00],
[-2.21651099e+00, 2.29201682e+00, 2.74636923e+00, ...,
-2.63707760e+00, -7.40651010e+00, -6.75590441e+00],
...,
[ 2.30779684e+02, 6.96716323e+02, 1.19512165e+03, ...,
1.01080949e+03, 1.28312149e+03, 1.51996480e+03],
[-1.54180981e+03, -1.61798052e+03, -1.69268642e+03, ...,
-1.57842691e+04, -1.57823160e+04, -1.57808512e+04],
[ 6.39054310e-03, 6.39054310e-03, 6.39054310e-03, ...,
-9.76081241e-02, -9.76081241e-02, -9.76081241e-02]],
self.data = np.vstack(all_data)[:, :32, ] #沿着竖直方向将矩阵堆叠起来
#shape: (1280, 32, 8064) --> take only the first 32 channels