MMSegmentation 模型训练结果批量推理及结果保存脚本

import os
import torch
import cv2
import argparse
import numpy as np 
from pprint import pprint
from tqdm import tqdm
from mmseg.apis import init_model, inference_model


DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# 测试图像所在文件夹
IMAGE_FILE_PATH = r"dataset\test\images"
# 模型训练结果的config配置文件路径
CONFIG = r'work_dir\dmnet_r50\dmnet_r50-d8_4xb4-160k_ade20k-512x512.py'
# 模型训练结果的权重文件路径
CHECKPOINT = r'work_dir\dmnet_r50\best_mIoU_iter_25600.pth'
# 模型推理测试结果的保存路径,每个模型的推理结果都保存在`{save_dir}/{模型config同名文件夹}`下,如文末图片所示。
SAVE_DIR = r"work_dir\infer_results"

def parse_args():
    parser = argparse.ArgumentParser(description='Visualize CAM')
    parser.add_argument('--img', default=IMAGE_FILE_PATH, help='Image file')
    parser.add_argument('--config', default=CONFIG ,help='Config file')
    parser.add_argument('--checkpoint', default=CHECKPOINT, help='Checkpoint file')
    parser.add_argument('--device', default=DEVICE, help='device')
    parser.add_argument('--save_dir', default=SAVE_DIR, help='save_dir')

    args = parser.parse_args()
    return args

def make_full_path(root_list, root_path):
    file_full_path_list = []
    for filename in root_list:
        file_full_path = os.path.join(root_path, filename)
        file_full_path_list.append(file_full_path)
    return file_full_path_list


def read_filepath(root):
    from natsort import natsorted
    test_image_list = natsorted(os.listdir(root))
    test_image_full_path_list = make_full_path(test_image_list, root)
    return test_image_full_path_list

def main():
    args = parse_args()
    
    model_mmseg = init_model(args.config, args.checkpoint, device=args.device)
    
    for imgs in tqdm(read_filepath(args.img)):
        result= inference_model(model_mmseg, imgs)
        pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
        pred_mask[pred_mask == 1] =255
        save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}")
        
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        cv2.imwrite(os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png"), pred_mask, [cv2.IMWRITE_PNG_COMPRESSION, 0])


if __name__ == '__main__':
    main()

结果保存格式如下:每个模型的推理结果都保存在{save_dir}/{模型config同名文件夹}
在这里插入图片描述

  • 10
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 8
    评论
中文句子关系推断是一项重要的自然语言处理任务,可以用于文本分类、情感分析、问答系统等领域。在huggingface中,使用预训练模型进行中文句子关系推断的实现非常简单,下面是一个示例代码: ```python from transformers import AutoTokenizer, AutoModelForSequenceClassification # 加载中文BERT模型 tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") model = AutoModelForSequenceClassification.from_pretrained("bert-base-chinese") # 准备数据集 sentences = ["这是一个正向句子", "这是一个负向句子"] labels = [1, 0] # 进行数据预处理 inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt") # 进行模型训练和微调 outputs = model(**inputs, labels=labels) loss = outputs.loss logits = outputs.logits # 进行模型评估和推理 predictions = logits.argmax(dim=1) ``` 在上面的代码中,我们使用了中文BERT模型进行句子关系推断的训练和微调,使用了PyTorch框架进行模型训练推理。在进行模型训练和微调时,我们需要指定模型输入和输出的格式,以及损失函数和优化器的选择。在进行模型评估和推理时,我们可以使用模型输出的logits进行分类,得到模型对于输入句子的分类结果。 需要注意的是,上面的代码只是一个简单的示例,实际应用中还需要根据具体任务进行模型调整和性能优化。同时,在进行中文句子关系推断的实战中,还需要注意数据集的选择和预处理,以及模型训练的超参数的选择等方面。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卖报的大地主

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值