【StratifiedKFold】分层抽样数据集来减少训练时长

以TubeR来说明:

ANNO_PATH: '/data1/ws/Tubelet-transformer/datasets/assets/my_ava_{}_v22.json'
train_bbox_json = json.load(open(cfg.CONFIG.DATA.ANNO_PATH.format("train")))
train_video_frame_bbox, train_frame_keys_list = train_bbox_json["video_frame_bbox"], train_bbox_json["frame_keys_list"]

在这里插入图片描述
也就是说,找到dataloader下的训练文件来源,我们对这个文件进行分层抽样。

1. 查询原配置文件信息

import json
import pandas as pd

train_bbox_json = json.load(open('/data1/ws/Tubelet-transformer/datasets/assets/ava_train_v22.json'))
train_video_frame_bbox, train_frame_keys_list = train_bbox_json["video_frame_bbox"], train_bbox_json["frame_keys_list"]
"""
train_bbox_json: 
video_frame_bbox: 'zlVkeKC6Ha8,1782': {'bboxes': [[0.042, 0.046, 0.97, 0.98]], 'acts': [[10, 73, 79]]},
frame_keys_list:  'zlVkeKC6Ha8,1782'
"""
print(list(train_bbox_json)[:2])
print(list(train_video_frame_bbox.items())[:2]) 
print(list(train_frame_keys_list)[:2]) 

['video_frame_bbox', 'frame_keys_list']
[('-5KQ66BBWC4,0902', {'bboxes': [[0.077, 0.151, 0.28300000000000003, 0.8109999999999999], [0.332, 0.19399999999999998, 0.48100000000000004, 0.8909999999999999], [0.505, 0.105, 0.653, 0.78], [0.626, 0.146, 0.805, 0.818], [0.805, 0.222, 0.997, 1.0]], 'acts': [[79, 8], [79, 8], [8], [8], [79, 8]]}), ('-5KQ66BBWC4,0903', {'bboxes': [[0.0, 0.162, 0.177, 0.804], [0.141, 0.158, 0.298, 0.825], [0.32799999999999996, 0.182, 0.484, 0.895], [0.507, 0.147, 0.6659999999999999, 0.789], [0.642, 0.158, 0.7909999999999999, 0.8590000000000001], [0.785, 0.15, 0.8859999999999999, 0.703], [0.802, 0.267, 0.9940000000000001, 0.971], [0.865, 0.158, 0.991, 0.436]], 'acts': [[79, 8], [11, 79], [79, 8], [79, 8], [8], [11, 79], [79, 8], [79, 8]]})]
['D8Vhxbho1fY,1798', 'PcFEhUKhN6g,1408']

2. 分层抽样 StratifiedKFold

from sklearn.model_selection import StratifiedGroupKFold
import pandas as pd

# Load the data from the csv file
data = pd.read_csv('/data1/ws/Tubelet-transformer/datasets/assets/ava_val_v2.2.csv')

# Get the target variable
target = data.iloc[:, -2]

# Create the StratifiedKFold object
skf = StratifiedGroupKFold(n_splits=10, shuffle=True, random_state=42)

# Get the indices of the samples to keep
_, indices = next(skf.split(data, target, groups=data.iloc[:, 0]))

# Keep only the selected samples
selected_data = data.iloc[indices]

# Save the selected samples to a new csv file
selected_data.to_csv('/data1/ws/Tubelet-transformer/datasets/selected_val_data.csv', index=False)

80万条数据变为8万。
在这里插入图片描述

/data1/ws/anaconda3/envs/pytorch/lib/python3.11/site-packages/sklearn/model_selection/_split.py:737: UserWarning: The least populated class in y has only 7 members, which is less than n_splits=10.
warnings.warn(
/data1/ws/anaconda3/envs/pytorch/lib/python3.11/site-packages/sklearn/model_selection/_split.py:737: UserWarning: The least populated class in y has only 3 members, which is less than n_splits=10.
warnings.warn(
UserWarning: y 中填充最少的类只有 3 个成员,小于 n_splits= 10.
警告.警告(

3. 清理数据 (除0和删最后一列)

最后一列数据是person id,我们任务中不需要。

import pandas as pd

original_data = pd.read_csv('/data1/ws/Tubelet-transformer/datasets/selected_data.csv')

# 删除最后一列为0的数据
filtered_data = original_data.loc[original_data.iloc[:, -1] != 0]
filtered_data.to_csv('/data1/ws/Tubelet-transformer/datasets/2.csv', index=False)

# 再删除最后一列数据
df = filtered_data.iloc[:, :-1]
df.to_csv('/data1/ws/Tubelet-transformer/datasets/3.csv', index=False)

在这里插入图片描述
在这里插入图片描述

4. video_frame_bbox键生成

import json
import csv


# 读取CSV文件并转化为字典
def csv_to_dict(csv_file):
    result_dict = {}
    with open(csv_file, 'r') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if len(row) >= 7:
                video_id, time, *bbox_data = row[:6]
                categories = [int(category)-1 for category in row[6:]]
                frame_key = f"{video_id},{time}"
                if frame_key in result_dict:
                    result_dict[frame_key]["bboxes"].append([float(coord) for coord in bbox_data])
                    result_dict[frame_key]["acts"].extend(categories)
                else:
                    result_dict[frame_key] = {
                        "bboxes": [[float(coord) for coord in bbox_data]],
                        "acts": categories
                    }
    # Ensure each category is in its own list
    for key, value in result_dict.items():
        result_dict[key]["acts"] = [[category] for category in value["acts"]]
        bboxes = value["bboxes"]
        for i in range(len(bboxes)):
            if bboxes[i] in bboxes[i+1:]:
                # print(f"Duplicate bbox found at index {i} in {key}")
                result_dict[key]["acts"][i] + result_dict[key]["acts"][i+1]

    return result_dict


# 保存字典为JSON文件
def save_to_json(data, output_file):
    import json
    with open(output_file, 'w') as jsonfile:
        json.dump(data, jsonfile, indent=4)

# 输入的CSV文件路径
csv_file_path = '/data1/ws/Tubelet-transformer/datasets/3.csv'

# 转化为字典
result_dict = csv_to_dict(csv_file_path)

# # 删除键中重复的字典
# for key, value in result_dict.items():
#     value['bboxes'] = list(set(tuple(box) for box in value['bboxes']))
    
print(result_dict)
# print(len(result_dict))
with open('video_frame_bbox.txt', 'w') as f:
    f.write(str(result_dict))

save_to_json(result_dict, 'video_frame_bbox.json')

在这里插入图片描述

5. frame_keys_list

import csv

csv_file = '/data1/ws/Tubelet-transformer/datasets/3.csv'
data_dict = {}

with open(csv_file, 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    next(csv_reader)  # Skip the header row if present
    for row in csv_reader:
        key = row[0] + "," + row[1]
        if key not in data_dict:
            data_dict[key] = [float(val) for val in row[2:]]
        else:
            result_dict = {"frame_keys_list": [key for key in data_dict]}
# print(result_dict)

with open('result_dict.txt', 'w') as f:
    f.write(str(result_dict))

在这里插入图片描述

把单引号变成双引号!

6. 粘贴复制到源文件

在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值