pytorch 回归预测(时间序列)

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from sklearn import preprocessing


features = pd.read_csv('temps.csv')
# year,moth,day,week 时间
# temp_2:前天的最高温度值
# temp_1:昨天的最高温度值
# average:在历史中,每年这一天的平均最高温度值
# actual:当天实际温度值
features = pd.get_dummies(features) # 将星期进行one-hot编码


labels = np.array(features['actual']) # 预测值 Y
features= features.drop('actual', axis = 1) # 特征 X
feature_list = list(features.columns) # 名字单独保存一下,以备后患
features = np.array(features) # 转换成合适的格式
input_features = preprocessing.StandardScaler().fit_transform(features)# 标准化
x = torch.tensor(input_features, dtype = float)
y = torch.tensor(labels, dtype = float)

input_size = input_features.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(my_nn.parameters(), lr = 0.001)
# 训练网络
losses = []
for i in range(1000):
    batch_loss = []
    # MINI-Batch方法来进行训练
    for start in range(0, len(input_features), batch_size):
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)
        xx = torch.tensor(input_features[start:end], dtype = torch.float, requires_grad = True)
        yy = torch.tensor(labels[start:end], dtype = torch.float, requires_grad = True)
        prediction = my_nn(xx)
        loss = cost(prediction, yy)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        batch_loss.append(loss.data.numpy())
    
    # 打印损失
    if i % 100==0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))

# 进行预测
x = torch.tensor(input_features, dtype = torch.float)
predict = my_nn(x).data.numpy()

数据

