class dataset(Dataset):
def __init__(self):
df=pd.read_csv()
feat=df.iloc[].values
label=df.iloc[].values
self.x=torch.from_numpy(feat)
self.y=torch.from_numpy(label)
def __len(self):
return len(self.y)
def __getitem__(self,index):
return self.x[index],self.y[index]
- CE weight 用于类别不均衡情况,ignore_index 用于padding 部分,reduction 是对 batch 的操作