1、主函数配置参数的修改
main.py文件
第122行、132行
parser.add_argument(‘–set_cost_screenLight’, default=1, type=float,help=“phone cost set”)
parser.add_argument(‘–screenLight_loss_coef’, default=1, type=float)
#############################################################
2、数据装载的修改
datasets/hico.py
173行从 json标注中取出 手机显示的标注
screenLight_labels.append(torch.tensor(hoi[‘screenLight’]))
190行、197行不同情况下,将screenLight标注放入target 类中
190行
target[‘screenLight_labels’] = torch.zeros((0,),dtype=torch.int64)
197行
target[‘screenLight_labels’] = torch.stack((screenLight_labels))
#############################################################
3、修改模型运行的三个主要流程
model, criterion, postprocessors三个类都在hoi.py文件中
model 模型的主要的块 都在里面定义,如模型的编码器、解码器、MLP等。
criterion 定义模型各个预测类别的一些损失函数设定
postprocessors 模型的预测结果的收集
在HOI文件的CDNHOI中,64、65行添加定义两个线性层,且将中间的隐藏层添加到interaction的向量中
screen_interaction_out = self.screenLight_embed(hopd_out)
outputs_screenLight_class = self.screen_interaction_embed(screen_interaction_out)
76行 将 预测的手机明暗放在out中。
out = {‘pred_obj_logits’: outputs_obj_class[-1], ‘pred_verb_logits’: outputs_verb_class[-1],
‘pred_sub_boxes’: outputs_sub_coord[-1], ‘pred_obj_boxes’: outputs_obj_coord[-1], ‘pred_screenLight_labels’: outputs_screenLight_class[-1]}
SetCriterionHOI 定义 损失函数
355行 def loss_screenLight_labels(self, outputs, targets, indices, num_interactions, log=True):
410行 损失函数的初始化
def get_loss(self, loss, outputs, targets, indices, num, **kwargs):
loss_map = {
'obj_labels': self.loss_obj_labels,
'obj_cardinality': self.loss_obj_cardinality,
'verb_labels': self.loss_verb_labels,
'sub_obj_boxes': self.loss_sub_obj_boxes,
'screenLight_labels': self.loss_screenLight_labels,
'matching_labels': self.loss_matching_labels
}
class PostProcessHOI 收集 预测的结果
466 行
out_screenLight_labels = outputs[‘pred_screenLight_labels’]
对预测的手机明暗进行一个softmax,得到概率
对于输入的二维数组x,进行softmax(x, -1)操作时,会对每一行进行softmax计算。因此,[…,1]表示取结果中每一行的第二个元素
screenLight_scores = F.softmax(out_screenLight_labels, -1)[…, 1]
549行、560行、564行一些损失的规范化
#################################################################
engine.py 定义 训练一个epoch的流程 和评估hoi的准确度的代码
def train_one_epoch
在34行取出 预测的结果
outputs = model(samples)
36行根据定义的损失函数进行计算损失,进行权重的更新
loss_dict = criterion(outputs, targets)