- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
from torch.autograd import Variable def train(model,train_loader,loss_model,optimizer): model=model.to(device) model.train() for i, (images, labels) in enumerate(train_loader, 0): #0是标起始位置的值。 images = Variable(images.to(device)) labels = Variable(labels.to(device)) optimizer.zero_grad() outputs = model(images) loss = loss_model(outputs, labels) loss.backward() optimizer.step() if i % 1000 == 0: print('[%5d] loss: %.3f' % (i, loss)) def test(model, test_loader, loss_model): size = len(test_loader.dataset) num_batches = len(test_loader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in test_loader: X, y = X.to(device), y.to(device) pred = model(X) test_loss += loss_model(pred, y).item() pred_classes=pred.argmax(dim=2) y_classes=y.argmax(dim=2) correct+=(pred_classes==y_classes).type(torch.float).sum().item() test_loss /= num_batches correct /=size print(f"Avg loss: {test_loss:>8f} \n") print(f"test acc: {correct:>8f} \n") return correct,test_loss
任务完成下周继续加油
第P10周:Pytorch实现车牌识别
最新推荐文章于 2024-08-15 16:33:35 发布