AlexNet pytorch代码注释

代码来自知乎

​​​​​​实现pytorch实现AlexNet(CNN经典网络模型详解) - 知乎

module.py

#model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(  #打包
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 227, 227]  output[48, 55, 55] 自动舍去小数点后
            nn.ReLU(inplace=True), #inplace 可以载入更大模型
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            #全链接
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1) #展平  保留dim0也就是batchsize,剩下推平
# 或者view()保留batch,剩下的推平
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值
                nn.init.constant_(m.bias, 0)

train.py

# train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

# device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # device切换到GPU
print(device)  # 打印设备名称

# 数据转换
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪一个area然后再Resize 默认长宽相等
                                 transforms.RandomHorizontalFlip(),  # 对图片随机进行水平翻转
                                 transforms.ToTensor(),  # 转换成tensor格式,可以直接输入进神经网络,将灰度0-255变换到0-1
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),  # 换到-1到1
    # 验证
    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                               # resize成224*224的尺寸,如果只写一个224就是将最小的一个边放缩到224另一个边等比例缩放
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
# 字典
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
data_root = os.getcwd()  # os.getcwd() 方法用于返回当前工作目录。
image_path = data_root + "/flower_data/"  # flower data set path 为了兼容

train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])  # 是一个通用的数据加载器,在读取的时候会为数据写入标签


train_num = len(train_dataset)  # train_dataset里图片的个数



# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}

flower_list = train_dataset.class_to_idx  # Dict with items (class_name, class_index).
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)  # 用dumps将python编码成json字符串
# obj:就是你要转化成json的对象。indent:参数根据数据格式缩进显示,读起来更加清晰。


with open('class_indices.json', 'w') as json_file:  # 写入class_indices.json
    json_file.write(json_str)

batch_size = 32  # BATCH_SIZE:即一次训练所抓取的数据样本数量;批处理数据个数

train_loader = torch.utils.data.DataLoader(train_dataset,  # torch.utils.data.DataLoader主要是对数据进行batch的划分
                                           batch_size=batch_size, shuffle=True,  # 是否打乱顺序
                                           num_workers=0)  # 多线程读取数据
# 如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),
# 而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度更慢。

# 验证数据集
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=0)
# 以上重复载入训练数据集的步骤
test_data_iter = iter(validate_loader)  # iter() 函数用来生成迭代器
test_image, test_label = test_data_iter.next()  # 测试图像和测试标签
# print(test_image[0].size(),type(test_image[0]))
# print(test_label[0],test_label[0].item(),type(test_label[0]))


# 显示图像,之前需把validate_loader中batch_size改为4
# def imshow(img):
#     img = img / 2 + 0.5  # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))


net = AlexNet(num_classes=5, init_weights=True)#AlexNet

net.to(device)#把module载入GPU
# 损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()#使用交叉熵损失函数
# 优化器 这里用Adam


#net.parameters(), 模型的参数, lr学习率
optimizer = optim.Adam(net.parameters(), lr=0.0002)#使用Adam优化器

# 训练参数保存路径
save_path = './AlexNet.pth'
# 训练过程中最高准确率,初始化为0
best_acc = 0.0

# 开始进行训练和测试,训练一轮,测试一轮
for epoch in range(30):#训练10轮
    # train
    net.train()  # 训练过程中,使用之前定义网络中的dropout
    running_loss = 0.0#损失初始化
    t1 = time.perf_counter()
    #调用一次 perf_counter(),从计算机系统里随机选一个时间点A,计算其距离当前时间点B1有多少秒。
    # 当第二次调用该函数时,默认从第一次调用的时间点A算起,距离当前时间点B2有多少秒。两个函数取差,即实现从时间点B1到B2的计时功能。
    for step, data in enumerate(train_loader, start=0): #返回下标和值
        images, labels = data#train_loader里是数据和标签的元组
        optimizer.zero_grad()#优化器的梯度清零
        outputs = net(images.to(device))#将数据喂入网络
        loss = loss_function(outputs, labels.to(device))#计算输出和标签的损失
        loss.backward()#计算梯度,反向传播
        optimizer.step()#优化器进行迭代

        # print statistics
        running_loss += loss.item()#取loss的值,每一个batch的loss相加
        # print train process
        rate = (step + 1) / len(train_loader)#进度

        #下面两个是动态显示进度
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")#此处loss为每一个batch的loss
    print()
    print(time.perf_counter() - t1)#打印训练一个epoch的时间

    # validate
    net.eval()  # 测试过程中不需要dropout,使用所有的神经元
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():#反向传播时不会求导
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))#求出outputs
            predict_y = torch.max(outputs, dim=1)[1]#找出每一行中最大的值,并返回其索引
            acc += (predict_y == val_labels.to(device)).sum().item()#一个batch中预测正确的个数
        val_accurate = acc / val_num#计算正确率
        #求出最高正确率
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)#保存训练参数
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / (step+1), val_accurate))

print('Finished Training')

predict.py

#predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
#img = Image.open("./dandelion.jpg")  #验证太阳花
img = Image.open("./sunflower.jpg")  #验证太阳花
#img = Image.open("./rose.jpg")     #验证玫瑰花
plt.imshow(img)
#plt.show()
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)#扩展维度

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)#dim:指明维度,dim=0表示按列计算;dim=1表示按行计算。默认dim的方法已经弃用了,最好声明dim,否则会警告:
    predict_cla = torch.argmax(predict).numpy()#转为numpy,找出最大值对应的下标
print(class_indict[str(predict_cla)], predict[predict_cla].item())#predict[predict_cla].item()的意思是预测的值是多少
plt.show()

dataset.py

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

将数据集进行分类,(这段我看不懂)

#spile_data.py

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/'+cla)

mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/'+cla)

split_rate = 0.1
for cla in flower_class:
    cla_path = file + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num*split_rate))
    for index, image in enumerate(images):
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

遇到不懂的可以设断点print一下看看形状

强迫自己做笔记吧

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值