使用pytorch做Resnet迁移学习实现图像分类(完整流程)

图像处理:

先对所有的图像的大小reshape到224*224(Resnet输入为224*224)

def Image_PreProcessing(imagepath, targetpath):
    # 待处理图片存储路径
    im = cv2.imread(imagepath, 1)
    h, w, _ = im.shape
    print(im)
    t = 0

    top, bottom, left, right = (0, 0, 0, 0)

    # 对于长宽不相等的图片,找到最长的一边
    longest_edge = max(h, w)

    # 计算短边需要增加多上像素宽度使其与长边等长
    if h < longest_edge:
        dh = longest_edge - h
        top = int(dh // 2)
        bottom = int(dh - top)
    elif w < longest_edge:
        dw = longest_edge - w
        left = int(dw // 2)
        right = int(dw - left)
    else:
        pass

        # RGB颜色
    BLACK = [0, 0, 0]
    print(right)
    # 给图像增加边界,是图片长、宽等长,cv2.BORDER_CONSTANT指定边界颜色由value指定
    constant = cv2.copyMakeBorder(im, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0)

    imBackground = im.resize(constant, (224, 224))
    # 处理后的图片的存储路径,以及存储格式
    imBackground.save(targetpath, 'JPEG')

图像增强

其中包括,翻转,加噪,拉伸,颜色抖动等,

    def get_savename(self, operate):
        """
        :param export_path_base: 图像输出路径
        :param operate: 脸部区域名
        :return: 返回图像存储名
        """
        try:
            import time
            # 获取时间戳,用于区分图像
            now = time.time()
            tail_time = str(round(now * 1000000))[-4:]  # 时间戳尾数
            head_time = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
            # 时间标签
            label = str(head_time + tail_time)

            # 输出文件夹
            export_path_base = self.export_path_base
            # 子文件夹以“操作operate”命名
            out_path = export_path_base
            # 创建子文件夹
            if not os.path.exists(out_path):
                os.mkdir(out_path)

            # 存储完整路径
            savename = out_path + '/' + '_' + label + ".jpg"

            # 日志
            # logger.info('save:%s', savename)
            return savename

        except Exception as e:
            print(e)
            # logger.error('get_savename ERROR')
            # logger.error(e)

    def lightness(self, light):
        """改变图像亮度.
        推荐值:
            0.87,1.07
        明亮程度
            darker < 1.0 <lighter
        """
        try:
            operate = 'lightness_' + str(light)
            # 图像完整路径
            rootPath = self.rootPath

            with Image.open(rootPath) as image:
                # 图像左右翻转
                out = image.point(lambda p: p * light)
                # 重命名
                savename = self.get_savename('')
                # 图像存储
                out.save(savename)

            # 日志
            # logger.info(operate)
        except Exception as e:
            logger.error('ERROR %s', operate)
            logger.error(e)

    def rotate(self, angle):
        """图像旋转15度、30度."""
        try:
            operate = 'rotate_' + str(angle)
            # 图像完整路径
            rootPath = self.rootPath

            with Image.open(rootPath) as image:
                # 图像左右翻转
                out = image.rotate(angle)
                # 重命名
                savename = self.get_savename('')
                # 图像存储
                out.save(savename, quality=100)

            # 日志
            # logger.info(operate)
        except Exception as e:
            logger.error('ERROR %s', operate)
            logger.error(e)

    def transpose(self):
        """图像左右翻转操作."""
        try:
            operate = 'transpose'
            # 图像完整路径
            rootPath = self.rootPath

            with Image.open(rootPath) as image:
                # 图像左右翻转
                out = image.transpose(Image.FLIP_LEFT_RIGHT)
                # 重命名
                savename = self.get_savename('')
                # 图像存储
                out.save(savename, quality=100)  # quality=100

            # 日志
            # logger.info(operate)
        except Exception as e:
            logger.error('ERROR %s', operate)
            logger.error(e)

    def deform(self):
        """图像拉伸."""
        try:
            operate = 'deform'
            # 图像完整路径
            rootPath = self.rootPath

            with Image.open(rootPath) as image:
                w, h = image.size
                w = int(w)
                h = int(h)
                # 拉伸成宽为w的正方形
                out_ww = image.resize((int(w), int(w)))
                savename = self.get_savename('')
                out_ww.save(savename, quality=100)
                # 拉伸成宽为h的正方形
                out_ww = image.resize((int(h), int(h)))
                savename = self.get_savename('')
                out_ww.save(savename, quality=100)

            # 日志
            # logger.info(operate)
        except Exception as e:
            logger.error('ERROR %s', operate)
            logger.error(e)

    def crop(self):
        """提取四个角落和中心区域."""
        try:
            # operate = 'crop'
            # 图像完整路径
            rootPath = self.rootPath

            with Image.open(rootPath) as image:
                w, h = image.size
                # 切割后尺寸
                scale = 0.875
                # 切割后长宽
                ww = int(w * scale)
                hh = int(h * scale)
                # 图像起点,左上角坐标
                x = y = 0

                # 切割左上角
                x_lu = x
                y_lu = y
                out_lu = image.crop((x_lu, y_lu, ww, hh))
                savename = self.get_savename('')
                out_lu.save(savename, quality=100)
                # logger.info(operate + '_lu')

                # 切割左下角
                x_ld = int(x)
                y_ld = int(y + (h - hh))
                out_ld = image.crop((x_ld, y_ld, ww, hh))
                savename = self.get_savename('')
                out_ld.save(savename, quality=100)
                # logger.info(operate + '_ld')

                # 切割右上角
                x_ru = int(x + (w - ww))
                y_ru = int(y)
                out_ru = image.crop((x_ru, y_ru, w, hh))
                savename = self.get_savename('')
                out_ru.save(savename, quality=100)
                # logger.info(operate + '_ru')

                # 切割右下角
                x_rd = int(x + (w - ww))
                y_rd = int(y + (h - hh))
                out_rd = image.crop((x_rd, y_rd, w, h))
                savename = self.get_savename('')
                out_rd.save(savename, quality=100)
                # logger.info(operate + '_rd')

                # 切割中心
                x_c = int(x + (w - ww) / 2)
                y_c = int(y + (h - hh) / 2)
                out_c = image.crop((x_c, y_c, ww, hh))
                savename = self.get_savename('')
                out_c.save(savename, quality=100)
                # logger.info('提取中心')
        except Exception as e:
            logger.error('ERROR %s', 1)
            logger.error(e)

数据预处理:

包括划分训练集、验证集、测试集;数据的归一化等等。

def readImg(path):
    '''
    把图像集转三通道.convert('RGB')
    '''
    return Image.open(path).convert('RGB').resize((224, 224), Image.BILINEAR)


def ImageDataset(args):
    # 数据增强及归一化
    normalize = tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]
                                        )
    data_transforms = {
        'train': tv.transforms.Compose(
            [tv.transforms.Resize([224, 224]), tv.transforms.CenterCrop([224, 224]),
             tv.transforms.ToTensor(), normalize]
        ),
        'test': tv.transforms.Compose(
            [tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
             normalize]
        )
    }

    data_dir = args.data_dir
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                              data_transforms[x], loader=readImg)
                      for x in ['train', 'test']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
                                                  shuffle=(x == 'train'), num_workers=args.num_workers)
                   for x in ['train', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
    class_names = image_datasets['train'].classes
    return dataloaders, dataset_sizes, class_names

使用Resnet做迁移学习

def model():
    model = tv.models.resnet34(pretrained=True)
    print(model)
    for parma in model.parameters():
        parma.requires_grad = False

    num_fcs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_fcs, 256),
        nn.Dropout(p=0.4),
        nn.ReLU(inplace=True),
        nn.Linear(256, 40)
    )
    return model