year,month,day,week,temp_2,temp_1,average,actual
2016,1,1,Fri,45,45,45.6,45
2016,1,2,Sat,44,45,45.7,44
2016,1,3,Sun,45,44,45.8,41
2016,1,4,Mon,44,41,45.9,40
2016,1,5,Tues,41,40,46.0,44
2016,1,6,Wed,40,44,46.1,51
2016,1,7,Thurs,44,51,46.2,45
2016,1,8,Fri,51,45,46.3,48
2016,1,9,Sat,45,48,46.4,50
2016,1,10,Sun,48,50,46.5,52
2016,1,11,Mon,50,52,46.7,45
2016,1,12,Tues,52,45,46.8,49
2016,1,13,Wed,45,49,46.9,55
2016,1,14,Thurs,49,55,47.0,49
2016,1,15,Fri,55,49,47.1,48
2016,1,16,Sat,49,48,47.3,54
2016,1,17,Sun,48,54,47.4,50
2016,1,18,Mon,54,50,47.5,54
2016,1,19,Tues,50,54,47.6,48
2016,1,20,Wed,54,48,47.7,52
2016,1,21,Thurs,48,52,47.8,52
2016,1,22,Fri,52,52,47.9,57
2016,1,23,Sat,52,57,48.0,48
2016,1,24,Sun,57,48,48.1,51
2016,1,25,Mon,48,51,48.2,54
2016,1,26,Tues,51,54,48.3,56
2016,1,27,Wed,54,56,48.4,57
2016,1,28,Thurs,56,57,48.4,56
2016,1,29,Fri,57,56,48.5,52
2016,1,30,Sat,56,52,48.6,48
2016,1,31,Sun,52,48,48.7,47
2016,2,1,Mon,48,47,48.8,46
2016,2,2,Tues,47,46,48.8,51
2016,2,3,Wed,46,51,48.9,49
2016,2,4,Thurs,51,49,49.0,49
2016,2,5,Fri,49,49,49.1,53
2016,2,6,Sat,49,53,49.1,49
2016,2,7,Sun,53,49,49.2,51
2016,2,8,Mon,49,51,49.3,57
2016,2,9,Tues,51,57,49.4,62
2016,2,10,Wed,57,62,49.4,56
2016,2,11,Thurs,62,56,49.5,55
2016,2,12,Fri,56,55,49.6,58
2016,2,15,Mon,55,58,49.9,55
2016,2,16,Tues,58,55,49.9,56
2016,2,17,Wed,55,56,50.0,57
2016,2,18,Thurs,56,57,50.1,53
2016,2,19,Fri,57,53,50.2,51
2016,2,20,Sat,53,51,50.4,53
2016,2,21,Sun,51,53,50.5,51
2016,2,22,Mon,53,51,50.6,51
2016,2,23,Tues,51,51,50.7,60
2016,2,24,Wed,51,60,50.8,59
2016,2,25,Thurs,60,59,50.9,61
2016,2,26,Fri,59,61,51.1,60
2016,2,27,Sat,61,60,51.2,57
2016,2,28,Sun,60,57,51.3,53
2016,3,1,Tues,53,54,51.5,58
2016,3,2,Wed,54,58,51.6,55
2016,3,3,Thurs,58,55,51.8,59
2016,3,4,Fri,55,59,51.9,57
2016,3,5,Sat,59,57,52.1,64
2016,3,6,Sun,57,64,52.2,60
2016,3,7,Mon,64,60,52.4,53
2016,3,8,Tues,60,53,52.5,54
2016,3,9,Wed,53,54,52.7,55
2016,3,10,Thurs,54,55,52.8,56
2016,3,11,Fri,55,56,53.0,55
2016,3,12,Sat,56,55,53.1,52
2016,3,13,Sun,55,52,53.3,54
2016,3,14,Mon,52,54,53.4,49
2016,3,15,Tues,54,49,53.6,51
2016,3,16,Wed,49,51,53.7,53
2016,3,17,Thurs,51,53,53.9,58
2016,3,18,Fri,53,58,54.0,63
2016,3,19,Sat,58,63,54.2,61
2016,3,20,Sun,63,61,54.3,55
2016,3,21,Mon,61,55,54.5,56
2016,3,22,Tues,55,56,54.6,57
2016,3,23,Wed,56,57,54.7,53
2016,3,24,Thurs,57,53,54.9,54
2016,3,25,Fri,53,54,55.0,57
2016,3,26,Sat,54,57,55.2,59
2016,3,27,Sun,57,59,55.3,51
2016,3,28,Mon,59,51,55.5,56
2016,3,29,Tues,51,56,55.6,64
2016,3,30,Wed,56,64,55.7,68
2016,3,31,Thurs,64,68,55.9,73
2016,4,1,Fri,68,73,56.0,71
2016,4,2,Sat,73,71,56.2,63
2016,4,3,Sun,71,63,56.3,69
2016,4,4,Mon,63,69,56.5,60
2016,4,5,Tues,69,60,56.6,57
2016,4,6,Wed,60,57,56.8,68
2016,4,7,Thurs,57,68,56.9,77
2016,4,8,Fri,68,77,57.1,76
2016,4,9,Sat,77,76,57.2,66
2016,4,10,Sun,76,66,57.4,59
2016,4,11,Mon,66,59,57.6,58
2016,4,12,Tues,59,58,57.7,60
2016,4,13,Wed,58,60,57.9,59
2016,4,14,Thurs,60,59,58.1,59
2016,4,15,Fri,59,59,58.3,60
2016,4,16,Sat,59,60,58.5,68
2016,4,17,Sun,60,68,58.6,77
2016,4,18,Mon,68,77,58.8,89
2016,4,19,Tues,77,89,59.0,81
2016,4,20,Wed,89,81,59.2,81
2016,4,21,Thurs,81,81,59.4,73
2016,4,22,Fri,81,73,59.7,64
2016,4,23,Sat,73,64,59.9,65
2016,4,24,Sun,64,65,60.1,55
2016,4,25,Mon,65,55,60.3,59
2016,4,26,Tues,55,59,60.5,60
2016,4,27,Wed,59,60,60.7,61
2016,4,28,Thurs,60,61,61.0,64
2016,4,29,Fri,61,64,61.2,61
2016,4,30,Sat,64,61,61.4,68
2016,5,1,Sun,61,68,61.6,77
2016,5,2,Mon,68,77,61.9,87
2016,5,3,Tues,77,87,62.1,74
2016,5,4,Wed,87,74,62.3,60
2016,5,5,Thurs,74,60,62.5,68
2016,5,6,Fri,60,68,62.8,77
2016,5,7,Sat,68,77,63.0,82
2016,5,8,Sun,77,82,63.2,63
2016,5,9,Mon,82,63,63.4,67
2016,5,10,Tues,63,67,63.6,75
2016,5,11,Wed,67,75,63.8,81
2016,5,12,Thurs,75,81,64.1,77
2016,5,13,Fri,81,77,64.3,82
2016,5,14,Sat,77,82,64.5,65
2016,5,15,Sun,82,65,64.7,57
2016,5,16,Mon,65,57,64.8,60
2016,5,17,Tues,57,60,65.0,71
2016,5,18,Wed,60,71,65.2,64
2016,5,19,Thurs,71,64,65.4,63
2016,5,20,Fri,64,63,65.6,66
2016,5,21,Sat,63,66,65.7,59
2016,5,22,Sun,66,59,65.9,66
2016,5,23,Mon,59,66,66.1,65
2016,5,24,Tues,66,65,66.2,66
2016,5,25,Wed,65,66,66.4,66
2016,5,26,Thurs,66,66,66.5,65
2016,5,27,Fri,66,65,66.7,64
2016,5,28,Sat,65,64,66.8,64
2016,5,29,Sun,64,64,67.0,64
2016,5,30,Mon,64,64,67.1,71
2016,5,31,Tues,64,71,67.3,79
2016,6,1,Wed,71,79,67.4,75
2016,6,2,Thurs,79,75,67.6,71
2016,6,3,Fri,75,71,67.7,80
2016,6,4,Sat,71,80,67.9,81
2016,6,5,Sun,80,81,68.0,92
2016,6,6,Mon,81,92,68.2,86
2016,6,7,Tues,92,86,68.3,85
2016,6,8,Wed,86,85,68.5,67
2016,6,9,Thurs,85,67,68.6,65
2016,6,10,Fri,67,65,68.8,67
2016,6,11,Sat,65,67,69.0,65
2016,6,12,Sun,67,65,69.1,70
2016,6,13,Mon,65,70,69.3,66
2016,6,14,Tues,70,66,69.5,60
2016,6,15,Wed,66,60,69.7,67
2016,6,16,Thurs,60,67,69.8,71
2016,6,17,Fri,67,71,70.0,67
2016,6,18,Sat,71,67,70.2,65
2016,6,19,Sun,67,65,70.4,70
2016,6,20,Mon,65,70,70.6,76
2016,6,21,Tues,70,76,70.8,73
2016,6,22,Wed,76,73,71.0,75
2016,6,23,Thurs,73,75,71.3,68
2016,6,24,Fri,75,68,71.5,69
2016,6,25,Sat,68,69,71.7,71
2016,6,26,Sun,69,71,71.9,78
2016,6,27,Mon,71,78,72.2,85
2016,6,28,Tues,78,85,72.4,79
2016,6,29,Wed,85,79,72.6,74
2016,6,30,Thurs,79,74,72.8,73
2016,7,1,Fri,74,73,73.1,76
2016,7,2,Sat,73,76,73.3,76
2016,7,3,Sun,76,76,73.5,71
2016,7,4,Mon,76,71,73.8,68
2016,7,5,Tues,71,68,74.0,69
2016,7,6,Wed,68,69,74.2,76
2016,7,7,Thurs,69,76,74.4,68
2016,7,8,Fri,76,68,74.6,74
2016,7,9,Sat,68,74,74.9,71
2016,7,10,Sun,74,71,75.1,74
2016,7,11,Mon,71,74,75.3,74
2016,7,12,Tues,74,74,75.4,77
2016,7,13,Wed,74,77,75.6,75
2016,7,14,Thurs,77,75,75.8,77
2016,7,15,Fri,75,77,76.0,76
2016,7,16,Sat,77,76,76.1,72
2016,7,17,Sun,76,72,76.3,80
2016,7,18,Mon,72,80,76.4,73
2016,7,19,Tues,80,73,76.6,78
2016,7,20,Wed,73,78,76.7,82
2016,7,21,Thurs,78,82,76.8,81
2016,7,22,Fri,82,81,76.9,71
2016,7,23,Sat,81,71,77.0,75
2016,7,24,Sun,71,75,77.1,80
2016,7,25,Mon,75,80,77.1,85
2016,7,26,Tues,80,85,77.2,79
2016,7,27,Wed,85,79,77.3,83
2016,7,28,Thurs,79,83,77.3,85
2016,7,29,Fri,83,85,77.3,88
2016,7,30,Sat,85,88,77.3,76
2016,7,31,Sun,88,76,77.4,73
2016,8,1,Mon,76,73,77.4,77
2016,8,2,Tues,73,77,77.4,73
2016,8,3,Wed,77,73,77.3,75
2016,8,4,Thurs,73,75,77.3,80
2016,8,5,Fri,75,80,77.3,79
2016,8,6,Sat,80,79,77.2,72
2016,8,7,Sun,79,72,77.2,72
2016,8,8,Mon,72,72,77.1,73
2016,8,9,Tues,72,73,77.1,72
2016,8,10,Wed,73,72,77.0,76
2016,8,11,Thurs,72,76,76.9,80
2016,8,12,Fri,76,80,76.9,87
2016,8,13,Sat,80,87,76.8,90
2016,8,14,Sun,87,90,76.7,83
2016,8,15,Mon,90,83,76.6,84
2016,8,16,Tues,83,84,76.5,81
2016,8,23,Tues,84,81,75.7,79
2016,8,28,Sun,81,79,75.0,75
2016,8,30,Tues,79,75,74.6,70
2016,9,3,Sat,75,70,73.9,67
2016,9,4,Sun,70,67,73.7,68
2016,9,5,Mon,67,68,73.5,68
2016,9,6,Tues,68,68,73.3,68
2016,9,7,Wed,68,68,73.0,67
2016,9,8,Thurs,68,67,72.8,72
2016,9,9,Fri,67,72,72.6,74
2016,9,10,Sat,72,74,72.3,77
2016,9,11,Sun,74,77,72.1,70
2016,9,12,Mon,77,70,71.8,74
2016,9,13,Tues,70,74,71.5,75
2016,9,14,Wed,74,75,71.2,79
2016,9,15,Thurs,75,79,71.0,71
2016,9,16,Fri,79,71,70.7,75
2016,9,17,Sat,71,75,70.3,68
2016,9,18,Sun,75,68,70.0,69
2016,9,19,Mon,68,69,69.7,71
2016,9,20,Tues,69,71,69.4,67
2016,9,21,Wed,71,67,69.0,68
2016,9,22,Thurs,67,68,68.7,67
2016,9,23,Fri,68,67,68.3,64
2016,9,24,Sat,67,64,68.0,67
2016,9,25,Sun,64,67,67.6,76
2016,9,26,Mon,67,76,67.2,77
2016,9,27,Tues,76,77,66.8,69
2016,9,28,Wed,77,69,66.5,68
2016,9,29,Thurs,69,68,66.1,66
2016,9,30,Fri,68,66,65.7,67
2016,10,1,Sat,66,67,65.3,63
2016,10,2,Sun,67,63,64.9,65
2016,10,3,Mon,63,65,64.5,61
2016,10,4,Tues,65,61,64.1,63
2016,10,5,Wed,61,63,63.7,66
2016,10,6,Thurs,63,66,63.3,63
2016,10,7,Fri,66,63,62.9,64
2016,10,8,Sat,63,64,62.5,68
2016,10,9,Sun,64,68,62.1,57
2016,10,10,Mon,68,57,61.8,60
2016,10,11,Tues,57,60,61.4,62
2016,10,12,Wed,60,62,61.0,66
2016,10,13,Thurs,62,66,60.6,60
2016,10,14,Fri,66,60,60.2,60
2016,10,15,Sat,60,60,59.9,62
2016,10,16,Sun,60,62,59.5,60
2016,10,17,Mon,62,60,59.1,60
2016,10,18,Tues,60,60,58.8,61
2016,10,19,Wed,60,61,58.4,58
2016,10,20,Thurs,61,58,58.1,62
2016,10,21,Fri,58,62,57.8,59
2016,10,22,Sat,62,59,57.4,62
2016,10,23,Sun,59,62,57.1,62
2016,10,24,Mon,62,62,56.8,61
2016,10,25,Tues,62,61,56.5,65
2016,10,26,Wed,61,65,56.2,58
2016,10,27,Thurs,65,58,55.9,60
2016,10,28,Fri,58,60,55.6,65
2016,10,29,Sat,60,65,55.3,68
2016,10,31,Mon,65,68,54.8,59
2016,11,1,Tues,68,59,54.5,57
2016,11,2,Wed,59,57,54.2,57
2016,11,3,Thurs,57,57,53.9,65
2016,11,4,Fri,57,65,53.7,65
2016,11,5,Sat,65,65,53.4,58
2016,11,6,Sun,65,58,53.2,61
2016,11,7,Mon,58,61,52.9,63
2016,11,8,Tues,61,63,52.7,71
2016,11,9,Wed,63,71,52.4,65
2016,11,10,Thurs,71,65,52.2,64
2016,11,11,Fri,65,64,51.9,63
2016,11,12,Sat,64,63,51.7,59
2016,11,13,Sun,63,59,51.4,55
2016,11,14,Mon,59,55,51.2,57
2016,11,15,Tues,55,57,51.0,55
2016,11,16,Wed,57,55,50.7,50
2016,11,17,Thurs,55,50,50.5,52
2016,11,18,Fri,50,52,50.3,55
2016,11,19,Sat,52,55,50.0,57
2016,11,20,Sun,55,57,49.8,55
2016,11,21,Mon,57,55,49.5,54
2016,11,22,Tues,55,54,49.3,54
2016,11,23,Wed,54,54,49.1,49
2016,11,24,Thurs,54,49,48.9,52
2016,11,25,Fri,49,52,48.6,52
2016,11,26,Sat,52,52,48.4,53
2016,11,27,Sun,52,53,48.2,48
2016,11,28,Mon,53,48,48.0,52
2016,11,29,Tues,48,52,47.8,52
2016,11,30,Wed,52,52,47.6,52
2016,12,1,Thurs,52,52,47.4,46
2016,12,2,Fri,52,46,47.2,50
2016,12,3,Sat,46,50,47.0,49
2016,12,4,Sun,50,49,46.8,46
2016,12,5,Mon,49,46,46.6,40
2016,12,6,Tues,46,40,46.4,42
2016,12,7,Wed,40,42,46.3,40
2016,12,8,Thurs,42,40,46.1,41
2016,12,9,Fri,40,41,46.0,36
2016,12,10,Sat,41,36,45.9,44
2016,12,11,Sun,36,44,45.7,44
2016,12,12,Mon,44,44,45.6,43
2016,12,13,Tues,44,43,45.5,40
2016,12,14,Wed,43,40,45.4,39
2016,12,15,Thurs,40,39,45.3,39
2016,12,16,Fri,39,39,45.3,35
2016,12,17,Sat,39,35,45.2,35
2016,12,18,Sun,35,35,45.2,39
2016,12,19,Mon,35,39,45.1,46
2016,12,20,Tues,39,46,45.1,51
2016,12,21,Wed,46,51,45.1,49
2016,12,22,Thurs,51,49,45.1,45
2016,12,23,Fri,49,45,45.1,40
2016,12,24,Sat,45,40,45.1,41
2016,12,25,Sun,40,41,45.1,42
2016,12,26,Mon,41,42,45.2,42
2016,12,27,Tues,42,42,45.2,47
2016,12,28,Wed,42,47,45.3,48
2016,12,29,Thurs,47,48,45.3,48
2016,12,30,Fri,48,48,45.4,57
2016,12,31,Sat,48,57,45.5,40
  • 5
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
对于时间序列预测,PyTorch提供了一些强大的工具和库。你可以使用PyTorch构建和训练深度学习模型来进行时间序列预测。下面是一个使用PyTorch进行时间序列预测的基本步骤: 1. 数据准备:首先,你需要准备你的时间序列数据。将数据分为训练集和测试集,并进行标准化以便更好地训练模型。 2. 创建模型:使用PyTorch创建一个适合时间序列预测的模型。常见的模型包括循环神经网络(RNN)、长短期记忆网络(LSTM)和变体,如GRU等。 3. 定义损失函数:选择适当的损失函数来评估模型的性能。对于回归问题,通常使用均方误差(MSE)或平均绝对误差(MAE)作为损失函数。 4. 训练模型:使用训练数据对模型进行训练。通过反向传播和优化算法(如随机梯度下降)来更新模型的权重和偏置,以最小化损失函数。 5. 模型评估:使用测试数据评估模型的性能。计算预测值与真实值之间的误差,并使用适当的指标(如均方根误差,R-squared等)衡量模型的准确性。 6. 进行预测:使用已训练的模型对未来的时间序列进行预测。将模型应用于新的输入数据,并获得预测结果。 PyTorch提供了丰富的库和工具来支持这些步骤,包括torch.nn模块用于构建模型,torch.optim模块用于优化算法,以及许多其他工具和函数用于数据处理和评估模型性能。你可以根据具体的时间序列预测任务选择合适的模型和方法。希望这些步骤对你有所帮助!如果你对某个具体部分有更多的问题,可以进一步提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值