复现CVPR2024-Unifying Top-down and Bottom-up Scanpath Prediction Using Transformers记录

前言

由于笔者项目经验欠缺,在本文章代码复现时遇到诸多问题,多数问题在网上查阅并无法第一时间找到原因。并且由于文章和其研究方向较新,网络上复现过程记录几乎没有,故记录复现过程,或许能够帮助到有需求的人,主要供自己翻看。(笔者并不专业,在该领域也并未入门,有问题请指教)

论文

https://openaccess.thecvf.com/content/CVPR2024/papers/Yang_Unifying_Top-down_and_Bottom-up_Scanpath_Prediction_Using_Transformers_CVPR_2024_paper.pdf

14511bb4d97440459cbc033ca10952c5.png

代码

cvlab-stonybrook/HAT: CVPR 2024 "Unifying Top-down and Bottom-up Scanpath Prediction Using Transformers" (github.com)

环境搭建

笔者在服务器上搭建环境并实现代码,由于一些错误和原因(尚未解决),项目中的requirements.txt无法有效地使用conda搭建环境,因此直接使用pip安装需要的包。

运行环境:

python == 3.9(3.10以上版本会有问题)

pytorch == 2.0.1

CUDA ==11.7

Cython == 3.0.11

detectron2 == 0.6

numpy == 1.23.5

scikit-learn == 0.21.3 (sklearn多次遇到版本问题)

scipy == 1.10.0

Installation

git clone https://github.com/cvlab-stonybrook/HAT.git
cd HAT

Install Detectron2

git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2

Install MSDeformableAttn:

cd ./hat/pixel_decoder/ops
sh make.sh

笔者按给出的命令初始化并无报错,但后续出现了一些问题(不明原因报错),可以先尝试进行下面的步骤,若后续有问题再考虑下面的备选步骤,直接导入该包的源项目并初始化。

git clone https://github.com/fundamentalvision/deformable-detr

导入项目后,进入项目,按照README.md文件进行初始化

pip install -r requirements.txt
cd ./models/ops
sh ./make.sh

Download pretrained model weights (ResNet-50 and Deformable Transformer) with the following python code

 if not os.path.exists("./pretrained_models/"):
     os.mkdir('./pretrained_models')

 print('downloading pretrained model weights...')
 url = f"http://vision.cs.stonybrook.edu/~cvlab_download/HAT/pretrained_models/M2F_R50_MSDeformAttnPixelDecoder.pkl"
 wget.download(url, 'pretrained_models/')
 url = f"http://vision.cs.stonybrook.edu/~cvlab_download/HAT/pretrained_models/M2F_R50.pkl"
 wget.download(url, 'pretrained_models/')

建立一个python文件,import必要的包,将上述代码拷贝进去并运行即可下载预训练模型。

数据准备

本文章使用COCO-search18作为主要的数据集,可在官网下载数据。

https://sites.google.com/view/cocosearch/home

需要注意的是README.md并未给出具体的数据形式,而是要求参考GitHub - cvlab-stonybrook/Scanpath_Prediction: Predicting Goal-directed Human Attention Using Inverse Reinforcement Learning (CVPR2020)​​​​​​​​​​​​​

而其中的数据格式与结构并不完全适合本项目,通过代码的阅读并参考同样使用COCO-Search18的Gazeformer工作:cvlab-stonybrook/Gazeformer: Official codebase for "Gazeformer: Scalable, Effective and Fast Prediction of Goal-Directed Human Attention" (CVPR 2023) (github.com)

笔者认为完整的数据文件结构应为

<dataset_root>
    -- bbox_annos.npy                                # bounding box annotation for each image (available at COCO)
    -- cluste.npy                                    # string cluster
    -- scene_label_dict.npy                          # scene label
    -- all_target_ids.npy                            # task ids
    -- coco_search_fixations_512x320_on_target_allvalid.json         # all splits(train, validation and test) of human scanpaths (ground-truth)
    -- ./images                                      # COCO-search18 images(512x320)
        -- ./bottle
        -- ./bowl
        -- ...
    -- ./DCBs
        -- ./HR                                      # high-resolution belief maps of each input image (pre-computed)
        -- ./LR                                      # low-resolution belief maps of each input image (pre-computed)
    -- ./semantic_seq_full
        -- test.pkl
        -- ./segmentation_maps                       # segmentation maps for Semantic Sequence Score
            -- 000000000164.npy.gz
            -- ...

