深度学习:使用timm库的resnet101模型对 pokem图像集进行图像分类预测

import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader

from DIYdata_loader import DIYData_loader
from resnet import ResNet18

import timm

batchsz = 32
picture_resize = 224
lr = 1e-4
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)

train_db = DIYData_loader(r'D:\python pycharm learning\清华大佬课程\青光眼ResNet\pokemon', picture_resize, mode='train')
val_db = DIYData_loader(r'D:\python pycharm learning\清华大佬课程\青光眼ResNet\pokemon', picture_resize, mode='val')
test_db = DIYData_loader(r'D:\python pycharm learning\清华大佬课程\青光眼ResNet\pokemon', picture_resize, mode='test')

train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=8)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=8)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=8)

viz = visdom.Visdom()


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()

    return correct / total


def main():
    # model = ResNet18(5).to(device)

    # 创建一个 ResNet101 模型,输出类别数为5
    model = timm.create_model('resnet101', pretrained=False, num_classes=5)

    # 从本地文件加载预训练权重
    state_dict = torch.load(r'D:\python pycharm learning\清华大佬课程\青光眼ResNet\resnet101_a1h-36d3f2aa.pth')

    # 修改全连接层的权重和偏置,使其适应新的类别数
    state_dict['fc.weight'] = state_dict['fc.weight'][:5, :]
    state_dict['fc.bias'] = state_dict['fc.bias'][:5]

    # 加载权重到模型
    model.load_state_dict(state_dict, strict=False)

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):
        model.train()
        for step, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)


if __name__ == '__main__':
    main()

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值