打造训练方法

def train_model(args, model, criterion, optimizer, scheduler, num_epochs, dataset_sizes, use_gpu):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    device = torch.device('cuda' if use_gpu else 'cpu')

    for epoch in range(args.start_epoch, num_epochs):
        # 每一个epoch中都有一个训练和一个验证过程(Each epoch has a training and validation phase)
        for phase in ['train', 'test']:
            if phase == 'train':
                scheduler.step(epoch)
                # 设置为训练模式(Set model to training mode)
                model.train(True)
            else:
                # 设置为验证模式(Set model to evaluate mode)
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            tic_batch = time.time()

            # 在多个batch上依次处理数据(Iterate over data)
            for i, (inputs, labels) in enumerate(dataloders[phase]):
                # print(labels)
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 梯度置零(zero the parameter gradients)
                optimizer.zero_grad()

                # 前向传播(forward)
                # 训练模式下才记录操作以进行反向传播(track history if only in train)
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    # print(outputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # 训练模式下进行反向传播与梯度下降(backward + optimize only if in training phase)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 统计损失和准确率(statistics)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                batch_loss = running_loss / (i * args.batch_size + inputs.size(0))
                batch_acc = running_corrects.double() / (i * args.batch_size + inputs.size(0))

                if phase == 'train' and (i + 1) % args.print_freq == 0:
                    print(
                        '[Epoch {}/{}]-[batch:{}/{}] lr:{:.6f} {} Loss: {:.6f}  Acc: {:.4f}  Time: {:.4f} sec/batch'.format(
                            epoch + 1, num_epochs, i + 1, ceil(dataset_sizes[phase] / args.batch_size),
                            scheduler.get_lr()[0], phase, batch_loss, batch_acc,
                            (time.time() - tic_batch) / args.print_freq))
                    tic_batch = time.time()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if epoch == 0 and os.path.exists('result.txt'):
                os.remove('result.txt')
            with open('result.txt', 'a') as f:
                f.write('Epoch:{}/{} {} Loss: {:.4f} Acc: {:.4f} \n'.format(epoch + 1, num_epochs, phase, epoch_loss,
                                                                            epoch_acc))

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            writer.add_scalar(phase + '/Loss', epoch_loss, epoch)
            writer.add_scalar(phase + '/Acc', epoch_acc, epoch)

        if (epoch + 1) % args.save_epoch_freq == 0:
            if not os.path.exists(args.save_path):
                os.makedirs(args.save_path)
            torch.save(model.state_dict(), os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth"))

        # 深拷贝模型(deep copy the model)
        if phase == 'test' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    # 将model保存为graph
    writer.add_graph(model, (inputs,))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best test Accuracy: {:4f}'.format(best_acc))

    # 载入最佳模型参数(load best model weights)
    model.load_state_dict(best_model_wts)
    return model

主函数调用使用:

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='classification')
    # 图片数据的根目录(Root catalog of images)
    parser.add_argument('--data-dir', type=str, default='image')
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--num-epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.045)
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--print-freq', type=int, default=1)
    parser.add_argument('--save-epoch-freq', type=int, default=1)
    parser.add_argument('--save-path', type=str, default='output')
    parser.add_argument('--resume', type=str, default='', help='For training from one checkpoint')
    parser.add_argument('--start-epoch', type=int, default=0, help='Corresponding to the epoch of resume')
    args = parser.parse_args()

    # read data
    dataloders, dataset_sizes, class_names = ImageDataset(args)

    with open('class_names.json', 'w') as f:
        json.dump(class_names, f)

    # use gpu or not
    use_gpu = torch.cuda.is_available()
    print("use_gpu:{}".format(use_gpu))

    # get model
    model = model()
    # model = ResNet18()
    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            model.load_state_dict(torch.load(args.resume))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if use_gpu:
        model = torch.nn.DataParallel(model)
        model.to(torch.device('cuda'))
    else:
        model.to(torch.device('cpu'))

    # 用交叉熵损失函数(define loss function)
    criterion = nn.CrossEntropyLoss()   # size_average=False)

    # 梯度下降(Observe that all parameters are being optimized)
    optimizer_ft = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.00004)

    # Decay LR by a factor of 0.98 every 1 epoch
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=0.98)

    model = train_model(args=args,
                        model=model,
                        criterion=criterion,
                        optimizer=optimizer_ft,
                        scheduler=exp_lr_scheduler,
                        num_epochs=args.num_epochs,
                        dataset_sizes=dataset_sizes,
                        use_gpu=use_gpu)

    torch.save(model.state_dict(), os.path.join(args.save_path, 'best_model_wts.pth'))

    writer.close()

 

  • 10
    点赞
  • 152
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值