proj view
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GdBF7IMz-1685672388173)(/home/respecting-god/.config/Typora/typora-user-images/image-20230522182456185.png)]
tips:
- 如果使用task2-1作为示例时, 运行process.py的过程中需要确认 process调用的是函数
preprocess_ast_wav2vec(wav, fr)
1.1 任务简介
首个开源的儿科呼吸音数据集, 通过邀请11位医师标注;
292位参与测试者,共8.2个小时。
-
总共2683个录音文件record level, 被标记出了9089个呼吸音event level; (对比icbhi2017是920个录音文件)
-
录音文件被标记为 事件级别 event level 用于 task 1 任务, 和 record level, 用于task2 任务;
任务总共包含两大类,分别如下
# Important Assumption (used in model/metric.py)
# Normal is always index 0
# PQ, if exists, is index 1
def resp_classes(task, level):
assert task in (1,2), 'Task has to be either 1 or 2.'
assert level in (1,2), 'Level has to be either 1 or 2.'
if task==1:
if level==1:
CLASSES = ('Normal', 'Adventitious') # 2 class
elif level==2: # 7 class
CLASSES = ('Normal', 'Rhonchi', 'Wheeze', 'Stridor', 'Coarse Crackle', 'Fine Crackle', 'Wheeze & Crackle')
elif task==2:
if level==1: # 3 class;
CLASSES = ('Normal', 'Poor Quality', 'Adventitious')
elif leve-l==2: # 5 class;
CLASSES = ('Normal', 'Poor Quality', 'CAS', 'DAS', 'CAS & DAS')
return CLASSES
task1, 事件级别的分类, event level :
训练集: 6656份音频事件
测试集: 对应了2433份音频事件;
task2,录音级别的分类, record level,
训练集: 包含1949录音, (注意, 后续通过筛选 task2, 减少为1772 份录音;)
测试集: 734份录音,
1.2 数据预处理
preprocess.py
数据预处理, 详细的分析过程参考第9节;
其中,根据task_config.json 中的配置 data_loader, input_dir
选项中的是 task1 对应processed_wav2vec
or task2 对应processed_ast_wav2vec
,
根据上述不同的任务, preprocess() 函数将调用 不同的预处理函数, processed_wav2vec()
or processed_ast_wav2vec()
,
1.3 项目流程
train.py(): 是整个项目的执行过程的载体;
依次的顺序是,
- 实例化 训练集和验证集;
- 模型实例化:
- 损失函数和评价指标的设定;
- 可学习参数, 优化器以及学习率参数配置;
- 实例化训练类,
- 调度训练类中的
trian
函数, 开始训练;
2. DataLoader的实例化
加载音频文件函数
# location, data/SPRSound/Dataset.py
from torch.utils.data import Dataset
# RespDataLoader 中调用当前类 RespDataset();
class RespDataset(Dataset):
def __init__(self, data_dir, task, input_dir=None):
assert task in (1,2)
self.task = task
task_file_name = 'task1.csv' if task==1 else 'task2_filtered.csv'
# task_file_name = f'task{task}.csv'
self.csv = pd.read_csv(join(data_dir, task_file_name))
self.input_dir = input_dir
if input_dir is None: # note, 这里使用的原始划分的音频文件;
if task == 1: # 若果没有指定 input dir 用于训练的音频文件, 则 clip 中存放的是task1 的事件级别的检测任务;
self.dir = join(data_dir, 'clip')
else: # 如果, task2, 使用wav 文件,其中存放的是record 记录级别的事件;
self.dir = join(data_dir, 'wav')
else: # note , 这里是自定义 的文件夹;
self.dir = join(data_dir, input_dir)
def __len__(self):
return len(self.csv)
def __getitem__(self, index): # 这里获取的是音频, 和对应的label;
entry = self.csv.iloc[index]
wav_name = entry['wav_name']
target = (entry[f'label_{self.task}1'], entry[f'label_{self.task}2'])
if self.input_dir is None:
wav, _ = torchaudio.load(join(self.dir, wav_name))
else:
wav = torch.load(join(self.dir, wav_name), map_location='cpu')
# # normalize
# wav = (wav-37.3)/(2.3*2)
return wav, target
训练集和验证集分别通过调用, 以下函数进行实现;
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()
## 2.0 三个类之间的继承关系;
RespDataLoader(BaseDataLoader)
继承自 BaseDataLoader(DataLoader)
,
BaseDataLoader(DataLoader)
继承自pytorch
中DataLoader()
,
2.1 训练集的实例化:
data_loader = config.init_ob(
data_loader, module_data)
, 其中 参数配置中的data_loader
是指,Json 配置文件中,指定的类 RespDataLoader
, 通过将该类实例化为对象的过程中, 逐个在 重新初始化其父类, 最终将pytorch中的 DataLoader()
该基类重新初始化, 流程如下:
-
data_loader = config.init_ob(
data_loader, module_data)
-
—>
RespDataLoader(BaseDataLoader)
, 调用两个函数:
- 获取当前任务的整体数据集,dataset =
Datasets.RespDataset()
; - 通过重新初始化其父类,获得训练集和测试集的样本下标索引; 具体讲来,其中的
super().__init__(dataset, bt, shuffle, validation_split, num_workers, collate_fn= self.collate_fn)
通过传入参数,重新初始化其父类BaseDataLoader()
,下面进入父类中进行初始化,
- —->
BaseDataLoader(DataLoader)
, 初始化的过程中,分两步走:
-
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
分别生成训练集,和测试集的下标索引。 -
重新初始化所对应的父类
DataLoader()
, 通过传入super().__init__(sampler= self.sampler, **self.init_kwargs)
其中**self.init_kwargs
包含了上一个子类传入的自定义collate_fn
方法; -
上一步中的,将训练集的下标索引,
self.sampler
, 和collate_fn函数
传入到了DataLoader()
中, 从而获取了训练集;
2.2 测试集的实例化:
valid_data_loader = data_loader.split_validation()
调用 BaseDataLoader()
中的 BaseDataLoader().split_validation()
函数,
该函数内部,传入了测试集的下标索引, 并且同样传入了 collate_fn()
函数,通过 **self.init_kwargs
函数;
然后通过调用 pytorch 中的 DataLoader()
获取数据集, DataLoader(sampler = self.valid_sampler, **self.init_kwargs)
,
2.3 class RespDataLoader()
# location: data_loader/data_loaders.py
def resp_classes(task, level):
根据当前任务,
返回当前任务上每个类别所对应的标签;
from data.SPRSound import Datasets
class RespDataLoader(BaseDataLoader):
def __init__(self, ...):
初始化,当前任务上的类别标签属性;
dataset = Datasets.RespDataset(data_dir, task= task, input_dir=input_dir)
# 使用当前类中的属性重新初始化父类BaseDataLoader , 对父类中的 __init__() 函数重新初始化;
super().__init__(dataset, bt, shuffle, validation_split, num_workers, collate_fn=self.collate_fn)
def collate_fn(self, batch):
tensors, targets = [], []
获取一个batch 中的 tensor, 以及对应的label;
# 此处,需要搞清楚,这里的 tensor 到底对应的 特征级别的 tensor, 用于后续直接输入到网络模型中;
# 还是这里tensor 依然代表的是音频数据的 tensor;
return tensors, targets
2.4 class BaseDataLoader()
note: 上面的类RespDataLoader()
,在使用 super().__init__()
函数时,将会重新对父类BaseDataLoader()
进行初始化, 注意, 在传入super().__init__()
中的参数时, 传入了自定义的collate_fn() 函数
# location: base/base_data_loader.py
from torch.utils.data import DataLoader
# 根据 RespDataLoader 中传来的 dataset, 完成训练集 和测试集的划分;
class BaseDataLoader(DataLoader):
def __init__(self, dataset, bt, shuffle, validation_split, num_workers, collate_fn= default_collate)
初始化,训练集测试集的分配比率;
# 分别获取训练集, 验证集的下标索引;
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
# 注意到,这里的初始化参数通过子类RespDataLoader中, 重新传入参数赋值进来, 尤其关注到 collate_fn
# 被重新赋值;
self.init_kwargs = {
'dataset': dataset,
'batch_size':bt,
'shuffle':shuffle,
'collate_fn':collate_fn,
'num_workers':num_workers,
}
def _split_sampler(self, split):
# 将整体数据集,重新划分为训练集和测试集,
# 获取各自训练和验证集上,所对应的下标索引;
def split_validation(self):
# 用于获取验证集的数据,通过 属性,下标索引,
# 传入 DataLoader()
return DataLoader(sampler = self.valid_sampler, **self.init_kwargs)
3. 载入模型
model = config.init_obj('arch', module_arch)
通过关键字arch
获取Json 配置文件中的模型架构名称,
-
以及在当前任务上属于几分类问题,
-
该模型输入的 shape 形状;
之后,通过 getattr(module, module_name)(*args, **module_args)
进入当前调用的模型的初始化函数中去,
class ASTModel(nn.Module)
def __init__():
# 完成该模型的初始化;
4. 损失函数与评价指标的设定
设置当前任务上的损失函数和评价指标,同样是通过Json 文件中去设置的;
"loss": {
"type": "cross_entropy",
"args": {
"weight": [0.2, 0.5, 0.3]
}
},
"metrics": [
"accuracy", "specificity", "sensitivity_task2", "score_task2"
],
# 评价指标,包含4个方面, 精度, 特异度, 敏感度, 分数;
criterion = config.init_ftn('loss', module_loss, device=device)
metric = [getattr(module_metric, met) for met in config['metrics']]
5. 优化器以及学习率的配置
确认可学习参数, 构建优化器, 学习率;
trainable_params = filter(lambda p: p.requires_grad, model.parameters() )
# optimizer 中配置好, 优化器,学习率,可学习参数等信息;
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_sheduler, optimizer)
同样,通过调用config_
中的参数, 取出其中 优化器以及学习率对应的参数信息;
"optimizer": {
"type": "Adam",
"args":{
"lr": 0.0001,
"weight_decay": 0,
"amsgrad": true
}
},
"lr_scheduler": {
"type": "StepLR",
"args": {
"step_size": 50,
"gamma": 0.1
}
},
6. 实例化训练类
训练类的继承关系,
Trainer()
继承自父类BaseTrainer()
, 而 BaseTrainer()
则是最初的基类;
-
trainer = Trainer():
实例化训练类,通过实例化, 该类 Trainer(),trainer = Trainer(传入模型,损失函数, 优化器, 训练集和测试集)
# 实例化,训练类;
trainer = Trainer(model, criterion, metrics, optimizer,
config = config, device = device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler )
6.1 class BaseTrainer()
# current location: base/base_trainer.py
from logger import TensorboardWriter
class BaseTrainer:
def __init__():
初始以下各类属性, 模型, 损失函数, 评价指标;
优化器, epoch 数目;
监视器,用于监控模型的性能,保存住最佳模型,通过 min , val loss 来判断最佳;
可视化实例;
def _train_epoch():
由子类, 重写进行覆盖; 由下面的 train() 函数调用
def train():
train该函数, 在实例化子类Trainer()后,被调用,
作为训练函数的调用接口函数;
并且其自身,调用上面的 _train_epoch()函数;
监听模型性能: 根据指标的变化, 保存当前模型的权重文件;
调用下面的_save_checkpoiont()保存当前模型的训练过程;
def _save_checkpoint():
保存模型的训练信息,
包含模型的参数权重, 状态字典; 当前epoch 数目, 优化器参数;
def _resume_checkpoint();
从保存的训练信息中, 加载模型,继续训练;
6.2 class Trainer()
Trainer()
继承自父类BaseTrainer()
# current location: trainer/trainer.py
from base import BaseTrainer
class Trainer(BaseTrainer):
def __init__():
该初始化函数中,
设置属性,用来 传入训练集, 验证集; 模型;
传入当前任务上的评价指标;
# 传入参数, 重新初始化其父类 BaseTrainer 中的初始化函数;
super().__init__(model, criterion, metric_ftns, optimizer, config)
def _train_epoch(): 该函数,重写了父类中 _trian_epoch()中的方法;
是网络训练的主体部分, 整个训练过程,在这个函数中体现出来;
并将当前epoch 上训练得到的,结果保存在log 中;
for bt_idx, (data, target) in enumerate(self.data_loader):
...
def _valid_epoch();
用于每个epoch 训练结束时, 在_train_epoch() 函数中被调用,得到当前epoch 上的验证精度;
def _progress():
当前epoch 时, 每个batch 达到 self.log_step() 进行打印输出信息, 在_train_epoch() 函数中被调用;
def _createConfusionMatrix():
构建了混淆矩阵, 并且以热力图的形式保存,
当前未找到,调用关系;
6.3 训练流程
训练过程, 下面的第7节,对训练过程进行展开。
trainer.train()
由于 Trainer(BaseTrainer)
Trainer 继承自BaseTrainer
, 所以 trainer.train()
其中的 train() 函数是来自于父类中的函数;
所以 trainer.train()
其实调用的是BaseTrainer.train()
中的 train()
函数;
调用流程:
-
trainer. train()
–>BaseTrainer.train()
-
BaseTrainer.train()
该train() 函数中调用 –>self._train_epoch
() , 该函数在子类Trainer()
中重写,并实现; -
_train_epoch()
中调用 —>self.data_loader ()
, 而 data_loader 中每个batch 的数据加载流程 ,
7 . 训练过程
7.1 训练过程总览
训练过程,按照如下步骤进行分析:
- 训练过程中, 数据获取的流程
- 将优化器中的参数对应的梯度重新置零;
- 数据输入到模型中进行推理, 得到预测值;
- 将预测值和 标签输入到损失函数中,算出loss;
- 将损失开始反向传播,
- 更新优化器中的梯度
- 更新自定义的评价指标的中的性能参数;
- 将以上训练中性能信息 记录到
tensorboard
以及 logger 中; - 当前一个 epoch 训练完成后, 开始在验证集上,进行一次验证,调用验证函数;
- 打印信息,保存权重;
self.data_loader
每次取一个batch 的数据时候调用,最终会调用到 RespDataLoader().collate_fn()
类中的自定义函数,
该函数用于将取出的音频文件,以及对应的标签,打包成一个 batch
的张量数据进行返回。
训练集和测试集data_loder
, valid_data_loader
都是来自于同一个类(RespDataLoader)
的实例化对象, 故这里只以分析 data_loader
为例子,
训练过程中, 每次从训练集(
self.data_loader
)或者验证集(self.valid_data_loader
)中取出一个batch 的数据时,会执行
RespDataLoader().collate_fn()
函数, 用于返回一个batch 的数据。
7.2 训练中- 获取数据的流程:
data_loader
训练集是 RespDataLoader
的一个实例化对象, 通过先后继承父类 BaseDataLoader()
, DataLoader()
当每次从 self.data_loader
中取出一个batch 的数据时, 发生了如下调用事件,
-
调用 –> 私有类中的魔法函数
_BaseDataLoaderIter(object).__next__():
该函数中继续调用– >
self._next_data()
上述的意思即,在该__next__()
魔法函数中调用了 self._next_data()
,
_BaseDataLoaderIter(object)
自身类中,该 _next_data()
私有方法没有实现,
而是 在其子类_SingleProcessDataLoaderIter(_BaseDataLoaderIter)._next_data()
中实现了, 故调用其子类中的该方法。
故这里的实际调用关系是:
—> _BaseDataLoaderIter(object).__next__():
––> 私有单线程类中的方法 _SingleProcessDataLoaderIter(_BaseDataLoaderIter)._next_data()
# location: `torch.utils.data.dataloader.py`中,
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
-
1 而
_SingleProcessDataLoaderIter(_BaseDataLoaderIter)._next_data()
该方法在实现过程中调用 如下函数:—>
self._next_index()
, 当前子类中并没有实现,通过继承使用父类(_BaseDataLoaderIter)
中的该方法,而该父类中
self._next_index()
方法 则继续调用如下方法, –>
return next(self._sampler_iter)
,继续调用–>
torch.utils.data.sampler.py
中类BatchSampler.__iter__()
, 该函数实现了取出一个 batch 批次的数据,所对应的下标索引。2.2 在
self._next_index()
, 调用完成之后,获取了一个batch 数据的下标索引, 则继续调用
self._dataset_fetcher.fetch(index)
,—-> 该函数的实现则是调用了
_MapDatasetFetcher(_BaseDatasetFetcher).fetch()
方法# location: torch.utils.data._utils.fetch.py 中 class _MapDatasetFetcher(_BaseDatasetFetcher): def __init__(self, dataset, auto_collation, collate_fn, drop_last): super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) def fetch(self, possibly_batched_index): if self.auto_collation: # 注意到, 这里通过self.dataset 该属性,获取了该下标所对应的数据; data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)
注意上面的
fetch()
该方法通过self.dataset
属性, 找到当前下标所对应的数据,通过
index
获取data
,发生如下的调用关系事件: —> fetch(index) –>
data = self.dataset[index]
—> 此时,会返回到
Dataset().__getitem__()
,而该
__getitem()
方法,通常是由在子类中实现,这里是RespDataset(Dataset)
,至此, 通过当前下标索引
index
, 获取data
, 注意的这里的data
, 指的是在数据集上,所对应的音频数据以及标签;这里需要通过数据预处理部分,
process.py
来确认,到底特征级别还是音频级别注意,这里获取的音频文件, 如果是自定义的方式,生成的
self.input_dir
, 这里的音频可能便是特征级别的数据;比如输入的
input_dir= processed_ast_wav2vec
, 则是自定义的音频数据,则代表的是特征,这里此时wav= (768, 128)
,
class RespDataset(Dataset):
def __init__():
读入当前任务task 所对应的 .csv 文件,csv 文件,包含了音频以及对应的标签信息;
读入音频文件, 根据传入的音频文件夹的位置;
def __len__():
返回csv 文件的长度,即当前任务上音频的总个数, 包括训练集和验证集;
def __getitem__(self, index): # 这里获取的是音频, 和对应的label;
entry = self.csv.iloc[index]
wav_name = entry['wav_name']
target = (entry[f'label_{self.task}1'], entry[f'label_{self.task}2'])
if self.input_dir is None:
wav, _ = torchaudio.load(join(self.dir, wav_name))
else:
wav = torch.load(join(self.dir, wav_name), map_location='cpu')
# # normalize
# wav = (wav-37.3)/(2.3*2)
return wav, target
2.3 在执行完, data = self.dataset(index)
–>self.dataset.__getitem(index)
后,
则继续执行类 _MapDatasetFetcher(_BaseDatasetFetcher)
中的最后一个方法, return self.collate_fn(data)
;
7.3 collate_fn()
的传递过程
2.4 而collate_fn()
该函数经历怎样的传递过程呢? 首先该方法在 RespDataLoader(BaseDataLoader).collate_fn()
中定义的,
在DataLoader
中调用 __iter()
后, 继续调用自身类中的私有函数_get_iterator()
函数,该函数中继续调用到_SingleProcessDataLoaderIter()
之后collate_fn()
,便在以下的各个类中进行传递 :
_SingleProcessDataLoaderIter()
—> _DatasetKind
—> _MapDatasetFetcher
;
终于,来到了最初在 RespDataLoader().collate_fn()
中设置的方法, 该方法的作用,是将获取的数据和标签打包成一个 batch 的数据,
然后进行返回, 返回的过程便是一个弹栈的过程:
先返回到 –> _SingleProcessDataLoaderIter()._next_data()
中 data= self._dataset_fetcher.fetch(index)
;
–> _BaseDataLoaderIter.__next__()
该魔法函数中的的 data = self._next_data()
—> 回到训练过程中的 for batch_idx, (data, target) in enumerate(self.data_loader):
至此,训练过程中, 训练集数据的提取过程分析完毕;
class RespDataLoader(BaseDataLoader):
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True, task=1, level=1, input_dir='processed'):
self.CLASSES = resp_classes(task, level)
self.CLASS2INT = {label:i for (i, label) in enumerate(self.CLASSES)}
self.LEVEL = level
# note, dataset 获取训练集和 测试集;
dataset = Datasets.RespDataset(data_dir, task=task, input_dir=input_dir)
super().__init__(dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=self.collate_fn)
# 这里根据预处理,获取用于输入的 训练样本 和 标签;
def collate_fn(self, batch):
tensors, targets = [], []
# Gather in lists, and encode labels as indices
for wave, label in batch:
label = label[self.LEVEL-1] # 根据级别,获取当前的label 标签;
tensors += [wave]
targets += [torch.LongTensor([self.CLASS2INT[label]])]
# Group the list of tensors into a batched tensor
tensors = torch.stack(tensors)
targets = torch.stack(targets)
targets.squeeze_(1)
return tensors, targets
8. DataLoader与_BaseDataLoaderIter()
当创建一个 DataLoader()
实例化对象的时候, 实际是在通过 _BaseDataLoaderIter
来迭代数据集,
这样的设计方式,是为了将数据集 和 迭代数据的过程进行分离,
DataLoader()
: 用于管理 dataset, 兵准备好 迭代数据之前所需要的设置;
_BaseDataLoaderIter
: 则是执行,实际的迭代过程, 包括了从线程中获取数据;
这种将 数据集本身 与迭代数据过程的方法 进行分离的方式,
可以通过继承类_BaseDataLoaderIter
方式, 自定义一个子类,在该子类中重写 数据迭代的方式,从而更多的控制数据迭代的过程。
8.1 DataLoader
当在 DataLoader()
调用其中的魔法函数 __iter()
时, 该魔法函数返回的实际上是一个一个_BaseDataLoaderIter
,
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
__iter()
继续调用自身类中的私有函数 _get_iterator()
函数, 可以看到,此时根据是否启用多线程,
将会返回不同的线程迭代数据集的方式, num_worker==0
, 则使用(单进程)主进程完成数据的迭代,
而无论是 单进程_SingleProcessDataLoaderIter(_BaseDataLoaderIter)
还是多进程,他们都是继承的同一个父类_BaseDataLoaderIter
,
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
8.2 _BaseDataLoaderIter
可以看到,这两个类都是继承自_BaseDataLoaderIter
,
_SingleProcessDataLoaderIter(_BaseDataLoaderIter)
_MultiProcessingDataLoaderIter(_BaseDataLoaderIter)
8.3 _SingleProcessDataLoaderIter()
# location: torch.utils.data.dataloader.py
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
可以看到,在执行 data = self._dataset_fetcher.fetch(index)
过程中,调用了私有类_DatasetKind
中的 create_fetcher
方法;
# location: torch.utils.data.dataloader.py
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
create_fetcher
方法中,则继续调用私有类, _MapDatasetFetcher()
#location: torch.utils.data._utils.fetch.py
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
可以,看到从_SingleProcessDataLoaderIter()
开始,
collate_fn
该方法就一直被传递过来,中间在以下的各个类中进行传递如下过程 :
_SingleProcessDataLoaderIter()
—> _DatasetKind
—> _MapDatasetFetcher
;
9. 数据预处理
数据预处理,其实是整个项目的最开始,由于篇幅会较多,故放在这里分析;
tips:
- 如果使用task2-1作为示例时, 运行process.py的过程中需要确认 process调用的是函数
preprocess_ast_wav2vec(wav, fr)
preprocess.py` 数据预处理, 详细的分析过程参考第9节;
其中,根据task_config.json 中的配置 data_loader, input_dir
选项中的是 task1 对应processed_wav2vec
or task2 对应processed_ast_wav2vec
,
根据上述不同的任务, preprocess() 函数将调用 不同的预处理函数, processed_wav2vec()
or processed_ast_wav2vec()
,
# location, data/SPRSound/preprocess2_1.py
# 装饰器,函数, 当不同的任务时, 将会调用不同的预处理函数;
def preprocess(wav,fr):