初学pytorch,往往是有点难度,在几次的练习后,逐渐熟练起来,这里介绍几个使用编程的技巧,也使后人少走弯路。
1、数据集的划分
通常情况下,机器学习的过程必然少不了数据集的清洗和划分,在划分的过程是存在技巧的。
以1万条数据的数据集为例,我们选取8000条数据进行训练,2000条数据进行测试,那么在这8000条的数据中,我们每一个训练的epoch都会完整的使用到这8000条数据,在不同的epoch中,对这8000条数据进行shuffle操作,即打乱训练顺序,可以有更好的训练效果。
2、torch中的形状
对于不同的torch网络模型,往往要求不同的输入格式,如(sen_len, batch_size, embedding_size)等。我们一上来数据集的数据通常都是混在一起的一个大数组形式,这是便可以使用torch的view方法,将tensor变成想要的形状。
但是要注意的是,tensor.view方法是原来tensor的浅拷贝,并且该方法会打乱tensor在内存中的位置,因此对该tensor要进行后续的有tensor连续性要求的操作时会报错,这时就需要对view过的tensor进行连接操作tensor.contiguous(),把分散的内存连续化。
3、数据文件操作
大多数的文件数据如果是csv或者xlsx文件时,我只能说pandas是yyds!