模型训练的基本操作

模型训练的基本操作

一、 数据读取以及预处理

1. 数据集下载

​ pytorch中保存有很多已有数据集,一般存于torchversion.datasets文件中,包含分类、目标检测、分割、姿态估计、光流估计、GAN、超分辨率重构、视频、物体追踪、声音、人脸、遥感等数据集

# 训练集
train_dataset = torchversion.datasets.MNIST(root = './data',
                              train = True,
                              transform = my_transform,
                              download = True)
# 测试集
test_dataset = torchversion.datasets.MNIST(root = './data',
                             train = False,
                             transform = my_transform)

2. 数据读取

​ 读取数据集,并分割出训练集以及测试集

	root_dir = '/dataset/'	# 数据集地址

	train_dir = data_dir + '/train'	# 训练集

	valid_dir = data_dir + '/valid'	# 测试集

3. transform:数据预处理与数据增强

(1)TransForm:

​ transform主要是用于数据预处理以及数据增强操作。transform主要定义了两个类型的操作:

  • 数据预处理以及数据增强操作:具体的数据变换,如旋转、裁切、转换为张量以及正则化等操作
  • 组合操作:transform.compose
from torchvision import transforms
(2)常见数据预处理/增强方式:

​ transform.compose:作用于数据,可以用于组合多个数据变换(数据增强以及数据预处理)

​ torch.nn.Sequential:作用于张量,用于组合多个神经网络层从而定义前向传播过程

