CNN和RNN中如何引入BatchNorm

/* 版权声明:可以任意转载,转载时请标明文章原始出处和作者信息 .*/

                                                     author: 张俊林


Batch Normalization的基本思路和价值在之前一篇文章“Batch Normalization导读”介绍了,此处不赘述,背景知识请参考上面文章。


看到BN后,很明显能够看到这等于往传统的神经网络中增加了一个BN层,而且位置处于神经元非线性变换前,基本大多数网络结构都能很自然地融合进去,于是很自然的想法就是:如果用在CNN或者RNN效果会如何?是否也会增加收敛速度以及模型分类性能?CNN的使用方式在原始的Batch Normalization论文就说了,RNN则有相关后续研究跟进,目前看还没有特别明确的结论。


|CNNBatchNorm


CNNDNN不一样,某个卷积层包含多个FilterMap,而每个Filter Map其实是参数共享的,侦测同一类特征,是通过在输入图像上的局部扫描的方式遍历覆盖整个输入图像的,但是单个Filter Map本身可能是二维甚至多维的,如果是二维的,那么包含p*q个神经元。那么此时要应用BN其实有两种选择:


一种是把一个FilterMap看成一个整体,可以想象成是一个Filter Map对应DNN隐层中的一个神经元,所以一个Filter Map的所有神经元共享一个Scale

  • 18
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
好的,根据您的要求,以下是一个处理PTB心电数据集的Python代码,使用CNNRNN进行特征提取,并使用PyTorch实现: ```python import numpy as np import matplotlib.pyplot as plt import wfdb import pywt import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset # 数据预处理 def preprocess(data): # 滤波器 b, a = signal.butter(4, [0.5, 40], 'bandpass') # 去噪 data = signal.filtfilt(b, a, data) # 小波变换 coeffs = pywt.wavedec(data, 'db4', level=5) data = pywt.waverec(coeffs[:-1] + [None]*3, 'db4') return data # 加载PTB心电数据集 class PTBDataset(Dataset): def __init__(self, path): self.records = wfdb.get_record_list('ptbdb/' + path) def __len__(self): return len(self.records) def __getitem__(self, idx): record = self.records[idx] data, _ = wfdb.rdsamp('ptbdb/' + record) data = preprocess(data[:, 0]) data = torch.from_numpy(data).float() return data.unsqueeze(0) # CNN+RNN模型 class CNNRNN(nn.Module): def __init__(self): super(CNNRNN, self).__init__() self.cnn = nn.Sequential( nn.Conv1d(1, 64, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(128), nn.ReLU(), nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(256), nn.ReLU(), ) self.rnn = nn.GRU(256, 128, bidirectional=True, batch_first=True) self.fc = nn.Linear(256, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.cnn(x) x = x.permute(0, 2, 1) x, _ = self.rnn(x) x = self.fc(x) x = self.sigmoid(x) return x # 训练模型 def train_model(model, train_loader, optimizer, criterion, device): model.train() train_loss = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output.squeeze(), torch.ones_like(output)) loss.backward() optimizer.step() train_loss += loss.item() * data.size(0) train_loss /= len(train_loader.dataset) return train_loss # 测试模型 def test_model(model, test_loader, criterion, device): model.eval() test_loss = 0 with torch.no_grad(): for data in test_loader: data = data.to(device) output = model(data) loss = criterion(output.squeeze(), torch.ones_like(output)) test_loss += loss.item() * data.size(0) test_loss /= len(test_loader.dataset) return test_loss # 主函数 if __name__ == '__main__': # 设置超参数 batch_size = 64 learning_rate = 0.001 num_epochs = 10 # 加载数据集 train_dataset = PTBDataset('ptbdb_train') test_dataset = PTBDataset('ptbdb_test') train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 初始化模型、优化器和损失函数 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CNNRNN().to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) criterion = nn.BCELoss() # 训练模型 for epoch in range(num_epochs): train_loss = train_model(model, train_loader, optimizer, criterion, device) test_loss = test_model(model, test_loader, criterion, device) print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}') ``` 这段代码包含了以下步骤: 1. 数据预处理:使用带通滤波器、小波变换和去噪等方法对数据进行预处理。 2. 加载PTB心电数据集:使用wfdb库加载数据集。 3. CNN+RNN模型:使用CNNRNN进行特征提取。 4. 训练模型:使用Adam优化器和BCELoss损失函数训练模型。 注意:这段代码仅供参考,实际应用可能需要对数据预处理和模型结构进行调整。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值