(亲测有效)将AnomalyGPT的web_demo转换为本地代码

 下面分别是原图,缺陷图,返回的二值化图像,实现代码在文后。

设置模型参数,加载权重文件的方式与web_demo.py中的相同

from model.openllama import OpenLLAMAPEFTModel
import torch
import numpy as np
import argparse
import matplotlib.pyplot as plt
import cv2
from PIL import Image as PILImage
import os

torch.cuda.set_device(5)

parser = argparse.ArgumentParser("AnomalyGPT", add_help=True)
command_args = parser.parse_args()

# init the model
args = {
    'model': 'openllama_peft',
    'imagebind_ckpt_path': 'pretrained_ckpt/imagebind_ckpt/imagebind_huge.pth',
    'vicuna_ckpt_path': 'pretrained_ckpt/vicuna_ckpt/7b_v0',
    'anomalygpt_ckpt_path': 'code/ckpt/train_supervised/pytorch_model.pt',
    'delta_ckpt_path': 'pretrained_ckpt/pandagpt_ckpt/7b/pytorch_model.pt',
    'stage': 2,
    'max_tgt_len': 128,
    'lora_r': 32,
    'lora_alpha': 32,
    'lora_dropout': 0.1
}
# 初始化模型
model = OpenLLAMAPEFTModel(**args)
# 加载模型权重
delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cuda:5'))
model.load_state_dict(delta_ckpt, strict=False)
delta_ckpt = torch.load(args['anomalygpt_ckpt_path'], map_location=torch.device('cuda:5'))
model.load_state_dict(delta_ckpt, strict=False)
model = model.eval().half().cuda()

print(f'[!] init the 7b model over ...')

下面是封装好的缺陷检测函数

def detect_anomalies(input_image_path, normal_image_path, max_length, top_p, temperature):
    # 准备 prompt
    prompt_text = "Is there any anomaly in the image?" # 这里简化了 prompt 的构建

    if not os.path.exists(input_image_path):
        print("Input image file does not exist.")
        return None
    if normal_image_path and not os.path.exists(normal_image_path):
        print("Normal image file does not exist.")
        return None

    # 模型预测
    try:
        response, pixel_output = model.generate({
            'prompt': prompt_text,
            'image_paths': [input_image_path] if input_image_path else [],
            'normal_img_paths': [normal_image_path] if normal_image_path else [],
            'audio_paths': [],
            'video_paths': [],
            'thermal_paths': [],
            'top_p': top_p,
            'temperature': temperature,
            'max_tgt_len': max_length,
            'modality_embeds': []  # 假设这里不需要模态嵌入
        }, web_demo=True)
    except Exception as e:
        print(f"Error in model generation: {e}")
        return None

    if pixel_output is None:
        print("No output generated by the model.")
        return None

    
    return response,pixel_output,

打开本地缺陷图片和原始图片,设置好其他参数,调用检测函数,即可得到检测返回的文本信息和返回的缺陷图片

nput_image_path = 'defectTest/inpaintA.png'
normal_image_path = 'normalImage/U101_90500000483A_,rgb.179.756.png'  # 如果没有正常图像,可以设置为 None
max_length = 512
top_p = 0.01
temperature = 1.0
image = cv2.imread(input_image_path)
image_width = image.shape[1]
image_height = image.shape[0]

response,pixel_output = detect_anomalies(input_image_path, normal_image_path, max_length, top_p, temperature)

对输出图像进行处理和保存

output = pixel_output.cpu().numpy().squeeze()
output = np.clip(output, 0, 1)
output = (output * 255).astype(np.uint8)
output = PILImage.fromarray(output)
output = output.resize((image_width, image_height), PILImage.LANCZOS)
output.save('output.png')

将输出图像转化为二值化图

gray_image = output.convert("L")    
threshold =50
binary_image = gray_image.point(lambda x: 0 if x < threshold else 255, '1')  
binary_image.save('binary_image.png')

获取缺陷区域的中心点坐标,可达到定位缺陷的目的

gray_image = output.convert("L")    
threshold =50
binary_image = gray_image.point(lambda x: 0 if x < threshold else 255, '1')  
binary_image.save('binary_image.png')
contours, hierarchy = cv2.findContours(cv2.imread('binary_image.png', cv2.IMREAD_GRAYSCALE), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)  
# 存储所有中心的坐标  
centers = []  
  
# 遍历所有轮廓  
for contour in contours:  
    # 计算轮廓的矩  
    M = cv2.moments(contour)  
      
    # 检查矩是否有效(即,轮廓是否包含像素)  
    if M["m00"] != 0:  
        # 计算质心(中心)  
        cX = int(M["m10"] / M["m00"])  
        cY = int(M["m01"] / M["m00"])  
          
        # 将中心坐标添加到列表中  
        centers.append((cX, cY))  
  
# 如果没有找到任何轮廓,则centers列表将为空  
if not centers:  
    print("No white regions found.")  
else:  
    # 打印所有中心的坐标  
    for center in centers:  
        print(f"Center: {center}")

想要输出图像保存原始大小,需要屏蔽图像剪裁操作,然后再将输出图像恢复原始大小 ,这会影响检测精确度,如果只追求精确度不需要恢复原始大小的可以忽略这一步

 返回结果

  • 14
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值