放射学报告是一项复杂的任务,需要详细的图像理解、多个输入的整合(包括与之前的影像比较)和精确的语言生成。这使得它非常适合开发和使用生成式多模型。在这里,我们将报告生成扩展到包括在图像上定位单个发现的任务,我们称之为基于实际情况的报告生成。先前的研究表明,基于实际情况的报告对于澄清图像理解和解释AI生成的文本非常重要。因此,基于实际情况的报告有望提高自动报告草拟的实用性和透明度。为了评估基于实际情况的报告,我们提出了一个新颖的评估框架RadFact,利用大型语言模型(LLMs)的推理能力。RadFact评估单个生成的句子的事实性,以及当存在时生成的空间定位的正确性。我们引入了MAIRA-2,这是一个大型多模型,将放射学特定的图像编码器与LLM相结合,并针对胸部X光片的基于实际情况的报告生成进行了训练。MAIRA-2使用比以前探索的更全面的输入:当前的正面图像、当前的侧面图像、之前的正面图像和报告,以及当前报告的指示、技术和比较部分。我们证明,这些补充显著提高了报告质量并减少了幻觉,在MIMIC-CXR上建立了一个新的发现生成(没有基于实际情况的报告)的最新技术水平,同时证明了基于实际情况的报告作为一项新的、更丰富的任务的可行性。
https://huggingface.co/microsoft/maira-2
Finetuned from model [optional]: vicuna-7b-1.5, RAD-DINO-MAIRA-2
代码
由于 MAIRA-2 需要 transformers>=4.46.0.dev0 的版本,因此您可能暂时需要从源代码安装 transformers。
pip install git+https://github.com/huggingface/transformers.git@main
from transformers import AutoModelForCausalLM, AutoProcessor
from pathlib import Path
import torch
model = AutoModelForCausalLM.from_pretrained("microsoft/maira-2", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/maira-2", trust_remote_code=True)
device = torch.device("cuda")
model = model.eval()
model = model.to(device)
我们需要获取一些数据来演示前向传递。 在本例中,我们将从 IU X 射线数据集中收集一个示例,该数据集具有许可授权。
import requests
from PIL import Image
def get_sample_data() -> dict[str, Image.Image | str]:
"""
Download chest X-rays from IU-Xray, which we didn't train MAIRA-2 on. License is CC.
We modified this function from the Rad-DINO repository on Huggingface.
"""
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
def download_and_open(url: str) -> Image.Image:
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
return Image.open(response.raw)
frontal_image = download_and_open(frontal_image_url)
lateral_image = download_and_open(lateral_image_url)
sample_data = {
"frontal": frontal_image,
"lateral": lateral_image,
"indication": "Dyspnea.",
"comparison": "None.",
"technique": "PA and lateral views of the chest.",
"phrase": "Pleural effusion." # For the phrase grounding example. This patient has pleural effusion.
}
return sample_data
sample_data = get_sample_data()
Without Grounding
processed_inputs = processor.format_and_preprocess_reporting_input(
current_frontal=sample_data["frontal"],
current_lateral=sample_data["lateral"],
prior_frontal=None, # Our example has no prior
indication=sample_data["indication"],
technique=sample_data["technique"],
comparison=sample_data["comparison"],
prior_report=None, # Our example has no prior
return_tensors="pt",
get_grounding=False, # For this example we generate a non-grounded report
)
processed_inputs = processed_inputs.to(device)
with torch.no_grad():
output_decoding = model.generate(
**processed_inputs,
max_new_tokens=300, # Set to 450 for grounded reporting
use_cache=True,
)
prompt_length = processed_inputs["input_ids"].shape[-1]
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
decoded_text = decoded_text.lstrip() # Findings generation completions have a single leading space
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
print("Parsed prediction:", prediction)
('There is a large right pleural effusion.', [(0.055, 0.275, 0.445, 0.665)]),
('The left lung is clear.', None),
('No pneumothorax is identified.', None),
('The cardiomediastinal silhouette is within normal limits.', None),
('The visualized osseous structures are unremarkable.', None)
Grounding
processed_inputs = processor.format_and_preprocess_phrase_grounding_input(
frontal_image=sample_data["frontal"],
phrase=sample_data["phrase"],
return_tensors="pt",
)
processed_inputs = processed_inputs.to(device)
with torch.no_grad():
output_decoding = model.generate(
**processed_inputs,
max_new_tokens=150,
use_cache=True,
)
prompt_length = processed_inputs["input_ids"].shape[-1]
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
print("Parsed prediction:", prediction)
('Pleural effusion.', [(0.025, 0.345, 0.425, 0.575)])