本文任务:1、mmdetection框架下针对医学图像病灶检测的gradio检测界面可视化
2、结合vqa问答大模型实现本地预测(英文问答效果更好)
咸鱼:医学图像检测的科研狗
先上效果图:
一、环境
pytorch相关 :
- python=3.8.19=h955ad1f_0
- pytorch=2.0.1=py3.8_cuda11.8_cudnn8.7.0_0
- pytorch-cuda=11.8=h7e8668a_5
- torchaudio=2.0.2=py38_cu118
- torchtriton=2.0.0=py38
- torchvision=0.15.2=py38_cu118
mmcv:
- mmcv==2.0.0
- mmengine==0.10.3
语言模型:
- tokenizers==0.15.2
mmdetection:
-mmdet == 3.1.0
二、界面可视化实现
界面可视话基于gradio实现
2.1 gradio的基本用法为:
Gradio是一个Python库,它允许用户通过简单的代码快速创建模型的可视化界面。这对于机器学习模型的演示和测试非常有用,因为它可以让用户与模型进行交互,而无需深入了解前端开发技术。本章节将介绍Gradio的基本用法、参数指标,并提供一个简单的示例。
基本用法
安装Gradio
首先,安装Gradio库。可以通过pip进行安装:
pip install gradio
本文代码:
with gr.Blocks() as demo:
with gr.Tab("image and plot test"):
with gr.Row():
with gr.Column():
radio1 = gr.Radio(["pneumonia", "brain"], value="pneumonia",
label="type model", info="Where did they go?")
bottom1 = gr.Button(value="Model init")
out1 = gr.Textbox()
src_img1 = gr.Image(interactive=True,height=640,width=640, label="Input Image")
slider1 = gr.Slider(0, 1, value=0.5, step=0.1, label="Count",
info="Choose between 0 and 1")
bottom2 = gr.Button(value="Inference")
with gr.Column():
dst_img2 = gr.Image(height=480,width=640,label="Output Image")
out2 = gr.Textbox()
qin = gr.Textbox(value='What disease might the patient in this medical image have? and give treatment suggestions')
bottom3 = gr.Button(value="VAQ")
qout = gr.Textbox()
bottom4 = gr.Button(value="下一个病例")
# 当单选按钮变化时,更新模型
# radio1.input(mmdetection_model.init_model, inputs=radio1)
# 当选择button时,模型初始化
bottom1.click(mmdetection_model.init_model, inputs=radio1, outputs= out1)
# 当滑动条变化时,使用分析类的方法进行图像分析
# slider1.change(mmdetection_model.draw, inputs=[src_img1, slider1], outputs=dst_img2)
bottom2.click(mmdetection_model.draw, inputs=[src_img1, slider1], outputs=[out2,dst_img2])
bottom3.click(mmdetection_model.answer, inputs=[qin], outputs=qout)
bottom4.click(mmdetection_model.remove, outputs=None)
demo.launch()
参数指标
Gradio界面的创建主要涉及以下几个参数:
gr.Blocks()
: 用于创建一个块级布局,它是界面结构的最外层容器。gr.Tab()
: 标签页,可以在一个界面中创建多个独立的区域。gr.Row()
: 行,用于在水平方向上排列多个组件。gr.Column()
: 列,用于在垂直方向上排列组件。
界面中涉及到的函数主要有:
gr.Radio
gr.Radio
用于创建单选按钮组,用户可以从一组选项中选择一个。
基本用法:
gr.Radio(choices, value=None, label=None, inline=False, info=None)
choices
: 一个列表,包含所有可选的选项。value
: 默认选中的选项。label
: 单选按钮组的标题。inline
: 是否将选项内联显示(默认为False,即选项垂直排列)。info
: 提供关于此单选按钮组的额外信息。
gr.Button
gr.Button
用于创建按钮,用户点击后可以触发特定的操作。
基本用法:
gr.Button(label=None, action=None)
label
: 按钮上显示的文本。action
: 点击按钮时执行的函数或方法。
gr.Textbox
gr.Textbox
用于创建文本输入框,用户可以输入文本。
基本用法:
gr.Textbox(value="", placeholder=None, label=None, type=None, info=None)
value
: 输入框的初始值。placeholder
: 输入框内的提示文本。label
: 输入框的标题。type
: 输入框的类型(如text
,number
,date
等)。info
: 提供关于此输入框的额外信息。
gr.Image
gr.Image
用于创建图像输入组件,用户可以上传图片。
基本用法:
gr.Image(interactive=False, label=None, width=None, height=None)
interactive
: 是否允许用户在界面中直接操作图像(如缩放、移动等)。label
: 图像输入组件的标题。width
和height
: 图像显示的宽度和高度。
gr.Slider
gr.Slider
用于创建滑动条组件,用户可以通过滑动选择一个值。
基本用法:
gr.Slider(min, max, value=None, step=None, label=None, info=None)
min
和max
: 滑动条的最小值和最大值。value
: 滑动条的初始值。step
: 滑动条的步长,即用户每次滑动时值的变化量。label
: 滑动条的标题。info
: 提供关于此滑动条的额外信息。
2.2 可视化界面中其他函数的:
init_model模型初始化函数:初始化mmdetection中的网络并加载权重:
if isinstance(type, str):
if type == 'pneumonia':
config_file = '/demo/config/faster-rcnn_r50_fpn_2x_coco_RSNA_all.py'
checkpoint_file = "/demo/weight/feiyan_48.pth"
elif type == 'brain':
config_file = '/demo/config/faster-rcnn_r50_fpn_2x_coco_brain_all.py'
checkpoint_file = "/demo/weight/brain_48.pth"
time1 = time.time()
self.model = init_detector(config_file, checkpoint_file, device='cuda:0')
draw函数:模型推理+坐标转换+绘图:
def draw(self, img, scores=0.5):
result = inference_detector(self.model, img)
pre_bbox, pre_label = analysis_result(result,score_thr=scores)
#获取pre的box和label
pre_num_label = len(pre_label)
pre_boxes = np.zeros((pre_num_label, 5), dtype=np.uint16)
#获取pre的类别和bbox
for ix in range(pre_num_label):
xmax = float(pre_bbox[ix][2])
xmin = float(pre_bbox[ix][0])
ymax = float(pre_bbox[ix][3])
ymin = float(pre_bbox[ix][1])
pre_boxes[ix, 0:4] = [xmin, ymin, xmax, ymax]
#标签转换
pre_boxes[ix, 4] = pre_label[ix]
img = draw_imagev2(img, pre_boxes)
return img
大模型部分应用的为Hugging face中公开的医学问答大模型,并没有做其他改进仅进行了本地的预测因此这里将不再赘述。
def answer(self, question):
max_new_tokens=1200
pre_img=os.path.join(current_working_directory, 'tmp_result/pre_img.jpg')
prompts = [
pre_img,
"Question:" + question,
"Answer:"
]
answer = do_inference(self.VAQmodel, self.VAQprocessor, prompts, max_new_tokens)
return answer