import
torch
torch的总包torch.nn
网络层,通常自定义的网络都会继承nn.Module
torch.nn.functional
F里面都包括常用的函数,relu和pooling等torchvision
常用的数据集,MNIST和FashionMNIST等torchvision.transforms
数据集
使用FashionMNIST
train_set = torchvision.datasets.FashionMNIST(
root='./data/FashionMNIST',
train=True,
download=False,
transform=transforms.Compose([
transforms.ToTensor()
])
)
train_dataloader = torch.utils.data.DataLoader(
train_set,
batch_size=10
)
建立网络
class nopaNet(nn.Module):
def __init__(self):
super(nopaNet, self).__init__()
#输入图像[batch,1,28,28]
#out_channel 代表有几个filter
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.dense1 = nn.Linear(in_features=12*4*4, out_features=120)
self.dense2 = nn.Linear(