#coding=utf-8 import os os.environ['KMP_DUPLICATE_LIB_OK']='True' import torch from torch import nn from d2l import torch as d2l net=torch.nn.Sequential( nn.Conv2d(1,96,kernel_size=11,stride=4,padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=3,stride=2), nn.Conv2d(96,256,kernel_size=5,padding=2), nn.ReLU(), nn.MaxPool2d(kernel_size=3,stride=2), nn.Conv2d(256,384,kernel_size=3,padding=1),nn.ReLU(), nn.Conv2d(384,384,kernel_size=3,padding=1),nn.ReLU(), nn.Conv2d(384,256,kernel_size=3,padding=1),nn.ReLU(), nn.MaxPool2d(kernel_size=3,stride=2), nn.Flatten(), nn.Linear(6400,4096),nn.ReLU(),nn.Dropout(p=0.5), nn.Linear(4096,4096), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(4096,10), ) '''x=torch.randn(1,1,224,224) for layer in net: x=layer(x) print(layer.__class__.__name__,"output shape:\t",x.shape) ''' batch_size=128 train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size,resize=224) lr ,num_epochs=0.01,10 d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu()) d2l.plt.show()
调用d2l库,使用AlexNet模型
最新推荐文章于 2024-08-15 11:31:36 发布