跟着李沐学AI代码复现-07Finetune,08BoundBox,09ObjectDetection

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data

import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt

import os

'''
从互联网上下载hotdog数据集
使用微调
我没有做对照实验。
'''

# 超参数定义
batch_size = 128
epoch_num = 15 
lr = 0.001
if torch.cuda.device_count() >= 1:
    device = torch.device(f'cuda:0')
else:
    device = torch.device('cpu')


def load():
    # 这里是和前面的resnet不同的地方,
    # 使用RGB通道的均值和标准差,以标准化每个通道
    normalize = torchvision.transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_augs = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        normalize])

    test_augs = torchvision.transforms.Compose([
        torchvision.transforms.Resize([256, 256]),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        normalize])
    # 加载数据集
    train_imgs = torchvision.datasets.ImageFolder(r'E:\1MyDocuments\学习\12深度学习\recurrence\data\hotdog\train',
                                                  transform=train_augs)
    test_imgs = torchvision.datasets.ImageFolder(r'E:\1MyDocuments\学习\12深度学习\recurrence\data\hotdog\train',
                                                 transform=test_augs)

    # 将数据集转为特定格式 
    train_iter = data.DataLoader(train_imgs, batch_size, shuffle=True, num_workers=8)
    test_iter = data.DataLoader(test_imgs, batch_size, shuffle=False, num_workers=8)

    return train_iter, test_iter


def netdefine():
    # 这里使用从互联网上的预训练模型
    net = torchvision.models.resnet18(pretrained=True)
    net.fc = nn.Linear(net.fc.in_features, 2)
    # 初始化参数
    nn.init.xavier_uniform_(net.fc.weight)
    # 模型迁移
    net = net.to(device)
    return net


def train(net, train_iter,finetune = True):
    # optimizer会根据梯度对所有参数进行更新
    # 注意这里的优化器,对不同层的参数设置不同训练率
    if finetune:
        params_1x = [param for name, param in net.named_parameters()
                if name not in ["fc.weight", "fc.bias"]]
        optimizer = torch.optim.SGD([{'params': params_1x},
                                    {'params': net.fc.parameters(),
                                    'lr': lr * 10}],
                                    lr=lr, weight_decay=0.001)
    else:
        optimizer = torch.optim.SGD(net.parameters(), lr=lr,
                                  weight_decay=0.001)

    loss = nn.CrossEntropyLoss()
    for epoch in range(0, epoch_num):
        net.train()
        total = 0
        correct = 0
        for X, y in train_iter:
            # 梯度置零
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            # 梯度反向传播,参数更新
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            # 计算准确率
            _, pred = torch.max(y_hat, 1)
            total += y.size(0)
            correct += (pred == y).sum().item()
        print(f'epoch = {epoch}, total = {total}, accuracy = {100 * correct / total}%')


def test(net, test_iter):
    net.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for X, y in test_iter:
            X = X.to(device)
            y = y.to(device)
            output = net(X)
            _, pred = torch.max(output, 1)
            total += y.size(0)
            correct += (pred == y).sum().item()
        print(f'test,total = {total}, accuracy = {100 * correct / total}%')


if __name__ == '__main__':
    # 展示数据集
    train_iter, test_iter = load()
    pretrained_net = torchvision.models.resnet18(pretrained=True)
    net = netdefine()
    train(net, train_iter)
    test(net, test_iter)




import matplotlib.pyplot as plt
from PIL import Image
import matplotlib.patches as patches

if __name__ == '__main__':
    # 读取图像
    img_path = r'E:\1MyDocuments\学习\12深度学习\recurrence\img\cat1.jpg'
    img = Image.open(img_path)
    print("这是原图片")

    # 显示原始图像
    fig, ax = plt.subplots()
    ax.imshow(img)

    # 添加边界框
    bbox = [30, 80, 470, 333]
    rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1],
                             linewidth=2, edgecolor='blue', facecolor='none')
    ax.add_patch(rect)

    # 显示图像和边界框
    plt.show()
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import pandas as pd
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from PIL import Image
import matplotlib.patches as patches
'''
目标检测的数据集下载
'''

# 写一个DataSet类
class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = self.read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
              is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float().div(255), 
                self.labels[idx])
    def __len__(self):
        return len(self.features)
    
    def read_data_bananas(self,is_train):
        # 下面开始读取目标检测数据集
        data_dir = r'E:\1MyDocuments\学习\12深度学习\recurrence\data\banana-detection'
        csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                             else 'bananas_val', 'label.csv')
        csv_data = pd.read_csv(csv_fname)
        csv_data = csv_data.set_index('img_name')
        images, targets = [], []
        for img_name, target in csv_data.iterrows():
            images.append(torchvision.io.read_image(
                os.path.join(data_dir, 'bananas_train' if is_train else
                         'bananas_val', 'images', f'{img_name}')))
            targets.append(list(target))
        return images, torch.tensor(targets).unsqueeze(1) / 256

def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter

if __name__ == '__main__':
    '''
    http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
    '''
    batch_size, edge_size = 32, 256
    train_iter, _ = load_data_bananas(batch_size)
    batch = next(iter(train_iter))
    print(batch[0].shape, batch[1].shape)

    # 绘制
    fig, ax = plt.subplots()
    ax.imshow((torchvision.transforms.ToPILImage()(batch[0][0])))
    bbox = batch[1][0][0][1:5] * 256
    rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1],
                             linewidth=2, edgecolor='blue', facecolor='none')
    ax.add_patch(rect)
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值