由于,本文使用的COCO-Search18图片是经过rescale为512x320(原size为1680x1050),可在COCO官网下载后自行rescale。

大部分的Scanpaths标注信息文件和belief图在Scanpath_Prediction项目中提供,而语义信息的文件可以在Gazeformer项目中找到,其中coco_search_fixations_512x320_on_target_allvalid.json文件可以通过拼接Gazeformer提供的三个.json文件得到。下面给出rescale_image和combine_json的代码(仅供参考)。

#images_rescale.py

from PIL import Image
import os

# 定义源目录和目标根目录
root_source_dir = 'COCO-Search18/images'
root_target_dir = 'COCO-Search18_rescaled/images'

# 遍历源目录中的所有子目录
for subdir in os.listdir(root_source_dir):
    source_dir = os.path.join(root_source_dir, subdir)
    target_dir = os.path.join(root_target_dir, subdir)
    print("prosessing",subdir)
    # 如果源子目录存在且是目录
    if os.path.isdir(source_dir):
        # 确保目标子目录存在
        os.makedirs(target_dir, exist_ok=True)

        # 遍历源子目录中的所有文件
        for filename in os.listdir(source_dir):
            if filename.endswith('.jpg'):
                # 构建源文件路径和目标文件路径
                source_path = os.path.join(source_dir, filename)
                target_path = os.path.join(target_dir, filename)
                
                # 打开图片
                with Image.open(source_path) as img:
                    # 调整图片尺寸
                    img_resized = img.resize((512, 320))
                    
                    # 保存调整后的图片到目标路径
                    img_resized.save(target_path)

print("所有图片处理完成")
#combine_json.py

import json

# 读取第一个JSON文件
with open('coco_search_fixations_512x320_on_target_allvalid_train.json', 'r', encoding='utf-8') as f1:
    data1 = json.load(f1)

# 读取第二个JSON文件
with open('coco_search_fixations_512x320_on_target_allvalid_val.json', 'r', encoding='utf-8') as f2:
    data2 = json.load(f2)


# 合并两个root的内容
combined_root = data1+ data2

# 创建一个新root
combined_data = combined_root

# 将合并后的数据写入新的JSON文件
with open('coco_search_fixations_512x320_on_target_allvalid.json', 'w', encoding='utf-8') as f3:
    json.dump(combined_data, f3, ensure_ascii=False, indent=4)

print("合并完成,/coco_search_fixations_512x320_on_target_allvalid.json 文件已创建。")

在数据准备完毕后,理应能够进行训练以及评估

训练与评估

README.md中给出的训练命令为

python train.py --hparams ./configs/coco_search18_dense_SSL.json --dataset-root <dataset_root> 

其中"coco_search18_dense_SSL.json"应根据实际任务更改为其提供的json文件,以TP任务为例,应为"coco_search18_dense_SSL_TP.json"。若是只进行测试,则在命令中加上"--eval-only"即可,以下为示例。

python train.py --hparams ./configs/coco_search18_dense_SSL_TP.json --dataset-root ./data/  --eval-only

训练和评估的结果与日志文件将保存在./assets中。

按照笔者的做法,在计算metrics时出现问题,semantic_seq_score均为0,其他metrics正常,或许在语义信息数据的准备中出现问题,暂时未解决,如成果发现原因并解决该问题,会及时更新。

结果可视化

801df370aa3f49b8b5cc711c9adb307f.png

项目中并无Scanpath可视化的代码,参考demo.ipynb和Scanpath_Prediction项目,给出读取json预测结果文件并进行Scanpath可视化的参考代码。

#plot_scanpath.py

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import argparse
import json
from os.path import isfile
import os

def convert_coordinate(X, Y, im_w, im_h):
    """
    convert from display coordinate to pixel coordinate

    X - x coordinate of the fixations
    Y - y coordinate of the fixations
    im_w - image width
    im_h - image height
    """
    display_w, display_h = 1680, 1050
    target_ratio = display_w / float(display_h)
    ratio = im_w / float(im_h)

    delta_w, delta_h = 0, 0
    if ratio > target_ratio:
        new_w = display_w
        new_h = int(new_w / ratio)
        delta_h = display_h - new_h
    else:
        new_h = display_h
        new_w = int(new_h * ratio)
        delta_w = display_w - new_w
    dif_ux = delta_w // 2
    dif_uy = delta_h // 2
    scale = im_w / float(new_w)
    X = (X - dif_ux) * scale
    Y = (Y - dif_uy) * scale
    return X, Y


    
