模型训练框架

#导入库
%matplotlib inline
from torchvision import datasets
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
import torch.nn.functional as F

#导入数据
data_path = 'data/p1ch7/'
cifar10 = datasets.CIFAR10(data_path,train=True,download=False,transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
    ]))
cifar10_val = datasets.CIFAR10(data_path,train=False,download=False,transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
    ]))
#从cifar10中提取2种图片
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']
cifar2 = [(img, label_map[label])
          for img, label in cifar10
          if label in [0, 2]]
cifar2_val = [(img, label_map[label])
              for img, label in cifar10_val
              if label in [0, 2]]

#定义网络
class Net(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(3,16,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(16,8,kernel_size=3,padding=1)
        self.fc1 = nn.Linear(8*8*8,32)
        self.fc2 = nn.Linear(32,32)

    def forward(self,x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)),2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)),2)
        out = out.view(-1,8*8*8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        return out

#训练框架
def training_loop(n_epochs,optimizer,model,loss_fn,train_loader):
    for epoch in range(1, n_epochs+1):
        loss_train = 0.0
        for imgs,labels in train_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()

            if epoch ==1 or epoch % 10 ==0:
                print('{} Epoch{},Training loss{}'.format(datetime.datetime.now(),epoch,loss_train/len(train_loader)))

#导入数据
train_lodaer = torch.utils.data.DataLoader(cifar2,batch_size = 64,shuffle = True)

model = Net()#初始化网络
optimizer = optim.SGD(model.parameters(),lr= 1e-2)
loss_fn = nn.CrossEntropyLoss()

#开始训练
training_loop(
    n_epochs=100,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_lodaer,
)

#模型在训练集和验证集上的准确度
train_loader = torch.utils.data.DataLoader(cifar2,batch_size=64,shuffle = False)
val_loader = torch.utils.data.DataLoader(cifar2_val,batch_size = 64,shuffle=False)

def validate(model,train_lodaer,val_loader):
    for name,loader in [("train",train_loader),("val",val_loader)]:
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs,labels in loader:
                outputs = model(imgs)
                _,predicted = torch.max(outputs,dim=1)
                total += labels.shape[0]

                correct += int((predicted == labels).sum())
            print("Accuracy {}:{:.2f}".format(name,correct/total))

validate(model,train_lodaer,val_loader)

# 保存加载模型
#保存模型的参数
torch.save(model.state_dict(),data_path + 'birds_vs_airplanes.pt')
#加载模型的参数
loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path+'birds_vs_airplanes.pt'))

#在GPU上训练模型
device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))
print(f"Training on device {device}")
import datetime

def training_loop(n_epochs,optimizer,model,loss_fn,train_loader):
    for epoch in range(1, n_epochs+1):
        loss_train = 0.0
        for imgs,labels in train_loader:
            imgs = imgs.to(device = device)
            labels = labels.to(device=device)
            
            outputs = model(imgs)
            loss = loss_fn(outputs,labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()

        if epoch ==1 or epoch % 10 ==0:
            print('{} Epoch{},Training loss{}'.format(datetime.datetime.now(),epoch,loss_train/len(train_loader)))
train_loader = torch.utils.data.DataLoader(cifar2,batch_size = 64,shuffle = True)

model = Net().to(device=device)
optimizer = optim.SGD(model.parameters(),lr = 1e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs=100,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader
)

loaded_model = Net().to(device=device)
loaded_model.load_state_dict(torch.load(data_path+'birds_vs_airplanes.pt',map_location=device))            
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值