一 数据加载
# 训练集的加载
class PretrainDataSet(Dataset):
def __init__(self, filepath="data/pre_train.csv"):
# print(f"reading{filepath}") # 打印列名
df = pandas.read_csv(
filepath, header=0, index_col=0,
encoding='utf-8',
names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'target'],
dtype={"sepal_length": np.float32, "sepal_width": np.float32, "petal_length": np.float32, "petal_width": np.float32, "target": int},
skip_blank_lines=True, # 跳过空白行
)
print(f"前训练数据的形状{df.shape}")
# print(df.head())
feat = df.iloc[:, :4].values # 前四列
label = df.iloc[:, 4].values # 第五列
self.x = torch.from_numpy(feat)
self.x = self.x.view(-1, 1, 4, 1)
self.y = torch.from_n