冻结ResNet50前几层并进行迁移学习(PyTorch)

在PyTorch中,加载ResNet50模型并冻结模型的前几层可以通过以下步骤进行:

import torch
from torchvision.models import resnet50

# 设置GPU环境
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

# 加载预训练的ResNet50模型
trained_model = resnet50(pretrained=True)

# 冻结需要保持不变的层,通常是前几个卷积层
for name, param in trained_model.named_parameters():
    if 'conv1' in name or 'bn1' in name or 'conv2' in name or 'bn2' in name or 'conv3' in name or 'bn3' in name:
        param.requires_grad = False

# 修改最后一层进行微调
model = nn.Sequential(*list(trained_model.children())[:-1],
                        Flatten(),  # [b,2048]
                        nn.Linear(2048, 4),  # 假设输出类别数为4
                        ).to(device)

# 损失、优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(epochs):
    total_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()

        # 前向传播
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # 打印每个epoch的损失值
    print(f"Epoch {epoch+1} Loss: {total_loss /len(train_loader)}")

# 保存模型参数
torch.save(model.state_dict(), 'retrain_resnet50.pth')

# 测试模型
def evalute(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct/total

model.load_state_dict(torch.load('retrain_resnet50.pth'))
test_acc = evalute(model, test_loader)
print('test acc:', test_acc)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值