Charades数据处理部分代码阅读

1.enumerate

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

split = ["train","test"]
for di,dt in enumerate(split):
    print(di)
>>>
0
1
print(dt)
>>>
train
test
print(di,dt)
>>>
0 train
1 test

2 使用Dataset和Dataloader构建以及加载数据集

已有好文

具体文件:
utils io_utils
abstract_dataset charades
common_functions

3 torch.no_grad():

torch.no_grad() 是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。
作为一个装饰器。
加在网络测试的函数前

def eval():
	with torch.no_grad()...
		...

4 hdf5文件读取

对dataset直接读取

import h5py
f = h5py.File("D:/code/LGI/data/charades/test_label_F1_L10_I3D.hdf5", "r")
print(f["1"][()]) 
>>>[691   1 977 561 953 620 977 287   0   0]

对group下的dataset读取

f = h5py.File("D:/code/LGI/data/charades/preprocess/grounding_info/test_labels_S128_I3D.hdf5", "r")
# print([key for key in f.keys()], "\n")  
att_mask = f["att_mask"]
end_pos = f["end_pos"]
start_pos = f["start_pos"]
print(end_pos['1'][()]) 
print(start_pos['1'][()]) 
print(att_mask['1'][()]) 
>>>0.9819121447028423
0.7848837209302325

5 关于Dataset和Dataloader究竟创建了个什么出来

实例代码

已知Charadesdataset的getitem函数返回

 instance = {
            "vids": vid,
            "qids": qid,
            "timestamps": timestamp, # GT location [s, e] (second)
            "duration": duration, # video span (second)
            "query_lengths": q_leng,
            "query_labels": torch.LongTensor(q_label).unsqueeze(0),     # [1,L_q_max]
            "query_masks": (torch.FloatTensor(q_label)>0).unsqueeze(0), # [1,L_q_max]
            "grounding_start_pos": torch.FloatTensor([start_pos]), # [1]; normalized
            "grounding_end_pos": torch.FloatTensor([end_pos]),     # [1]; normalized
            "grounding_att_masks": torch.FloatTensor(att_mask),  # [L_v]
            "nfeats": torch.FloatTensor([nfeats]),
            "video_feats": torch.FloatTensor(vid_feat), # [L_v,D_v]
            "video_masks": torch.ByteTensor(vid_mask), # [L_v,1]
        }

        return instance

又有debug函数:

def get_loader():
    conf = {
        
        "test_loader": {
            "dataset": "charades",
            "split": "test",
            "batch_size": 1,
            "data_dir": "data/charades",
            "video_feature_path": "data/charades/features/i3d_finetuned/{}.npy",
            "max_length": 10,
            "word_frequency_threshold": 1,
            "num_segment": 128,
            "feature_type": "I3D",
        }
    }
    print(json.dumps(conf, indent=4))
    dsets, L = create_loaders(["test"],
                              [conf["test_loader"]],
                              num_workers=5)
    return dsets, L

编写测试函数

if __name__ == "__main__":
    dset, l = get_loader()

得到的dset和l是什么

dset:

type(dset)
>>><class 'dict'>
print(dset)
>>>{'test': <__main__.CharadesDataset object at 0x00000243AC8787C0>}
print(dset["test"]
>>><class '__main__.CharadesDataset'>

到这里可知,dataset获得的是一个字典,值为我们编写的dataset类,故我们可调用类中的__getitem__函数进一步探究

instan = dset["test"].__getitem__(1)
print(type(instan))
>>><class 'dict'>

那么显然,instan为一个字典,且输出其键,为getitem函数中返回的instance字典

for key in instan.keys():
        print(key)
>>>
vids
qids
timestamps
duration
query_lengths
query_labels
query_masks
grounding_start_pos
grounding_end_pos
grounding_att_masks
nfeats
video_feats
video_masks
print((instan["video_feats"]).shape[0])
128

loader:

type(loader) = <class 'dict'>
loader = {'test': <torch.utils.data.dataloader.DataLoader object at 0x0000014BA41CDCD0>}

我们将batch_size设为2

 ii = 1
 loader = l["test"]
    for batch in loader:
        ii += 1
    print(ii)
>>>1860    

可见数据集被分成了1860个batch,每个batch内有两个instance

batch_size决定了在训练中一次用多少条数据进行训练
如视频特征维度为128 x 1024
batch_size为100的情况下,输入的vid_feature则为[100,128,1024]

6 数据输入

由evalu.main -> cmf.test() -> LGI.prepare_batch
在这里插入图片描述
可见,对于dataset中得到的instance及其众多特征项,这里分成了输入特征和标签
进一步比对可知
在这里插入图片描述
在这里插入图片描述
均为标签类特征。
其中att_mask为共有特征,维度为128

而作为网络输入的特征,只有编码+固定维度后的查询文本向量及mask
维度为[1x10]
在这里插入图片描述
和I3D+固定维度处理后的视频特征及mask
维度为[128 x 1024]
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值