def plot_scanpath(img, xs, ys, bbox=None, title=None):
    fig, ax = plt.subplots()
    ax.imshow(img)
    # cir_rad_min, cir_rad_max = 30, 60
    # min_T, max_T = np.min(ts), np.max(ts)
    # rad_per_T = (cir_rad_max - cir_rad_min) / float(max_T - min_T)

    for i in range(len(xs)):
        if i > 0:
            plt.arrow(xs[i - 1], ys[i - 1], xs[i] - xs[i - 1],
                      ys[i] - ys[i - 1], width=3, color='yellow', alpha=0.5)

    for i in range(len(xs)):
        # cir_rad = int(25 + rad_per_T * (ts[i] - min_T))
        cir_rad=15
        circle = plt.Circle((xs[i], ys[i]),
                            radius=cir_rad,
                            edgecolor='red',
                            facecolor='yellow',
                            alpha=0.5)
        ax.add_patch(circle)
        plt.annotate("{}".format(
            i+1), xy=(xs[i], ys[i]+3), fontsize=10, ha="center", va="center")

    # if bbox is not None:
    #     rect = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], 
    #         alpha=0.5, edgecolor='yellow', facecolor='none', linewidth=2)
    #     ax.add_patch(rect)

    ax.axis('off')
    if title is not None:
        ax.set_title(title)
    target_dir='test_visualize'
    filename='plot.jpg'
    os.makedirs(target_dir, exist_ok=True)
    plt.savefig(os.path.join(target_dir, filename))  
    plt.show()
    

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--fixation_path', type=str, help='the path of the fixation json file')
    parser.add_argument('--image_dir', type=str,default='data/coco-search18/images', help='the directory of the image stimuli')
    parser.add_argument('--random_trial', choices=[0, 1],
                        default=1, type=int, help='randomly drawn from data (default=1)')
    parser.add_argument('--trial_id', default=0, type=int, help='trial id (default=0)')
    parser.add_argument('--subj_id', type=int, default=-1,
                        help='subject id (default=-1)')
    parser.add_argument('--task', 
                        choices=['bottle', 'chair', 'cup', 'fork', 'bowl', 'mouse',
                        'microwave', 'laptop', 'key', 'sink', 'toilet', 'clock', 'tv',
                        'stop sign', 'car', 'oven', 'knife'],
                        default='bottle',
                        help='searching target')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    # load fixations data
    # with open(args.fixation_path, 'r') as f:
    #     scanpaths = json.load(f)
    with open('assets/R50_HAT_TP/predictions_TP.json', 'r') as f:
        scanpaths = json.load(f)
    scanpaths = list(filter(lambda x: x['task'] == args.task, scanpaths))
    if args.subj_id > 0:
        scanpaths = list(filter(lambda x: x['subject'] == args.subj_id, scanpaths))
    
    if args.random_trial == 1:
        id = np.random.randint(len(scanpaths))
    else:
        id = args.trial_id
    scanpath = scanpaths[id]
    img_name = scanpath['name']
    cat_name = scanpath['task']
    # bbox = scanpath['bbox']
    img_path = './{}/{}/{}'.format(args.image_dir, cat_name, img_name)
    print("This is target-present trial")

    if not isfile(img_path):
        print("image not found at {}".format(img_path))
        exit(-1)

    # load image
    print(img_path)
    img = mpimg.imread(img_path)
    im_h, im_w = img.shape[0], img.shape[1]

    # convert fixations from display coordinate to pixel coordinate
    X, Y = scanpath['X'], scanpath['Y']
    # X, Y = convert_coordinate(X, Y, im_w, im_h)

    # title = "target={}, correct={}".format(cat_name, scanpath['correct'])

    # plot_scanpath
    # plot_scanpath(img, X, Y, T, bbox, title)
    plot_scanpath(img, X, Y)

参考

cvlab-stonybrook/HAT: CVPR 2024 "Unifying Top-down and Bottom-up Scanpath Prediction Using Transformers" (github.com)

cvlab-stonybrook/Scanpath_Prediction: Predicting Goal-directed Human Attention Using Inverse Reinforcement Learning (CVPR2020) (github.com)

cvlab-stonybrook/Gazeformer: Official codebase for "Gazeformer: Scalable, Effective and Fast Prediction of Goal-Directed Human Attention" (CVPR 2023) (github.com)

Unifying Top-down and Bottom-up Scanpath Prediction Using Transformers (thecvf.com)

COCO-Search18 (google.com)

如发现问题,请积极指正!

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值