基础
安装tfts
pip install tfts
数据sine函数
train, valid = tfts.get_data(args.use_data, args.train_length, args.predict_length, test_size=0.2)
如果想用docker的话,在tfts地址的docker文件夹下载Dockerfile
然后进入交互模式
训练
Note
注意该数据集训练集和测试集虽然数据不同,但模式相同,所以可以认为几乎是过拟合的。
但一个强大的模型,应该先能够过拟合简单数据
RNN
cd examples
python run_prediction.py --use_model rnn
seq2seq
cd examples
python run_prediction.py --use_model seq2seq
TCN
cd examples
python run_prediction.py --use_model tcn
wavenet
cd examples
python run_prediction.py --use_model wavenet
Bert
cd examples
python run_prediction.py --use_model bert
Transformer
cd examples
python run_prediction.py --use_model transformer
informer
cd examples
python run_prediction.py --use_model informer
如果 custom_params.update({"skip_connect_circle": True})
如果 custom_params.update({"skip_connect_circle": False})
Nbeats
cd examples
python run_prediction.py --use_model nbeats