PyTorch Transformer 预测股票价格,虚拟数据

#
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np


num_days = 200
stock_prices = np.random.rand(num_days) * 100



input_seq_len = 10
output_seq_len = 5
num_samples = num_days - input_seq_len - output_seq_len + 1

src_data = torch.tensor([stock_prices[i:i+input_seq_len] for i in range(num_samples)]).unsqueeze(-1).float()
tgt_data = torch.tensor([stock_prices[i+input_seq_len:i+input_seq_len+output_seq_len] for i in range(num_samples)]).unsqueeze(-1).float()


class StockPriceTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dropout):
        super(StockPriceTransformer, self).__init__()
        self.input_linear = nn.Linear(1, d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_layers, dropout=dropout)
        self.output_linear = nn.Linear(d_model, 1)

    def forward(self, src, tgt):
        src = self.input_linear(src)
        tgt = self.input_linear(tgt)
        output = self.transformer(src, tgt)
        output = self.output_linear(output)
        return output

d_model = 64
nhead = 4
num_layers = 2
dropout = 0.1

model = StockPriceTransformer(d_model, nhead, num_layers, dropout=dropout)


epochs = 100
lr = 0.001
batch_size = 16

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    for i in range(0, num_samples, batch_size):
        src_batch = src_data[i:i+batch_size].transpose(0, 1)
        tgt_batch = tgt_data[i:i+batch_size].transpose(0, 1)
        
        optimizer.zero_grad()
        output = model(src_batch, tgt_batch[:-1])
        loss = criterion(output, tgt_batch[1:])
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")



src = torch.tensor(stock_prices[-input_seq_len:]).unsqueeze(-1).unsqueeze(1).float()
tgt = torch.zeros(output_seq_len, 1, 1)

with torch.no_grad():
    for i in range(output_seq_len):
        prediction = model(src, tgt[:i+1])
        tgt[i] = prediction[-1]

output = tgt.squeeze().tolist()
print("Next 5 days of stock prices:", output)


  • 5
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值