Unet网络源码的改进,批量处理图片

本文介绍了如何对Pytorch-UNet的predict_img函数进行修改,使其能够处理文件夹中的图片,并详细讲解了关键代码段。作者展示了如何在终端中批量预测并保存处理后的结果,同时涵盖了模型加载、参数调整和输出可视化。
摘要由CSDN通过智能技术生成

https://github.com/milesial/Pytorch-UNet

获取的源代码中,在predict.py只允许通过在终端输入预测图片进行预测,通过对部分代码的修改实现对文件夹下的图片进行处理,以及对predict.py中代码的理解。

import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask


def predict_img(net,
                full_img,
                device,
                scale_factor=1,  # 比例因子
                out_threshold=0.5):  # out_threshold 阈值
    net.eval()
    # BasicDataset.preprocess 对传入图片进行预处理,将原图按照scale值进行缩放,并且对图片进行归一化处理,对于不是mask的将图像通道转为CHW
    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
    # 对图像升维,(3,360,640)->(1,3,360,640)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        # (1,2,360,640)
        output = net(img)

        if net.n_classes > 1:
            # (2,360,640)
            probs = F.softmax(output, dim=1)[0]
        else:
            probs = torch.sigmoid(output)[0]

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((full_img.size[1], full_img.size[0])),
            transforms.ToTensor()
        ])
        # (2,720,1280)
        full_mask = tf(probs.cpu()).squeeze()

    if net.n_classes == 1:
        return (full_mask > out_threshold).numpy()
    else:
        return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()


def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--model', '-m', default='checkpoints/checkpoint_epoch10.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')  # 模型权重
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', default='D:/Pytorch-UNet-master/video2imgs',
                        help='Filenames of input images')  # detect图片文件夹地址
    parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', default='D:/Pytorch-UNet-master/runs/detect/',
                        help='Filenames of output images')  # predict后保存的图片的文件夹地址
    parser.add_argument('--viz', '-v', action='store_true',
                        help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
                        help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.5,
                        help='Scale factor for the input images')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')  # 双线性

    return parser.parse_args()


def get_output_filenames(args):
    def _generate_name(fn):
        return f'{os.path.splitext(fn)[0]}_OUT.png'

    return args.output or list(map(_generate_name, args.input))


def mask_to_image(mask: np.ndarray):
    # ndim 维度
    # np.argmax(mask, axis=0) 按列搜索最大值的索引值
    # Image.fromarray array转为Image
    if mask.ndim == 2:
        return Image.fromarray((mask * 255).astype(np.uint8))
    elif mask.ndim == 3:
        return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))


if __name__ == '__main__':
    args = get_args()  # 获取args中的参数
    in_files = args.input  # 获取输入图片名字/图片路径
    out_files = get_output_filenames(args)  # 获取输出图片名字/保存路径

    net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)  # 初始化网络模型结构

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 处理设备
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    net.to(device=device)
    # torch.load加载模型参数权重,load_state_dict将参数权重加载到新的模型
    net.load_state_dict(torch.load(args.model, map_location=device))
    logging.info('Model loaded!')
    # 遍历文件夹下的所有文件
    infiles = os.listdir(in_files)
    # enumerate 获取可遍历的数组的 下标和值
    for i, filename in enumerate(infiles):
        logging.info(f'\nPredicting image {filename} ...')
        img = Image.open(in_files+'/'+filename)

        # 执行predict_img函数
        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)

        # args.no_save : 默认false
        # parser 中用了action 表示一个开关,在终端中输入带有-n时启动开关,值为True,则在if判断后表示不保存
        if not args.no_save:
            # 输出图片的绝对路径, 输出图片名为输入图像名+_OUT的后缀
            out_filename = out_files + filename.split('.')[0] + '_OUT.jpg'
            result = mask_to_image(mask)
            result.save(out_filename)
            logging.info(f'Mask saved to {out_filename}')

        if args.viz:
            logging.info(f'Visualizing results for image {filename}, close to continue...')
            plot_img_and_mask(img, mask)

  • 5
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值