import torch
from torch import nn
class WDCNN(nn.Module):
def __init__(self):
super(WDCNN, self).__init__()
self.model1 = nn.Sequential(
nn.Conv1d(1, 16, kernel_size=64, stride=16, padding=24),
nn.BatchNorm1d(16),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(64, 64, kernel_size=3),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(192, 100),
nn.ReLU(inplace=True),
nn.Linear(100, 10),
)
def forward(self, x):
x = self.model1(x)
return x
wdcnn = WDCNN()
print(wdcnn)
#input= torch.ones(1, 1, 2048)
#print(input)
#output= wdcnn(input)
#print(output.shape)#网络正确性检测代码