猫狗二分类

资源地址
链接:https://pan.baidu.com/s/1H2tQwYP1i-VIziS6XbLllQ 
提取码:q5mt 



import torchvision
from torch.utils.data import DataLoader
import torch
import time
from PIL import Image

# 搭建模型
model = torchvision.models.vgg19(pretrained=True)
model.classifier[-1] = torch.nn.Linear(4096, 2)
print(model)

# 初始化运行条件
if True:
    sp = '\n' + '--------' * 20 + '\n'
    root = './data'
    train_path = root + '/train'
    test_path = root + '/test'

    bs = 16
    lr = 0.0001
    epoch = 20
    device = 'cuda'
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params=model.parameters(), momentum=0.9, lr=lr)

    transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.ToTensor()
    ])

# 读取数据
if True:
    train_data = torchvision.datasets.ImageFolder(train_path, transform)
    classes = train_data.classes
    train_iterator = DataLoader(train_data, bs, shuffle=True)
    # 展示数据细节
    if False:
        print(train_data, end=sp)
        print('class_to_idx: ', train_data.class_to_idx)
        print('classes: ', train_data.classes)
        print('extension: ', train_data.extensions)
        print('extra_repr: ', train_data.extra_repr())
        print('imgs: ', train_data.imgs)
        print('loader: ', train_data.loader)
        print('root: ', train_data.root)
        print('samples: ', train_data.samples)
        print('target_transform: ', train_data.target_transform)
        print('targets: ', train_data.targets)
        print('transform: ', train_data.transform)
        print('transforms: ', train_data.transforms)
        print('\n\n', end=sp)

        print(train_iterator, end=sp)
        print('batch_sampler: ', train_iterator.batch_sampler)
        print('batch_size: ', train_iterator.batch_size)
        print('collate_fn: ', train_iterator.collate_fn)
        print('dataset: ', train_iterator.dataset)
        print('drop_last: ', train_iterator.drop_last)
        print('generator: ', train_iterator.generator)
        print('multiprocessing_context: ', train_iterator.multiprocessing_context)
        print('num_workers: ', train_iterator.num_workers)
        print('persistent_workers: ', train_iterator.persistent_workers)
        print('pin_memory: ', train_iterator.pin_memory)
        print('prefetch_factor: ', train_iterator.prefetch_factor)
        print('sampler: ', train_iterator.sampler)
        print('timeout: ', train_iterator.timeout)
        print('worker_init_fn: ', train_iterator.worker_init_fn)
        print('\n\n', end=sp)


def train(model, iterator, optimizer, criterion):
    def accuracy(outputs, label):
        pre = torch.argmax(outputs, dim=1)
        acc_num = (pre == label).sum()
        return acc_num / len(label)

    start_time = time.monotonic()
    epoch_loss = 0.0
    epoch_acc = 0.0

    model = model.to(device)
    model.train()

    for (images, labels) in iterator:
        optimizer.zero_grad()

        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)

        loss = criterion(outputs, labels)
        acc = accuracy(outputs, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss
        epoch_acc += acc
    cost_time = time.monotonic() - start_time
    return epoch_loss / len(iterator), epoch_acc / len(iterator), cost_time


if __name__ == '__main__':
    # 是否训练
    if False:
        for epoch in range(epoch):
            loss, acc, cost_t = train(model, train_iterator, optimizer, criterion)
            print(f'epoch: {epoch}\tcost time: {cost_t}\nloss: {loss}\tacc: {acc}')
        torch.save(model.state_dict(), 'cat_dog_classification.pth')
    model.load_state_dict(torch.load('cat_dog_classification.pth'))
    # 是否选择图片进行预测
    choice = input('是否选择图片进行预测? ')
    if choice in {'y', 'Y'}:
        path = input('输入图片路径(仅限于jpg): ')


        def classification(path):
            image = transform(Image.open(path))
            image = torch.unsqueeze(image, 0)
            out = model(image)
            poss = torch.softmax(out, dim=1)
            index = int(torch.argmax(out, dim=1))
            print('name: ', classes[index], '\tpossibility: ', float(poss[0, index]))
        classification(path)
参与评论 您还未登录,请先 登录 后发表或查看评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:大白 设计师:CSDN官方博客 返回首页

打赏作者

嗷我懂了

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值