计算机视觉4 AlexNet Pytorch实现

计算机视觉4 AlexNet Pytorch实现

定义网络类型

class AlexNet(nn.Module):       # 构造父类
    def __init__(self, num_classes=2, init_weights=False):   # 预先放的权重先不放,表示从头开始训练
        # super表示承接父类,即AlexNet.
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11),               # input[3, 65, 65]此处与GPU有关,电脑配置有关, output[48, 55, 55]
            nn.ReLU(inplace=True),                          # 激活函数
            nn.MaxPool2d(kernel_size=3, stride=2),          # 最大池化 output[48, 27, 27]
            nn.BatchNorm2d(48),                             # 归一化,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定,
                                                            # 原文中的归一化已经不常见了,现在采用普通归一化
            nn.Conv2d(48, 128, kernel_size=5, padding=2),   # (27 + 2*padding -kernel_size)/stride + 1
            nn.ReLU(inplace=True),                          # 激活函数
            nn.MaxPool2d(kernel_size=3, stride=2),          # 最大池化 output[128, 13, 13]
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),  # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),          # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self.__initialize_weights()

    # 定义前向传播
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity="relu")
                if m.bias is not None:
                    # nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


图片存放和载入的路径

# # 载入图片位置
source_path0 = r'D:\cv\image\train_divided\1'
source_path1 = r'D:\cv\image\train_divided\0'
# # 输出图片位置
aim_dir0 = r'D:\cv\net\1'
aim_dir1 = r'D:\cv\net\0'

图片增强

def dataEnhance(source_path, aim_dir, size):
    h = 0
    # 得到目标文件的文件和文件名
    file_list = os.listdir(source_path)
    # 创建目标文件夹
    if not os.path.exists(aim_dir):
        os.makedirs(aim_dir)
    # 对文件夹中文件进行遍历

    for i in file_list:
        img = Image.open('%s\%s' % (source_path, i))
        print(img.size)
        h = h + 1
        transform1 = transforms.Compose([
            transforms.ToTensor(),                      # 转换成张量
            transforms.ToPILImage(),
            transforms.Resize(size),
        ])
        img1 = transform1(img)
        img1.save('%s\%s.png' % (aim_dir, h))

        h = h + 1
        transform2 = transforms.Compose([
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            transforms.ToTensor(),  # 转换成张量
            transforms.ToPILImage(),
            # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
            transforms.Resize(size),
        ])
    
        img2 = transform2(img)
        img2.save('%s/%s.png' % (aim_dir, h))
        h = h + 1
        transform3 = transforms.Compose([
            transforms.ToTensor(),  # 转换成张量
            transforms.ToPILImage(),
            transforms.RandomCrop(227, pad_if_needed=True),    # 随机裁剪
            transforms.Resize(size),
        ])
        img3 = transform3(img)
        img3.save('%s/%s.png' % (aim_dir, h))

        h = h + 1
        transform4 = transforms.Compose([
            transforms.ToTensor(),                              # 转换成张量
            transforms.ToPILImage(),
            transforms.RandomRotation(60),                      # 随机旋转6度
            transforms.Resize(size),
        ])
        img4 = transform4(img)
        img4.save('%s/%s.png' % (aim_dir, h))


# dataEnhance(source_path0, aim_dir0, (65, 65))
# dataEnhance(source_path1, aim_dir1, (65, 65))

加载模型

# 训练模型 确定均值和方差
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.486], std=[0.229, 0.224, 0.225])
# 训练集加载
path = r"D:\cv\net"
trans = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])
# root-图片放的位置,transform转换
dataset = ImageFolder(root=path, transform=trans)
# 模型加载
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)         # num_workers表示几个GPU进行

# 测试集加载
path1 = r'D:\cv\image\test_divided'
transf = transforms.Compose([
    transforms.Resize((65, 65)),
    transforms.ToTensor(),
    normalize,
])
datasettest = ImageFolder(root=path1, transform=transf)
test_loader = torch.utils.data.DataLoader(datasettest, batch_size=batch_size, shuffle=True, num_workers=0)

训练模型

def train_model(model, device, train_loader, optimizer, epoch):
    train_loss = 0
    model.train()
    for batch_index, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # 交叉熵
        loss = f.cross_entropy(output, target)
        loss.backward()
        # 更新所有的权重
        optimizer.step()
        if batch_index % 300 == 0:
            train_loss = loss.item()
            print("train epoch :{}\t train Loss : {:.6f}".format(epoch, loss.item()))
    return train_loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

白麦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值