train函数

本文深入探讨了在深度学习中train函数的作用和流程,包括数据加载、模型优化、损失计算及反向传播等关键步骤,帮助读者理解如何通过train函数进行有效的模型训练。
摘要由CSDN通过智能技术生成

每次训练都测试

def get_acc(output,label):
    total = output.shape[0]
    _,pred_label = output.max(1)
    return (pred_label == label).sum().data.item()/total

def train(net,train_data,valid_data,num_epochs,optimizer,criterion):
    if torch.cuda.is_available():
        net = net.cuda()
    time0 = time.time()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        time1 = time.time()
        for im,label in train_data:
            im = Variable(im.cuda())
            label = Variable(label.cuda())
            output = net(im)
            #print(output)
            loss = criterion(output,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.data.item()
            train_acc += get_acc(output,label)
        if valid_data is not None:
            valid_loss = 0
      
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值