import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
def cal_len(x):
return x//2 -1
class Net(nn.Module):
def __init__(self, ecgsamples=5000):
super(Net, self).__init__()
self.ecgsamples = ecgsamples
self.conv1 = nn.Conv1d(12, 32, kernel_size=3, stride=1, padding=0)
# self.conv1_bn = nn.batchnorm()
self.conv2 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=0)
self.conv3 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=0)
self.fc1 = nn.Linear(cal_len(cal_len(cal_len(self.ecgsamples)))-3, 5)
def forward(self, x):
x = F.max_pool1d(F.relu(self.conv1(x)), 2, stride=2, padding=0)
x = F.max_pool1d(F.relu(self.conv2(x)), 2, stride=2, padding=0)
x = F.max_pool1d(F.relu(self.conv2(x)), 2, stride=2, padding=0)
x = F.max_pool1d(F.relu(self.conv2(x)), 2, stride=1, padding=0)
x = x.view(-1, x.shape[-1]*x.shape[-2])
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net(ecgsamples=10000).to(device)
summary(model, input_size=(12, 10000))
可以看出,在构建model后,在keras中可以直接通过model.summary()来展示网络架构和训练参数总量等信息。但在torch中,需要引入torchsummary这个包,并且,需要指定准确的input_size才能展示准确的结构和参数信息。