哪吒2票房目前已经位于全球第8,相信不久之后就会进一步向前迈进,那么今天尝试使用LSTM简单的对票房进行预测。
预测数据取前27天:
[4.88, 4.80, 6.20, 7.32, 8.13, 8.44, 8.67, 6.50, 5.86, 5.42, 6.20, 7.62, 4.82, 4.82, 5.33, 3.60, 5.81, 7.88, 6.14, 1.94, 1.66, 1.46, 1.28, 2.29, 5.22, 3.52, 0.74]
预测代码如下:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 设置 matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体字体,可根据系统情况更换为其他支持中文的字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 历史票房数据
box_office = [4.88, 4.80, 6.20, 7.32, 8.13, 8.44, 8.67, 6.50, 5.86, 5.42, 6.20, 7.62, 4.82, 4.82, 5.33, 3.60, 5.81,
7.88, 6.14, 1.94, 1.66, 1.46, 1.28, 2.29, 5.22, 3.52, 0.74]
box_office = np.array(box_office, dtype=np.float32)
# 数据预处理
def create_sequences(data, seq_length):
xs, ys = [], []
for i in range(len(data) - seq_length):
x = data[i:i + seq_length]
y = data[i + seq_length]
xs.append(x)
ys.append(y)
return np.array(xs), np.array(ys)
seq_length = 5 # 序列长度
X, y = create_sequences(box_office, seq_length)
# 转换为 PyTorch 张量
X = torch.from_numpy(X).unsqueeze(2)
y = torch.from_numpy(y)
# 定义 LSTM 模型
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_()
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_()
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out = self.fc(out[:, -1, :])
return out
# 模型参数
input_size = 1
hidden_size = 32
num_layers = 1
output_size = 1
model = LSTMModel(input_size, hidden_size, num_layers, output_size)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
outputs = model(X)
optimizer.zero_grad()
loss = criterion(outputs.squeeze(), y)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
# 预测未来 34 天的票房
last_sequence = box_office[-seq_length:].reshape(1, seq_length, 1)
last_sequence = torch.from_numpy(last_sequence).float()
future_predictions = []
for _ in range(34):
next_day_prediction = model(last_sequence).item()
future_predictions.append(next_day_prediction)
new_sequence = np.roll(last_sequence.numpy(), -1, axis=1)
new_sequence[0, -1, 0] = next_day_prediction
last_sequence = torch.from_numpy(new_sequence).float()
# 计算包括前27天的总票房
total_forecast_box_office = np.sum(np.concatenate((box_office, np.array(future_predictions))))
# 输出结果
print("未来 34 天的每日票房预测:", future_predictions)
print("预测总票房:", total_forecast_box_office, "亿")
# 绘制历史数据和预测数据的图表
plt.figure(figsize=(12, 6))
plt.plot(box_office, label='历史票房')
plt.plot(range(len(box_office), len(box_office) + 34), future_predictions, label='预测票房', color='red')
plt.title('票房预测')
plt.xlabel('天数')
plt.ylabel('票房(亿)')
plt.legend()
plt.show()
训练过程如下:
Epoch [10/1500], Loss: 24.4515
Epoch [20/1500], Loss: 22.1718
Epoch [30/1500], Loss: 19.5055
Epoch [40/1500], Loss: 16.2970
Epoch [50/1500], Loss: 12.9819
Epoch [60/1500], Loss: 10.2338
Epoch [70/1500], Loss: 8.3034
Epoch [80/1500], Loss: 7.0466
Epoch [90/1500], Loss: 6.2346
Epoch [100/1500], Loss: 5.7214
Epoch [110/1500], Loss: 5.4366
Epoch [120/1500], Loss: 5.2936
Epoch [130/1500], Loss: 5.2122
Epoch [140/1500], Loss: 5.1438
Epoch [150/1500], Loss: 5.0620
Epoch [160/1500], Loss: 4.9547
Epoch [170/1500], Loss: 4.8231
Epoch [180/1500], Loss: 4.6558
Epoch [190/1500], Loss: 4.4610
Epoch [200/1500], Loss: 4.2507
Epoch [210/1500], Loss: 4.0944
Epoch [220/1500], Loss: 3.9582
Epoch [230/1500], Loss: 3.8379
Epoch [240/1500], Loss: 3.7219
Epoch [250/1500], Loss: 3.6206
Epoch [260/1500], Loss: 3.5343
Epoch [270/1500], Loss: 3.4657
Epoch [280/1500], Loss: 3.4031
Epoch [290/1500], Loss: 3.3474
Epoch [300/1500], Loss: 3.3001
Epoch [310/1500], Loss: 3.2587
Epoch [320/1500], Loss: 3.2208
Epoch [330/1500], Loss: 3.1852
Epoch [340/1500], Loss: 3.1507
Epoch [350/1500], Loss: 3.1162
Epoch [360/1500], Loss: 3.0800
Epoch [370/1500], Loss: 3.0409
Epoch [380/1500], Loss: 2.9984
Epoch [390/1500], Loss: 2.9520
Epoch [400/1500], Loss: 2.9020
Epoch [410/1500], Loss: 2.8506
Epoch [420/1500], Loss: 2.7983
Epoch [430/1500], Loss: 2.7430
Epoch [440/1500], Loss: 2.6849
Epoch [450/1500], Loss: 2.6235
Epoch [460/1500], Loss: 2.5583
Epoch [470/1500], Loss: 2.4880
Epoch [480/1500], Loss: 2.4103
Epoch [490/1500], Loss: 2.3240
Epoch [500/1500], Loss: 2.2276
Epoch [510/1500], Loss: 2.1202
Epoch [520/1500], Loss: 2.0029
Epoch [530/1500], Loss: 1.8804
Epoch [540/1500], Loss: 1.7615
Epoch [550/1500], Loss: 1.6531
Epoch [560/1500], Loss: 1.5558
Epoch [570/1500], Loss: 1.4668
Epoch [580/1500], Loss: 1.3832
Epoch [590/1500], Loss: 1.3035
Epoch [600/1500], Loss: 1.2263
Epoch [610/1500], Loss: 1.1509
Epoch [620/1500], Loss: 1.0780
Epoch [630/1500], Loss: 1.0098
Epoch [640/1500], Loss: 0.9482
Epoch [650/1500], Loss: 0.8947
Epoch [660/1500], Loss: 0.8492
Epoch [670/1500], Loss: 0.8103
Epoch [680/1500], Loss: 0.7765
Epoch [690/1500], Loss: 0.7469
Epoch [700/1500], Loss: 0.7208
Epoch [710/1500], Loss: 0.6978
Epoch [720/1500], Loss: 0.6773
Epoch [730/1500], Loss: 0.6591
Epoch [740/1500], Loss: 0.6426
Epoch [750/1500], Loss: 0.6279
Epoch [760/1500], Loss: 0.6147
Epoch [770/1500], Loss: 0.6040
Epoch [780/1500], Loss: 0.5932
Epoch [790/1500], Loss: 0.5830
Epoch [800/1500], Loss: 0.5736
Epoch [810/1500], Loss: 0.5648
Epoch [820/1500], Loss: 0.5563
Epoch [830/1500], Loss: 0.5481
Epoch [840/1500], Loss: 0.5403
Epoch [850/1500], Loss: 0.5327
Epoch [860/1500], Loss: 0.5255
Epoch [870/1500], Loss: 0.5186
Epoch [880/1500], Loss: 0.5120
Epoch [890/1500], Loss: 0.5057
Epoch [900/1500], Loss: 0.4995
Epoch [910/1500], Loss: 0.4935
Epoch [920/1500], Loss: 0.4877
Epoch [930/1500], Loss: 0.4824
Epoch [940/1500], Loss: 0.4781
Epoch [950/1500], Loss: 0.4723
Epoch [960/1500], Loss: 0.4666
Epoch [970/1500], Loss: 0.4617
Epoch [980/1500], Loss: 0.4569
Epoch [990/1500], Loss: 0.4522
Epoch [1000/1500], Loss: 0.4475
Epoch [1010/1500], Loss: 0.4430
Epoch [1020/1500], Loss: 0.4384
Epoch [1030/1500], Loss: 0.4339
Epoch [1040/1500], Loss: 0.4293
Epoch [1050/1500], Loss: 0.4247
Epoch [1060/1500], Loss: 0.4200
Epoch [1070/1500], Loss: 0.4153
Epoch [1080/1500], Loss: 0.4103
Epoch [1090/1500], Loss: 0.4052
Epoch [1100/1500], Loss: 0.3998
Epoch [1110/1500], Loss: 0.3938
Epoch [1120/1500], Loss: 0.3870
Epoch [1130/1500], Loss: 0.3848
Epoch [1140/1500], Loss: 0.3726
Epoch [1150/1500], Loss: 0.3619
Epoch [1160/1500], Loss: 0.3534
Epoch [1170/1500], Loss: 0.3454
Epoch [1180/1500], Loss: 0.3375
Epoch [1190/1500], Loss: 0.3296
Epoch [1200/1500], Loss: 0.3216
Epoch [1210/1500], Loss: 0.3134
Epoch [1220/1500], Loss: 0.3049
Epoch [1230/1500], Loss: 0.2962
Epoch [1240/1500], Loss: 0.2870
Epoch [1250/1500], Loss: 0.2774
Epoch [1260/1500], Loss: 0.2672
Epoch [1270/1500], Loss: 0.2566
Epoch [1280/1500], Loss: 0.2457
Epoch [1290/1500], Loss: 0.2348
Epoch [1300/1500], Loss: 0.2240
Epoch [1310/1500], Loss: 0.2135
Epoch [1320/1500], Loss: 0.2132
Epoch [1330/1500], Loss: 0.1970
Epoch [1340/1500], Loss: 0.1866
Epoch [1350/1500], Loss: 0.1773
Epoch [1360/1500], Loss: 0.1688
Epoch [1370/1500], Loss: 0.1607
Epoch [1380/1500], Loss: 0.1529
Epoch [1390/1500], Loss: 0.1455
Epoch [1400/1500], Loss: 0.1383
Epoch [1410/1500], Loss: 0.1313
Epoch [1420/1500], Loss: 0.1245
Epoch [1430/1500], Loss: 0.1178
Epoch [1440/1500], Loss: 0.1111
Epoch [1450/1500], Loss: 0.1042
Epoch [1460/1500], Loss: 0.0970
Epoch [1470/1500], Loss: 0.0893
Epoch [1480/1500], Loss: 0.0813
Epoch [1490/1500], Loss: 0.0736
Epoch [1500/1500], Loss: 0.0664
预测结果:
未来 34 天的每日票房预测: [1.1393108367919922, 1.0481380224227905, 2.6242482662200928, 2.4993085861206055, 0.6360362768173218, 1.1742253303527832, 3.7346150875091553, 5.752073287963867, 2.163248300552368, 0.6004949808120728, 1.3627697229385376, 0.5352745056152344, 1.5377495288848877, 2.271090507507324, 0.693607747554779, 0.8373626470565796, 2.386218547821045, 2.779088258743286, 1.3350541591644287, 1.0107192993164062, 2.3213491439819336, 4.748796463012695, 3.0593085289001465, 0.7969074249267578, 1.3265624046325684, 1.6689274311065674, 3.2373130321502686, 2.5052967071533203, 0.6550310850143433, 1.921600580215454, 4.785950660705566, 6.217737197875977, 2.214029312133789, 0.635899007320404]
预测总票房: 208.76534271240234 亿
走势图: