混合数据集训练

数据集混合代码

多模态模型进行训练时,通常需要混合其他数据,以避免训练过程中对其他知识能力表现的降低。

'''
对不同的数据集进行混合,返回混合后的数据集,设置不同的数据集的权重
'''
import os
import json
import random
import matplotlib.pyplot as plt
import colorsys

import numpy as np  

# 根据json文件统计样本数量

def count_samples(data_path):
    if os.path.splitext(data_path)[1] == '.jsonl':
        with open(data_path, 'r') as f:
            length = sum(1 for line in f)
            # print(os.path.basename(data_path), length)
        return length
    elif os.path.splitext(data_path)[1] == '.json':
        with open(data_path, 'r') as f:
            data = json.load(f)
            length = len(data)
            # print(os.path.basename(data_path), length)
        return length


def generate_evenly_spaced_colors(num_colors):
    colors = []
    for i in range(num_colors):
        hue = i / num_colors
        lightness = 0.5
        saturation = 0.9
        r, g, b = colorsys.hls_to_rgb(hue, lightness, saturation)
        colors.append((r, g, b))
    return colors


def plot_data(counts,save_path='/home/app/examples/lb/san_nuo/datasets/fig',name='all'):
    labels = counts.keys()
    sizes = counts.values()
    colors = generate_evenly_spaced_colors(len(labels))

    # 画圆饼图
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
    plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    # plt.show()
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f'{name}_dataset_distribution.png'))
    plt.close()
    

## 设置比例因子,
def random_split(total, n,type_name='even'):
    if type_name == 'random':
        """将整数total随机分为n个部分"""
        
        splits = sorted(random.sample(range(1, total), n - 1))
        return [splits[i] - splits[i - 1] if i > 0 else splits[i] for i in range(n - 1)] + [total - splits[-1]]
    if type_name == 'even':
        """将整数total平均分为n个部分"""
        return [total // n] * n

def set_beta(seed,counts,alphta,beta,c):# counts dont include app
    if beta == None:
        random.seed(seed)
        # 随机设置数据集数量,通过字典返回
        beta = {}
        totals = alphta*c
        values = random_split(totals, len(counts)) 
        for i,key in enumerate(counts.keys()):
            beta[key] = values[i]  
            
        beta['app'] = alphta 
        return beta
    else:
        # 通过字典返回
        return beta

# # 抽取指定数量的样本,随机抽取
# def extract_samples(data_path, num,shuffle=True):
#     if shuffle:
#         np.random.seed(25)
    
#     with open(data_path, 'r') as f:
#         data = json.load(f)
#         if shuffle:
#             random.shuffle(data)
#         data = data[:num]
#     f.close()
#     return data,len(data)

    # 抽取指定数量的样本,随机抽取
def extract_samples(data_path, num,shuffle=True):
    if shuffle:
        np.random.seed(25)
        with open(data_path, 'r') as f:
            if os.path.splitext(data_path)[1] == '.jsonl':
                data = f.readlines()
                for i in range(len(data)):
                    data[i] = json.loads(data[i])
                
            elif os.path.splitext(data_path)[1] == '.json':
                data = json.load(f)
                
            if shuffle:
                random.shuffle(data)
                
            data = data[:num]
        f.close()
        
        return data,len(data)

def merge_data(datas:list,save_path:str,shuffle=True):
    merge_data = []
    for data in datas:
        merge_data.extend(data)
            
    if shuffle:
        random.shuffle(merge_data)
        
    if save_path is not None:
        with open(save_path, 'w') as f:
            json.dump(merge_data, f, ensure_ascii=False, indent=4)
        
    return merge_data



if __name__ == '__main__':

    # coco
    coco_path = 'database/coco2017/coco2017_swift.json'
    # ai2d
    ai2d_path = 'database/san_nuo_app_swift/ai2d_train_12k.jsonl'
    # chartqa 
    chartqa_path = 'database/san_nuo_app_swift/chartqa_train_18k.jsonl'
    #docvqa
    docvqa_path = 'database/san_nuo_app_swift/docvqa_train_10k.jsonl'

    # original
    app_path ='database/san_nuo_app/app/train_data_v3.jsonl'
    
    # merge_path
    merge_path = 'database/san_nuo_app_swift/merge_data.json'
    
    counts ={"app": 0,
            "coco":0,
            "chartqa":0,
            "ai2d_train_12k":0,
            "docvqa_train_10k":0,
            }

    counts["app"] = count_samples(app_path)
    counts["coco"] = count_samples(coco_path)
    counts["chartqa"] = count_samples(chartqa_path)
    counts["ai2d_train_12k"] = count_samples(ai2d_path)
    counts["docvqa_train_10k"] = count_samples(docvqa_path)

    plot_data(counts)
    
    # 设置比例因子,生成新的数据集
    alphta = counts['app']
    new_counts=set_beta(25,counts,alphta,beta=None,c=2)
    plot_data(new_counts,name='new')   
    print(new_counts)
    
    app_data,app_data_len = extract_samples(app_path, new_counts['app'])

    # 读取数据集
    coco_data,coco_data_len = extract_samples(coco_path, new_counts['coco'])
    chartqa_data,chartqa_data_len = extract_samples(chartqa_path, new_counts['chartqa'])
    ai2d_train_12k_data,ai2d_train_12k_data_len = extract_samples(ai2d_path, new_counts['ai2d_train_12k']) 
    docvqa_train_10k_data ,docvqa_train_10k_data_len= extract_samples(docvqa_path, new_counts['docvqa_train_10k'])

    ### 检验数量是否正确
    _new_counts = {
                    "app": app_data_len,
                    "coco": coco_data_len,
                    "chartqa":chartqa_data_len,
                    "ai2d_train_12k":ai2d_train_12k_data_len,
                    "docvqa_train_10k":docvqa_train_10k_data_len,
                }
    plot_data(_new_counts,name='check_new')
    
    # 合并数据集
    merge_data([app_data,coco_data,chartqa_data,ai2d_train_12k_data,docvqa_train_10k_data],
               merge_path)

原始数据分布
在这里插入图片描述混合后数据分布
在这里插入图片描述

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值