最近,通过pytorch对自己的数据集实现了数据分类,从导入参数(数据集、模型等路径以及模型超参数)、训练设备信息、数据集信息、训练组件配置信息、特征提取,模型训练及预测全过程实战演练。
基于pytorch的音频分类实现
1. 参数导出模块
该模块的功能主要是模型训练前初始化参数,将初始化的参数,全部放在parse_args(),代码如下:
def get_args():
parser = argparse.ArgumentParser(description=train')
parser.add_argument(
'-t',
type=str,
default='pytorch-audioclassification-master',
help="the theme's name of your task"
)
parser.add_argument(
'-dp',
type=str,
default=r'E:\data\origin_data',
help="train's directory"
)
parser.add_argument(
'-classes',
type=list,
default=trans(r'E:\data\scatter\classes.txt'),
help="classes list"
)
parser.add_argument(
'-infop',
type=str,
default=r'\data\folder\refer.csv',
help="DIF(folder information file)'s path"
)
parser.add_argument(
'-tp',
type=float,
default=0.9,
help="train folder's percent"
)
parser.add_argument(
'-bs',
type=int,
default=16,
help="folder's batch size"
)
parser.add_argument(
'-cn',
type=int,
default=10,
help='the number of classes'
)
parser.add_argument(
'-e',
type=int,
default=5,
help='epoch'
)
parser.add_argument(
'-lr',
type=float,
default=0.001,
help='learning rate'
)
parser.add_argument(
'-ld',
type=str,
default=r'E:\workdir',
help="the training log's save directory"
)
return parser.parse_args()
2. 训练设备信息
该模块的功能主要是:是否使用GPU训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
3. 数据集信息
该模块的功能主要是:根据初始化的数据路径,配置train_dl, valid_dl, samples_num, train_num, valid_num数据,代码如下:
train_dl, valid_dl, samples_num, train_num, valid_num = get_dataloader(args.infop, args.dp, args.bs, args.tp)
4. 训练组件配置
该模块的功能主要是:构建网络模型,并将网络模型的参数初始化,代码如下:
model = AudioClassificationModel(num_classes=args.cn).to(device)
optimizer = Adam(params=model.parameters(), lr=args.lr)
loss_fn = CrossEntropyLoss()
class AudioClassificationModel(Module):
def __init__(self, num_classes):
super().__init__()
conv_layers = []
self.conv1 = Conv2d(2, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
self.relu1 = ReLU()
self.bn1 = BatchNorm2d(8)
conv_layers += [self.conv1, self.relu1, self.bn1]
self.conv2 = Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.relu2 = ReLU()
self.bn2 = BatchNorm2d(16)
conv_layers += [self.conv2, self.relu2, self.bn2]
self.conv3 = Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.relu3 = ReLU()
self.bn3 = BatchNorm2d(32)
conv_layers += [self.conv3, self.relu3, self.bn3]
self.conv4 = Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.relu4 = ReLU()
self.bn4 = BatchNorm2d(64)
conv_layers += [self.conv4, self.relu4, self.bn4]
self.ap = AdaptiveAvgPool2d(output_size=1)
self.classification = Linear(in_features=64, out_features=num_classes)
self.conv = Sequential(*conv_layers) # *List:"解引用"list,conv_layers是[[],[]]形式的
def forward(self, x):
x = self.conv(x)
# flatten
x = self.ap(x)
x = x.view(x.shape[0], -1)
x = self.classification(x)
return x
5. 模型训练
该模块的功能主要是:模型迭代,保存最优训练模型等信息,代码如下:
for epoch in range(args.e):
prediction = []
label = []
score = []
model.train()
train_bar = tqdm(iter(train_dl), ncols=150, colour='red')
train_loss = 0.
i = 0
for train_data in train_bar:
x_train, y_train = train_data
x_train = x_train.to(device)
y_train = y_train.to(device).long()
output = model(x_train)
loss = loss_fn(output, y_train)
optimizer.zero_grad()
# clone().detach():
train_loss += loss.clone().detach().cpu().numpy()
loss.backward()
optimizer.step()
train_bar.set_description("Epoch:{}/{} Step:{}/{}".format(epoch + 1, args.e, i + 1, len(train_dl)))
train_bar.set_postfix({"train loss": "%.3f" % loss.data})
i += 1
train_loss = train_loss / i
losses.append(train_loss)
model.eval()
valid_bar = tqdm(iter(valid_dl), ncols=150, colour='red')
valid_acc = 0.
valid_pre = 0.
valid_recall = 0.
valid_f1 = 0.
valid_auc = 0.
valid_ap = 0.
i = 0
for valid_data in valid_bar:
x_valid, y_valid = valid_data
x_valid = x_valid.to(device)
y_valid_ = y_valid.clone().detach().numpy().tolist()
output = model(x_valid) # shape:(N*cls_n)
output_ = output.clone().detach().cpu()
_, pred = torch.max(output_, 1)
pred_ = pred.clone().detach().numpy().tolist()
output_ = output_.numpy().tolist()
# 显示每一批次的acc/precision/recall/f1
valid_bar.set_description("Epoch:{}/{} Step:{}/{}".format(epoch + 1, args.e, i + 1, len(valid_dl)))
prediction = prediction + pred_
label = label + y_valid_
score = score + output_
i += 1
valid_acc = accuracy_score(y_true=label, y_pred=prediction)
valid_pre = precision_score(y_true=label, y_pred=prediction, average='weighted')
valid_recall = recall_score(y_true=label, y_pred=prediction, average='weighted')
valid_f1 = f1_score(y_true=label, y_pred=prediction, average='weighted')
# valid_auc = roc_auc_score(y_true=label, y_score=score, average='weighted', multi_class="ovr")
# valid_ap = average_precision_score(y_true=label, y_score=score)
accuracies.append(valid_acc)
precisions.append(valid_pre)
recalls.append(valid_recall)
f1s.append(valid_f1)
if valid_f1 >= max(f1s):
best_checkpoint = model
6. 模型预测
该模块的功能主要是:预测音频的类别,代码如下:
model = torch.load(args.wp)
model = model.to(device)
inputs = np.load(args.fp)
inputs = ToTensor()(inputs).permute(1, 2, 0).unsqueeze(0).to(device)
output = model(inputs) # shape:(N*cls_n)
output_ = output.clone().detach().cpu()
pred = torch.max(output_, 1)