AlexNet

# model.py

import torch
from torch import nn

class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False):
        super(AlexNet,self).__init__()
        self.feature=nn.Sequential(
            nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2),  # input[
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(48,128,kernel_size=5,padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(128,192,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192,192,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192,128,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2)
        )

        self.classifier=nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6,2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048,2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048,num_classes)
        )

        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x=self.feature(x)

        x=torch.flatten(x,start_dim=1)  # x=x.view(-1,128*6*6)
        out=self.classifier(x)
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
            elif isinstance(m,nn.Linear):
                nn.init.normal_(m.weight,0,0.01)
                nn.init.constant_(m.bias,0)
# train.py

import torch
import torch.nn as nn
from torchvision import transforms,datasets,utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
    "val": transforms.Compose([transforms.Resize((224,224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
}

data_root=os.path.abspath(os.path.join(os.getcwd(),"../"))   # 得到数据集的根目录
image_path=data_root+"/data_set/flower_data/"   # flower数据集路径
train_dataset=datasets.ImageFolder(root=image_path+"/train",
                                   transform=data_transform["train"])

train_num=len(train_dataset)
print(train_num)
print(train_dataset[0][0].shape)

flower_list=train_dataset.class_to_idx
cla_dict=dict((val,key) for key,val in flower_list.items())

json_str=json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)

batch_size=32
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)

validata_dataset=datasets.ImageFolder(root=image_path+"/val",transform=data_transform["val"])
val_num=len(validata_dataset)
validata_loader=torch.utils.data.DataLoader(validata_dataset,batch_size=4,shuffle=True,num_workers=0)

print(len(validata_dataset)==len(validata_loader.dataset))

test_data_iter=iter(validata_loader)
test_image,test_label=test_data_iter.next()

# def imshow(img):
#     img = img / 2 + 0.5  # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
# print(test_image)
# # print labels
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# # show images
# imshow(utils.make_grid(test_image))

net=AlexNet(num_classes=5,init_weights=True)

net.to(device)
loss_function=nn.CrossEntropyLoss()
# pata=list(net.parameters())
optimizer=optim.Adam(net.parameters(),lr=0.0002)

save_path='./AlexNet.pth'
best_acc=0.0
for epoch in range(10):
    # train
    net.train()
    running_loss=0.0
    t1=time.perf_counter()# 训练以恶搞epoch所需时间
    for step, data in enumerate(train_loader,start=0):
        images,labels=data
        optimizer.zero_grad()
        outputs=net(images.to(device))
        loss=loss_function(outputs,labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss+=loss.item()
        rate=(step+1)/len(train_loader)
        a="*"*int(rate*50)
        b="."*int((1-rate)*50)
        print("\rtrain loss:{:^3.0f}%[{}->{}]{:.3f}".format(int(rate*100),a,b,loss),end="")
    print()
    print(time.perf_counter()-t1)

    net.eval()
    acc=0.0
    with torch.no_grad():
        for data_test in validata_loader:
            test_images,test_labels=data_test
            outputs=net(test_images.to(device))
            predict_y=torch.max(outputs,dim=1)[1]
            acc+=(predict_y==test_labels.to(device)).sum().item()
        accurate_test=acc/val_num
        if accurate_test>best_acc:
            best_acc=accurate_test
            torch.save(net.state_dict(),save_path)
        print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % (epoch+1,running_loss/step,acc/val_num))




# predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform=transforms.Compose(
    [transforms.Resize((224,224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

img=Image.open("../Tulip.png").convert('RGB')#将四通道图像转换成三通道就可以加到data_transform中进行图像预处理
# print(img)
# plt.imshow(img)
# plt.show()
img=data_transform(img)
# expand batch dimension
img=torch.unsqueeze(img,dim=0)

# read class_indict
try:
    json_file=open('./class_indices.json')
    class_indict=json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model=AlexNet(num_classes=5)
# load model weights
model_weight_path="./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

model.eval()
with torch.no_grad():
    output=torch.squeeze(model(img))
    predict=torch.softmax(output,dim=0)
    predict_cla=torch.argmax(predict).numpy()

print(class_indict[str(predict_cla)],predict[predict_cla].item())
plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值