【1】原始代码:
def __getitem__(self, index):
wt_feature = self.wt_features[index]
mt_feature = self.mt_features[index]
label = self.true_ddg[index]
# 将特征和标签转换为张量类型
wt_feature = torch.tensor(wt_feature, dtype=torch.float32)
mt_feature = torch.tensor(mt_feature, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.float32)
return {"wt_feature": wt_feature, "mt_feature": mt_feature, "label": label}
在之后训练过程中,使用dataloader 在for batch 的时候出现报错:
raise keyerror (key) from err
【解释】:该报错的原因是存在超过范围的索引
【原因】:
wt_feature = self.wt_features[index]
mt_feature = self.mt_features[index]
label = self.true_ddg[index]
这里输入的wt_features mt_features 是dataframe 类型,取值应该换为以下:
wt_feature = self.wt