data_transforms = {
    '''定义训练集:数据预处理方式'''
    'train': transforms.Compose([
        transforms.RandomRotation(45), 	# 随机旋转:-45~45度
        transforms.CenterCrop(224),	 # 中心裁剪:对大小不一的数据进行裁剪,裁减成相同规格。1024*1024 (reset) 256*256 (围绕中心裁剪) 224*224
        transforms.RandomHorizontalFlip(p=0.5),	# 随机水平翻转,p为概率
        transforms.RandomVerticalFlip(p=0.5),	# 随机垂直翻转,p为概率
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),	# brightness:亮度,contrast:对比度,saturation:饱和度,hue:色相 transforms.RandomGrayscale(p=0.0025),	# 随机转灰度图
        transforms.ToTensor(),	# 将数据集转换为Tensor
        # 标准化图像:迁移学习所用到模型的均值和标准差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
   '''定义训练集:数据预处理方式'''
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

4. dataLoader:批数据生成器

(1)数据集dataset

​ pytorch中数据集一般为TensorDataset数据集,元素为(x,y),其中x为属性,y为标签。

​ TensorDataset数据集可以通过下列代码进行制作

# 1. 通过Data.TensorDataset
train_dataset = Data.TensorDataset(x_train,y_train)

# 2.1 通过imageFolder制作
from torchvision import datasets
train_dataset = datasets.ImageFolder(
	root = train_dir,	# 大类文件路径
    transform = train_transform,	# 数据预处理方式
)

# 2.2 可以通过序列生成式同时生成train以及valid的dataset

image_dataset = {
    son:datasets.ImageFolder(
        root = os.path.join(root_dir, son),	# root_dir为主目录,son为子目录
        transform = data_transform[son]	# 见data_transform
    )
   	for son in ['train', 'valid']
}

​ 其中imageFolder假设:

  • 文件为目录结构;

  • 每个文件夹存储同一类别图片;

  • 文件名为类名

(2)通过DataLoader制作批数据生成迭代器:

​ 制作时,主要为如下几个参数{dataset, batch_size, shuffle, num_workers}

import torch.utils.data as Data

train_dataloader = Data.DataLoader(
	dataset = train_dataset, # 数据集,要求为Data.TensorDataset模型,有(x, y)
    batch_size = BATCH_SIZE, # 批数据集
    shuffle = True,	# 是否打乱
    num_workers = 2 # 多线程
)

valid_dataloader = Data.DataLoader(
	dataset = valid_dataset, # 数据集,要求为Data.TensorDataset模型,有(x, y)
    batch_size = BATCH_SIZE, # 批数据集
    shuffle = True,	# 是否打乱
    num_workers = 2 # 多线程
)

二、模型搭建


1. 常用套路

(1) 卷积操作

​ Conv操作一般是将Conv+BN+ReLU通过nn.Sequential()组合在一起

self.conv = nn.Sequential(
		nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
	)

2. 模板

(1) 通过init()函数定义所需的网络层

class Bottleneck(nn.Module):
    # 依残差网络为例
    def __init__(self, inplanes, planes, stride, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(planes, planes * self.extention, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(planes * self.extention),
            nn.ReLU(inplace=True)
        )

        self.relu = nn.ReLU(inplace=True)    

(2)通过forward()函数定义前向传播路径

 # 前向传播
    def forward(self, x):
        shortcut = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            shortcut = self.downsample(x)

        out = out + shortcut  # 不能写作out+=shortcut
        out = self.relu(out)
        return out

三、模型训练与保存


1. 模型训练

(1)创建模型并加载到GPU
model = ResNet50(bottleneck,[3,4,6,3],num_class)	# 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')	# 设置GPU
model = model.to(device)	# 将模型加载到GPU中

(2)训练以及验证

针对每一个epoch:模型训练需要对每个epoch进行训练(train)以及验证(valid):

  • 训练过程(train):
    • 使用model.train(),开起BN以及DroupOut等操作
  • 验证过程(valid):
    • 使用model.eval(),关闭BN以及DroupOut等操作
    • 通过with torch.no_grad()关闭梯度计算
for phase in ['train','valid']:
	if phase == 'train':
		model.train()   # 训练
        torch.set_grad_enabled(True)	# 打开梯度计算
	else:
		model.eval()    # 验证
        torch.set_grad_enabled(False)	# 关闭梯度计算,减少显存压力
	
    """
    train以及valid的传播(前向更新参数、反向计算loss)
    """
    for inputs,lables in dataloader[phase]:
        # 将数据导入GPU
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 梯度清0
        optimizer.zero_grad()
        # 将数据传入模型
        outputs = resnet50(inputs)
        # 计算损失
        loss = criterion(outputs,labels)
        if phase == 'train':
            # 误差反向传播并更新参数
            loss.backward()
            optimizer.step()  

(3)相关参数的计算
loss_sum = 0.0
for phase in ['train','valid']:
    ...
    # 计算损失值
    loss_sum += loss.item()
    # 计算准确度
    -,preds = torch.max(outputs.data, dim=1)
    num_correct += preds.eq(labels.data).cpu().sum()
    num_data += labels.size(0)
   	acc_iter = num_correct / num_data
    # 输出:每一个batch的loss值以及acc
    if phase == 'train':
		print('train: [epoch:%d, iter:%d] batch_Loss: %.03f | Acc: %.3f%% ' % (epoch + 1, (batch_idx + 1 + epoch * len(data_loader[phase])), sum_loss / (batch_idx + 1), 100.*acc_val))
    
# Valid:度量
print('Waiting Val...')
acc_val_epoch = correct / total
acc_val__epoch_list.append(acc_val)
print('Val\'s Acc is: %.3f%%' % (100 * acc_val_epoch))
(4)保存模型
for phase in ['train','valid']:
    ...
    if phase == 'valid' and epoch_acc > bestAcc:
        # 更新最优值
        best_acc = epoch_acc
        best_model_dict = copy.deepcopy(model.state_dict())
        # 记录相关数据
        state = {
            'state_dict':model.state_dict(),
            'beat_acc': best_acc,
            'optimizer':optimizer.state_dict()
        }
        # 保存模型
       	torch.save(state,file_path)

for epoch in range(num_epoch):
    ...
    model.load_state_dict(best_model_dict)	# 加载训练的最好的模型,方便返回
    return model

四、模型测试


​ 模型测试主要包括两个方面:加载模型、获取测试结果

1. 加载模型

​ 加载模型时主要有两个操作:① 创建模型对象并导入GPU、 ② 导入已训练模型的参数

def load_model():
    # 创建model模型(类),并导入GPU
    test_model = model.ResNet50(model.Bottleneck, [3, 4, 6, 3], 14)
    test_model.to(device)
    # 将参数导入模型
    checkpoint = torch.load('checkpoints/best.pth')
  	test_model.load_state_dict(checkpoint['model_state_dict'])

    return test_model

2. 测试模型

​ 测试模型主要涉及到两个步骤:数据转换、预测

​ 数据转换时:需要先导入image图像,再通过transform将其转换为torch.tensor数据,最后传入cuda中(因为模型也在cuda)

​ 预测时:因为训练的模型是一个四维张量(batch,H,W,C),而我们检测的数据是一个三维张量。所以需要将其转换为一个四维张量,再传入模型。获取最大值即可

def  get_model_preds(test_model,img_pth):

    # 转换数据并传入cuda
    image = Image.open(img_pth)
    image_tensor = transforms.ToTensor()(image)
    image_tensor = image_tensor.to(device)

    # 预测
    image_tensor = image_tensor.unsqueeze(0)
    outputs = test_model(image_tensor)
    print(outputs)
    pred = torch.argmax(outputs, dim=1).item()

    return pred

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值