基于强化学习微调多模态大模型vlm-r1

1. 前言

deepseek-R1采用了两种方式微调,一种较为传统的有监督SFT微调方式;第二种为强化学习的方式微调(GPRO)。基于此,本文采用最新的基于强化学习的微调方式对VLM-R1进行微调。

2. 强化学习微调的关键

强化学习微调的关键在于奖励函数的设计,因此,本人设计了营养健康方面的奖励函数:

2.1 内容的奖励机制,值得注意的是我是模改vlm里面IOU候选框的奖励机制,为了更适配下游我的任务,我把原来的奖励机制改为数量奖励和内容名称奖励。

def content_reward(completions, solution, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    # bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
    for content, sol in zip(contents, solution):
        reward = 0.0
        # Try symbolic verification first
        try:
            content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()
                has_numbers = bool(re.search(r'\d', sol))
                if has_numbers:
                    # For numeric answers, use exact matching
                    reward_num = numeric_reward(content_answer, sol)
                    print("数字奖励:", reward_num)
                    content_answer = remove_numbers(content_answer)
                    sol = remove_numbers(sol)
                    reward_content = ratio(content_answer, sol)
                    print("内容奖励:", reward_content)
                    reward = reward_content + reward_num
                # bbox_match = re.search(bbox_pattern, content_answer)
                # if bbox_match:
                #     bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                #     if iou(bbox, sol) > 0.5:
                #         reward = 1.0
        except Exception:
            pass  # Continue to next verification method if this fails
        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            # local_rank = int(os.getenv("LOCAL_RANK", 0))
            with open(log_path, "a", encoding='utf-8') as f:
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                f.write(f"Content: {content}\n")
                f.write(f"Solution: {sol}\n")
    return rewards

2.2format奖励机制

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

3.数据集的爬取

我这里爬取了4200条营养学的数据,我这里展示下爬虫的完整过程:

import requests
from bs4 import BeautifulSoup
import time
import csv

# 设置请求头模拟浏览器访问
headers = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
    'Referer': 'https://www.boohee.com/'
}
def scrape_food_group(food_items, group_id):
    for page_id in range(1, 11):
        url = f'https://www.xxxxx.com/food/group/{group_id}?page={page_id}'
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        soup = BeautifulSoup(response.text, 'html.parser')
        food_list = soup.find('ul', class_='food-list')
        if food_list:
            for item in food_list.find_all('li'):
                name = item.find('h4').get_text(strip=True) if item.find('h4') else 'N/A'
                heat = item.find('p').get_text(strip=True) if item.find('p') else 'N/A'
                img = item.find('img')['src'] if item.find('img') else 'N/A'
                food_items.append({
                        'name': name,
                        'heat': heat.split(":")[1],
                        'image': img,
                    })
    return food_items

def save_to_csv(data, filename='boohee_food.csv'):
    if not data:
        return
    keys = data[0].keys()
    with open(filename, 'w', newline='', encoding='utf-8-sig') as f:
        writer = csv.DictWriter(f, fieldnames=keys)
        writer.writeheader()
        writer.writerows(data)

if __name__ == '__main__':
    all_foods = []
    food_items = []
    for group_id in range(1, 10):
        print(f"Scraping group {group_id}...")
        foods = scrape_food_group(food_items, group_id)
        all_foods.extend(foods)
        time.sleep(2) 
        
    save_to_csv(all_foods)
    print(f"Done! Saved {len(all_foods)} food items to boohee_food.csv")

爬取好需要构建和训练集一样格式的格式,在这里我们默认的格式输入是:
在这里插入图片描述
转换代码如下:

'''数据格式:{
  "id": 1,
  "image": "Clevr_CoGenT_TrainA_R1/data/images/CLEVR_trainA_000001_16885.png",
  "conversations": [
    {"from": "human", "value": "<image>What number of purple metallic balls are there?"},
    {"from": "gpt", "value": "0"}
  ]
}'''
import glob
import pandas as pd
from openpyxl import load_workbook
import os
import random
import json
import csv
questions = ["这道菜是怎么做的?", "图中的菜需要哪些食材?", "请问这道菜的配方是什么?", "这道菜的主要配料有哪些?", "图中的菜是怎么烹饪的?", "这道菜用了哪些原料?"]
file_path = "/www/VLM-R1/xiangha_recipes_list.csv"
results_list = []
with open(file_path, mode='r', encoding='utf-8') as file:
    csv_reader = csv.reader(file)
    for i, row in enumerate(csv_reader):
        template = {
            "id": 1,
            "image": "",
            "conversations": [
            {"from": "human", "value": "<image>"},
                {"from": "gpt", "value": ""}
                ]
                }
        if i == 0:
            continue
        else:
            _, url1, answer, image_url, image_path  = row
            template["id"] = i
            template["image"] = image_path
            random_number = random.randint(0, len(questions)-1)
            template["conversations"][0]["value"] += questions[random_number]
            template["conversations"][1]["value"] = answer
            results_list.append(template)
#########################################################################################       
with open("/www/VLM-R1/datasets/own_datasets/own_data.json", "w", encoding="utf-8") as f:
    json.dump(results_list, f, ensure_ascii=False, indent=4)
print("数据已经转化完毕!")

结果展示

在这里插入图片描述

没有微调之前

请添加图片描述

使用微调后的权重

在这里插入图片描述
(源码有偿提供)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值