数据集混合代码
多模态模型进行训练时,通常需要混合其他数据,以避免训练过程中对其他知识能力表现的降低。
'''
对不同的数据集进行混合,返回混合后的数据集,设置不同的数据集的权重
'''
import os
import json
import random
import matplotlib.pyplot as plt
import colorsys
import numpy as np
# 根据json文件统计样本数量
def count_samples(data_path):
if os.path.splitext(data_path)[1] == '.jsonl':
with open(data_path, 'r') as f:
length = sum(1 for line in f)
# print(os.path.basename(data_path), length)
return length
elif os.path.splitext(data_path)[1] == '.json':
with open(data_path, 'r') as f:
data = json.load(f)
length = len(data)
# print(os.path.basename(data_path), length)
return length
def generate_evenly_spaced_colors(num_colors):
colors = []
for i in range(num_colors):
hue = i / num_colors
lightness = 0.5
saturation = 0.9
r, g, b = colorsys.hls_to_rgb(hue, lightness, saturation)
colors.append((r, g, b))
return colors
def plot_data(counts,save_path='/home/app/examples/lb/san_nuo/datasets/fig',name='all'):
labels = counts.keys()
sizes = counts.values()
colors = generate_evenly_spaced_colors(len(labels))
# 画圆饼图
plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140)
plt.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
# plt.show()
os.makedirs(save_path, exist_ok=True)
plt.savefig(os.path.join(save_path, f'{name}_dataset_distribution.png'))
plt.close()
## 设置比例因子,
def random_split(total, n,type_name='even'):
if type_name == 'random':
"""将整数total随机分为n个部分"""
splits = sorted(random.sample(range(1, total), n - 1))
return [splits[i] - splits[i - 1] if i > 0 else splits[i] for i in range(n - 1)] + [total - splits[-1]]
if type_name == 'even':
"""将整数total平均分为n个部分"""
return [total // n] * n
def set_beta(seed,counts,alphta,beta,c):# counts dont include app
if beta == None:
random.seed(seed)
# 随机设置数据集数量,通过字典返回
beta = {}
totals = alphta*c
values = random_split(totals, len(counts))
for i,key in enumerate(counts.keys()):
beta[key] = values[i]
beta['app'] = alphta
return beta
else:
# 通过字典返回
return beta
# # 抽取指定数量的样本,随机抽取
# def extract_samples(data_path, num,shuffle=True):
# if shuffle:
# np.random.seed(25)
# with open(data_path, 'r') as f:
# data = json.load(f)
# if shuffle:
# random.shuffle(data)
# data = data[:num]
# f.close()
# return data,len(data)
# 抽取指定数量的样本,随机抽取
def extract_samples(data_path, num,shuffle=True):
if shuffle:
np.random.seed(25)
with open(data_path, 'r') as f:
if os.path.splitext(data_path)[1] == '.jsonl':
data = f.readlines()
for i in range(len(data)):
data[i] = json.loads(data[i])
elif os.path.splitext(data_path)[1] == '.json':
data = json.load(f)
if shuffle:
random.shuffle(data)
data = data[:num]
f.close()
return data,len(data)
def merge_data(datas:list,save_path:str,shuffle=True):
merge_data = []
for data in datas:
merge_data.extend(data)
if shuffle:
random.shuffle(merge_data)
if save_path is not None:
with open(save_path, 'w') as f:
json.dump(merge_data, f, ensure_ascii=False, indent=4)
return merge_data
if __name__ == '__main__':
# coco
coco_path = 'database/coco2017/coco2017_swift.json'
# ai2d
ai2d_path = 'database/san_nuo_app_swift/ai2d_train_12k.jsonl'
# chartqa
chartqa_path = 'database/san_nuo_app_swift/chartqa_train_18k.jsonl'
#docvqa
docvqa_path = 'database/san_nuo_app_swift/docvqa_train_10k.jsonl'
# original
app_path ='database/san_nuo_app/app/train_data_v3.jsonl'
# merge_path
merge_path = 'database/san_nuo_app_swift/merge_data.json'
counts ={"app": 0,
"coco":0,
"chartqa":0,
"ai2d_train_12k":0,
"docvqa_train_10k":0,
}
counts["app"] = count_samples(app_path)
counts["coco"] = count_samples(coco_path)
counts["chartqa"] = count_samples(chartqa_path)
counts["ai2d_train_12k"] = count_samples(ai2d_path)
counts["docvqa_train_10k"] = count_samples(docvqa_path)
plot_data(counts)
# 设置比例因子,生成新的数据集
alphta = counts['app']
new_counts=set_beta(25,counts,alphta,beta=None,c=2)
plot_data(new_counts,name='new')
print(new_counts)
app_data,app_data_len = extract_samples(app_path, new_counts['app'])
# 读取数据集
coco_data,coco_data_len = extract_samples(coco_path, new_counts['coco'])
chartqa_data,chartqa_data_len = extract_samples(chartqa_path, new_counts['chartqa'])
ai2d_train_12k_data,ai2d_train_12k_data_len = extract_samples(ai2d_path, new_counts['ai2d_train_12k'])
docvqa_train_10k_data ,docvqa_train_10k_data_len= extract_samples(docvqa_path, new_counts['docvqa_train_10k'])
### 检验数量是否正确
_new_counts = {
"app": app_data_len,
"coco": coco_data_len,
"chartqa":chartqa_data_len,
"ai2d_train_12k":ai2d_train_12k_data_len,
"docvqa_train_10k":docvqa_train_10k_data_len,
}
plot_data(_new_counts,name='check_new')
# 合并数据集
merge_data([app_data,coco_data,chartqa_data,ai2d_train_12k_data,docvqa_train_10k_data],
merge_path)
原始数据分布
混合后数据分布