霹雳吧啦笔记 GoogleNet 网络搭建——pytroch

 GoogleNet网络模型

import torch
from torch import nn
import torch.nn.functional as F


class GoogLeNet(nn.Module):
    def __init__(self,num_classes=1000,aux_logits=True,init_weights=False):
        super(GoogLeNet,self).__init__()
        self.aux_logits=aux_logits

        self.conv1=BasicConv2d(3,64,kernel_size=7,stride=2,padding=3)
        self.maxpool1=nn.MaxPool2d(3,stride=2,ceil_mode=True)
        self.conv2=BasicConv2d(64,64,kernel_size=1)
        self.conv3=BasicConv2d(64,192,kernel_size=3,padding=1)
        self.maxpool2=nn.MaxPool2d(3,stride=2,ceil_mode=True)

        self.inception3a=Inception(192,64,96,128,16,32,32)
        self.inception3b=Inception(256,128,128,192,32,96,64)
        self.maxpool3=nn.MaxPool2d(3,stride=2,ceil_mode=True)

        self.inception4a=Inception(480,192,96,208,16,48,64)
        self.inception4b=Inception(512,160,112,224,24,64,64)
        self.inception4c=Inception(512,128,128,256,24,64,64)
        self.inception4d=Inception(512,112,144,288,32,64,64)
        self.inception4e=Inception(528,256,160,320,32,128,128)
        self.maxpool4=nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.inception5a=Inception(832,256,160,320,32,128,128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1=InceptionAux(512,num_classes)
            self.aux2=InceptionAux(528,num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(0.4)
        self.fc= nn.Linear(1024,num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=self.maxpool2(x)

        x=self.inception3a(x)
        x=self.inception3b(x)
        x=self.maxpool3(x)
        x=self.inception4a(x)
        if self.training and self.aux_logits:
            aux1=self.aux1(x)

        x=self.inception4b(x)
        x=self.inception4c(x)
        x=self.inception4d(x)
        if self.training and self.aux_logits:
            aux2=self.aux2(x)

        x=self.inception4e(x)
        x=self.maxpool4(x)
        x=self.inception5a(x)
        x=self.inception5b(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.dropout(x)
        x=self.fc(x)
        if self.training and self.aux_logits:
            return x,aux2,aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight,mode='fan_out',nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
                elif isinstance(m,nn.Linear):
                    nn.init.normal_(m.weight,0,0.01)
                    nn.init.constant_(m.bias,0)




class Inception(nn.Module):
    def __init__(self,in_channels,ch1x1,ch3x3red,ch3x3,ch5x5red,ch5x5,pool_proj):
        super(Inception,self).__init__()

        self.branch1=BasicConv2d(in_channels,ch1x1,kernel_size=1)

        self.branch2= nn.Sequential(
            BasicConv2d(in_channels,ch3x3red,kernel_size=1),
            BasicConv2d(ch3x3red,ch3x3,kernel_size=3,padding=1) #保证输出大小等于输入
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels,ch5x5red,kernel_size=1),
            BasicConv2d(ch5x5red,ch5x5,kernel_size=5,padding=2)
        )
        self.branch4=nn.Sequential(
            nn.MaxPool2d(kernel_size=3,stride=1,padding=1),
            BasicConv2d(in_channels,pool_proj,kernel_size=1)
        )

    def forward(self,x):
        branch1=self.branch1(x)
        branch2=self.branch2(x)
        branch3=self.branch3(x)
        branch4=self.branch4(x)

        outputs=[branch1,branch2,branch3,branch4]
        return torch.cat(outputs,1)

class InceptionAux(nn.Module):
    def __init__(self,in_channels,num_classes):
        super(InceptionAux,self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5,stride=3)
        self.conv= BasicConv2d(in_channels,128,kernel_size=1)
        self.fc1 = nn.Linear(2048,1024)
        self.fc2 = nn.Linear(1024,num_classes)

    def forward(self,x):
        x=self.averagePool(x)
        x=self.conv(x)
        x=torch.flatten(x,1)
        x=F.dropout(x,0.5,training=self.training)
        x=F.relu(self.fc1(x),inplace=True)
        x=F.dropout(x,0.5,training=self.training)
        x=self.fc2(x)
        return x



class BasicConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs):
        super(BasicConv2d,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,**kwargs)
        self.relu=nn.ReLU(inplace=True)

    def forward(self,x):
        x=self.conv(x)
        x=self.relu(x)
        return x

 

 训练模型代码

 

import json
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from GoogleNet import GoogLeNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#对数据集的操作
data_transform={
    "train":torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]),
    "val":torchvision.transforms.Compose([
        torchvision.transforms.Resize((244,244)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])}

image_path = "D:\pycharm project\噼里啪啦pytroch\\flower_data"
train_dataset = torchvision.datasets.ImageFolder(image_path+"/train",transform=data_transform["train"])
train_num=len(train_dataset)
flower_list= train_dataset.class_to_idx
cla_dict = dict((val,key) for key,val in flower_list.items())
json_str=json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)


batch_size=32
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)

valdata_dataset=torchvision.datasets.ImageFolder(image_path+"/val",transform=data_transform["val"])
val_num=len(valdata_dataset)
valdaat_loader = DataLoader(valdata_dataset,batch_size=val_num,shuffle=True)

net=GoogLeNet(num_classes=5,aux_logits=True,init_weights=True)
net.to(device)
loss_function=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.0003)
writer=SummaryWriter("keshihua")
best_acc = 0.0
save_path='googlenet.pth'
step_train=0
step_test=0
for epoch in range(50):
    net.train()
    running_loss=0.0
    for data in train_loader:
        images,labels = data
        optimizer.zero_grad()
        logits,aux_logits2,aux_logits1=net(images.to(device))
        loss0=loss_function(logits,labels.to(device))
        loss1=loss_function(aux_logits1,labels.to(device))
        loss2=loss_function(aux_logits2,labels.to(device))
        loss=loss0+loss1*0.3+loss2*0.3
        loss.backward()
        optimizer.step()

        running_loss+=loss.item()
        writer.add_scalar("train_loss",running_loss,step_train)
        step_train+=1

    net.eval()
    acc=0.0
    with torch.no_grad():
        for data_test in valdaat_loader:
            test_images,test_labels =data_test
            outputs=net(test_images.to(device))
            predict=torch.max(outputs,dim=1)[1]
            acc+=(predict==test_labels.to(device)).sum().item()
            writer.add_scalar("accuracy", acc,step_test)
            step_test+=1
        accuracy_test =acc/val_num
        if accuracy_test>best_acc:
            torch.save(net.state_dict(),save_path)
            print("epoch={0}的精确度={1},训练损失={2}".format(epoch+1,accuracy_test,running_loss))
            writer.add_scalar("accuracy",acc)
print("训练完成")
writer.close()




运行结果

 tensorboard可视化

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值