文章目录
以TubeR来说明:
ANNO_PATH: '/data1/ws/Tubelet-transformer/datasets/assets/my_ava_{}_v22.json'
train_bbox_json = json.load(open(cfg.CONFIG.DATA.ANNO_PATH.format("train")))
train_video_frame_bbox, train_frame_keys_list = train_bbox_json["video_frame_bbox"], train_bbox_json["frame_keys_list"]
也就是说,找到dataloader下的训练文件来源,我们对这个文件进行分层抽样。
1. 查询原配置文件信息
import json
import pandas as pd
train_bbox_json = json.load(open('/data1/ws/Tubelet-transformer/datasets/assets/ava_train_v22.json'))
train_video_frame_bbox, train_frame_keys_list = train_bbox_json["video_frame_bbox"], train_bbox_json["frame_keys_list"]
"""
train_bbox_json:
video_frame_bbox: 'zlVkeKC6Ha8,1782': {'bboxes': [[0.042, 0.046, 0.97, 0.98]], 'acts': [[10, 73, 79]]},
frame_keys_list: 'zlVkeKC6Ha8,1782'
"""
print(list(train_bbox_json)[:2])
print(list(train_video_frame_bbox.items())[:2])
print(list(train_frame_keys_list)[:2])
['video_frame_bbox', 'frame_keys_list']
[('-5KQ66BBWC4,0902', {'bboxes': [[0.077, 0.151, 0.28300000000000003, 0.8109999999999999], [0.332, 0.19399999999999998, 0.48100000000000004, 0.8909999999999999], [0.505, 0.105, 0.653, 0.78], [0.626, 0.146, 0.805, 0.818], [0.805, 0.222, 0.997, 1.0]], 'acts': [[79, 8], [79, 8], [8], [8], [79, 8]]}), ('-5KQ66BBWC4,0903', {'bboxes': [[0.0, 0.162, 0.177, 0.804], [0.141, 0.158, 0.298, 0.825], [0.32799999999999996, 0.182, 0.484, 0.895], [0.507, 0.147, 0.6659999999999999, 0.789], [0.642, 0.158, 0.7909999999999999, 0.8590000000000001], [0.785, 0.15, 0.8859999999999999, 0.703], [0.802, 0.267, 0.9940000000000001, 0.971], [0.865, 0.158, 0.991, 0.436]], 'acts': [[79, 8], [11, 79], [79, 8], [79, 8], [8], [11, 79], [79, 8], [79, 8]]})]
['D8Vhxbho1fY,1798', 'PcFEhUKhN6g,1408']
2. 分层抽样 StratifiedKFold
from sklearn.model_selection import StratifiedGroupKFold
import pandas as pd
# Load the data from the csv file
data = pd.read_csv('/data1/ws/Tubelet-transformer/datasets/assets/ava_val_v2.2.csv')
# Get the target variable
target = data.iloc[:, -2]
# Create the StratifiedKFold object
skf = StratifiedGroupKFold(n_splits=10, shuffle=True, random_state=42)
# Get the indices of the samples to keep
_, indices = next(skf.split(data, target, groups=data.iloc[:, 0]))
# Keep only the selected samples
selected_data = data.iloc[indices]
# Save the selected samples to a new csv file
selected_data.to_csv('/data1/ws/Tubelet-transformer/datasets/selected_val_data.csv', index=False)
80万条数据变为8万。
/data1/ws/anaconda3/envs/pytorch/lib/python3.11/site-packages/sklearn/model_selection/_split.py:737: UserWarning: The least populated class in y has only 7 members, which is less than n_splits=10.
warnings.warn(
/data1/ws/anaconda3/envs/pytorch/lib/python3.11/site-packages/sklearn/model_selection/_split.py:737: UserWarning: The least populated class in y has only 3 members, which is less than n_splits=10.
warnings.warn(
UserWarning: y 中填充最少的类只有 3 个成员,小于 n_splits= 10.
警告.警告(
3. 清理数据 (除0和删最后一列)
最后一列数据是person id,我们任务中不需要。
import pandas as pd
original_data = pd.read_csv('/data1/ws/Tubelet-transformer/datasets/selected_data.csv')
# 删除最后一列为0的数据
filtered_data = original_data.loc[original_data.iloc[:, -1] != 0]
filtered_data.to_csv('/data1/ws/Tubelet-transformer/datasets/2.csv', index=False)
# 再删除最后一列数据
df = filtered_data.iloc[:, :-1]
df.to_csv('/data1/ws/Tubelet-transformer/datasets/3.csv', index=False)
4. video_frame_bbox键生成
import json
import csv
# 读取CSV文件并转化为字典
def csv_to_dict(csv_file):
result_dict = {}
with open(csv_file, 'r') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if len(row) >= 7:
video_id, time, *bbox_data = row[:6]
categories = [int(category)-1 for category in row[6:]]
frame_key = f"{video_id},{time}"
if frame_key in result_dict:
result_dict[frame_key]["bboxes"].append([float(coord) for coord in bbox_data])
result_dict[frame_key]["acts"].extend(categories)
else:
result_dict[frame_key] = {
"bboxes": [[float(coord) for coord in bbox_data]],
"acts": categories
}
# Ensure each category is in its own list
for key, value in result_dict.items():
result_dict[key]["acts"] = [[category] for category in value["acts"]]
bboxes = value["bboxes"]
for i in range(len(bboxes)):
if bboxes[i] in bboxes[i+1:]:
# print(f"Duplicate bbox found at index {i} in {key}")
result_dict[key]["acts"][i] + result_dict[key]["acts"][i+1]
return result_dict
# 保存字典为JSON文件
def save_to_json(data, output_file):
import json
with open(output_file, 'w') as jsonfile:
json.dump(data, jsonfile, indent=4)
# 输入的CSV文件路径
csv_file_path = '/data1/ws/Tubelet-transformer/datasets/3.csv'
# 转化为字典
result_dict = csv_to_dict(csv_file_path)
# # 删除键中重复的字典
# for key, value in result_dict.items():
# value['bboxes'] = list(set(tuple(box) for box in value['bboxes']))
print(result_dict)
# print(len(result_dict))
with open('video_frame_bbox.txt', 'w') as f:
f.write(str(result_dict))
save_to_json(result_dict, 'video_frame_bbox.json')
5. frame_keys_list
import csv
csv_file = '/data1/ws/Tubelet-transformer/datasets/3.csv'
data_dict = {}
with open(csv_file, 'r') as csv_file:
csv_reader = csv.reader(csv_file)
next(csv_reader) # Skip the header row if present
for row in csv_reader:
key = row[0] + "," + row[1]
if key not in data_dict:
data_dict[key] = [float(val) for val in row[2:]]
else:
result_dict = {"frame_keys_list": [key for key in data_dict]}
# print(result_dict)
with open('result_dict.txt', 'w') as f:
f.write(str(result_dict))
把单引号变成双引号!
6. 粘贴复制到源文件