深度学习的训练、验证、测试的模板代码

文章目录

从头搭建一个深度学习的模型,基本上都可以从这个框架去套用。
包括了最基础的模型的定义、训练、验证和测试

完整的框架代码

import torch
import torch.nn as nn
import torch.optim as optim
from  torch.utils.data import DataLoader,Dataset

# 定义模型
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1=nn.Linear(100,50)
        self.fc2=nn.Linear(50,10)

    def forward(self,x):
        x=torch.relu(self.fc1(x))
        x=self.fc2(x)
        return x

# 定义训练集
class TrainDataset(Dataset):
    def __init__(self,data_path,target_path):
        super(TrainDataset,self).__init__()
        pass  # 初始化

    def __getitem__(self, index):
        pass# 根据index返回数据
        # return data,target
    def __len__(self):
        pass # 返回数据的长度
        # return len(xxx)

# 定义验证集
class ValDataset(Dataset):
    def __init__(self, data_path, target_path):
        super(ValDataset, self).__init__()
        pass  # 初始化

    def __getitem__(self, index):
        pass  # 根据index返回数据
        # return data,target

    def __len__(self):
        pass  # 返回数据的长度
        # return len(xxx)

# 定义测试集
class TestDataset(Dataset):
    def __init__(self, data_path, target_path):
        super(TestDataset, self).__init__()
        pass  # 初始化

    def __getitem__(self, index):
        pass  # 根据index返回数据
        # return data,target

    def __len__(self):
        pass  # 返回数据的长度
        # return len(xxx)

# 训练
def train(model,optimizer,criterion,train_loader,device):
    total_loss=0
    # 将模型设置为训练模式
    model.train()
    # 遍历数据集
    for data,target in train_loader:
        # 将数据移到设备上
        data,target=data.to(device),target.to(device)
        # 梯度清零
        optimizer.zero_grad()
        # 前向传播
        output=model(data)
        # 计算损失
        loss=criterion(output,target)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 统计损失
        total_loss+=loss.item()

    # 计算平均损失
    avg_loss=total_loss/len(train_loader)

    return avg_loss

# 验证
def validate(model,criterion,val_loader,device):
    total_loss=0
    # 将模型设置为验证模式
    model.eval()

    with torch.no_grad():
        #遍历验证集
        for data,target in val_loader:
            # 将数据移到设备上
            data,target=data.to(device),target.to(device)
            # 前向传播
            output=model(data)
            # 计算损失
            loss=criterion(output,target)
            # 统计损失
            total_loss+=loss.item()

    # 计算平均损失
    avg_loss=total_loss/len(val_loader)

    return avg_loss

def test(model,criterion,test_loader,device):
    total_loss=0
    total_correct=0
    # 将模型设置为测试模式
    model.eval()

    with torch.no_grad():
        # 遍历测试集
        for data,target in test_loader:
            # 将数据迁移到设备上
            data,target=data.to(device),target.to(device)
            # 前向传播
            output=model(data)
            # 计算损失
            loss=criterion(output,target)
            # 统计损失
            total_loss+=loss.item()
            # 计算准确率
            _,predicted =torch.max(output.data,1)
            total_correct+=(predicted==target).sum().item()

    # 计算平均损失
    avg_loss=total_loss/len(test_loader)
    # 计算准确率
    accuracy=total_correct/len(test_loader.dataset)

    return avg_loss,accuracy

# 主函数
def main():
    # 定义超参数
    epoch=10
    lr=0.01
    batch_size=32

    # 定义设备
    device =torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 定义训练集、验证集、测试集 的数据
    train_data=TrainDataset(data_path='',target_path='')
    val_data = ValDataset(data_path='', target_path='')
    test_data = TestDataset(data_path='', target_path='')

    # 定义训练集、验证集、测试集 的数据加载器
    train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True)
    val_loader=DataLoader(val_data,batch_size=batch_size,shuffle=True)
    test_loader=DataLoader(test_data,batch_size=batch_size,shuffle=True)

    # 定义模型
    model=Model().to(device)
	# 加载预训练权重
	model.load_state_dict(torch.load('xxx.pth')
    # 定义损失函数和优化器
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.SGD(model.parameters(),lr=lr)

    # 训练和验证
    for epoch in range(1,epoch+1):
        train_loss=train(model,optimizer,criterion,train_loader,device)
        val_loss=validate(model,criterion,val_loader,device)
        # 显示训练集和验证集的损失
		if (epoch%display_iter)==0:
        	print(f"Epoch: {epoch}, Train loss:{train_loss:.4f},Val loss: {val_loss: .4f}")
        # 保存权重
        if (epoch%snapshot_iter)==0:
        	torch.save(model.state_dict(),"xxx.pth")

    # 测试
    test_loss,test_accuracy=test(model,criterion,test_loader,device)
    print(f"Test loss:{test_loss:.4f},Test accuracy:{test_accuracy:.4f}")

if __name__=='__main__':
    main()
  • 8
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zyw2002

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值