使用python搭建一个股票训练程序

环境:
win10
python3.75
使用到数据: https://download.csdn.net/download/xy3233/82009700
样式
在这里插入图片描述
方向键 → ↑ 进入下一个交易日
数字键1是买 0是卖 (s是独立窗口的保存快捷键,所以没有用s/b)
主要参考代码: https://blog.csdn.net/xy3233/article/details/123083341

完整代码如下

# -*- coding: utf-8 -*-
import datetime

import pandas as pd
import mplfinance as mpf
import numpy as np
import matplotlib.pyplot as plt

# 独立窗口
import matplotlib

matplotlib.use('tkagg')

# 自定义风格和颜色
# 设置mplfinance的蜡烛颜色,up为阳线颜色,down为阴线颜色
my_color = mpf.make_marketcolors(up='r',  # 上涨K线的柱子的内部填充色
                                 down='g',  # 下跌K线的柱子的内部填充色
                                 edge='inherit',  # 边框设置“inherit”代表使用主配色 不设置则为黑色
                                 wick='inherit',  # wick设置的就是上下影线的颜色
                                 volume='inherit')  # volume设置的是交易量柱子的颜色
# 设置图表的背景色
my_style = mpf.make_mpf_style(marketcolors=my_color,
                              figcolor='(0.82, 0.83, 0.85)',
                              gridcolor='(0.82, 0.83, 0.85)')

# 标题格式,字体为中文字体,颜色为黑色,粗体,水平中心对齐
title_font = {'fontname': 'STZhongsong',
              'size': '16',
              'color': 'black',
              'weight': 'bold',
              'va': 'bottom',
              'ha': 'center'}

# 标签格式,可以显示中文,普通黑色12号字
normal_label_font = {'fontname': 'STZhongsong',
                     'size': '12',
                     'color': 'black',
                     'va': 'bottom',
                     'ha': 'right'}


def get_stock_data(file_path):
    '''
    数据来源
    :param file_path:
    :return:
    '''
    data = pd.read_csv(file_path, index_col=0)
    data['open'] = data['open_price']
    data['high'] = data['high_price']
    data['low'] = data['low_price']
    data['close'] = data['close_price']
    data['volume'] = data['deal_quantity']
    data['change'] = data['change_rate']
    data.index = pd.to_datetime(data['date'], format='%Y-%m-%d')
    data.rename(index=pd.Timestamp)
    return data


class StockSB:
    def __init__(self, stock_data, money=10000):
        self.all_data = pd.DataFrame(stock_data)
        # DataFrame 初始化  添加相应列
        # 添加列  original(投入本金)   quantity(持有数量)   cost_price(成本价) stock_value(市值) usable_money(可用金额)
        self.all_data.insert(self.all_data.shape[1], 'original', money, allow_duplicates=False)
        self.all_data.insert(self.all_data.shape[1], 'quantity', 0, allow_duplicates=False)
        self.all_data.insert(self.all_data.shape[1], 'cost_price', 0, allow_duplicates=False)
        # self.all_data.insert(self.all_data.shape[1], 'stock_value', 0, allow_duplicates=False)
        self.all_data.insert(self.all_data.shape[1], 'usable_money', money, allow_duplicates=False)
        self.all_data.insert(self.all_data.shape[1], 'b/s', 0, allow_duplicates=False)  # 买1 卖 -1
        self.start = 0
        self.len = 50
        self.plot_data = self.all_data.iloc[self.start:self.start + self.len]
        # 初始化 今日数据
        self.current_data = self.plot_data.iloc[-1]

        # 添加三个图表,四个数字分别代表图表左下角在figure中的坐标,以及图表的宽(0.88)、高(0.60)
        self.fig = mpf.figure(figsize=(12, 8), facecolor=(0.82, 0.83, 0.85))
        # 添加三个图表,四个数字分别代表图表左下角在figure中的坐标,以及图表的宽(0.88)、高(0.60)
        self.price_axe = self.fig.add_axes([0.06, 0.25, 0.88, 0.60])  # 添加价格图表 K线图
        # 添加第二、三张图表时,使用sharex关键字指明与ax1在x轴上对齐,且共用x轴
        self.volume_axe = self.fig.add_axes([0.06, 0.15, 0.88, 0.10], sharex=self.price_axe)  # 添加成交量
        self.macd_axe = self.fig.add_axes([0.06, 0.05, 0.88, 0.10], sharex=self.price_axe)  # 添加macd
        # 设置三张图表的Y轴标签
        self.price_axe.set_ylabel('price')
        self.volume_axe.set_ylabel('volume')
        self.macd_axe.set_ylabel('macd')

        # 标题等文本
        # 初始化figure对象,在figure上预先放置文本并设置格式,文本内容根据需要显示的数据实时更新
        self.t1 = self.fig.text(0.50, 0.95, '000001.SH - 平安保险', **title_font)
        self.t2 = self.fig.text(0.10, 0.90, '开/收: ', **normal_label_font)
        self.t2_1 = self.fig.text(0.20, 0.90, f'', **normal_label_font)
        self.t3 = self.fig.text(0.25, 0.90, '高: ', **normal_label_font)
        self.t3_1 = self.fig.text(0.30, 0.90, f'', **normal_label_font)
        self.t4 = self.fig.text(0.35, 0.90, '低: ', **normal_label_font)
        self.t4_1 = self.fig.text(0.40, 0.90, f'', **normal_label_font)
        self.t5 = self.fig.text(0.50, 0.90, '量(万手): ', **normal_label_font)
        self.t5_1 = self.fig.text(0.55, 0.90, f'', **normal_label_font)
        self.t6 = self.fig.text(0.65, 0.90, '当前时间: ', **normal_label_font)
        self.t6_1 = self.fig.text(0.75, 0.90, f'', **normal_label_font)

        self.t7 = self.fig.text(0.09, 0.87, f'本金', **normal_label_font)
        self.t7_1 = self.fig.text(0.20, 0.87, f' ', **normal_label_font)

        self.t8 = self.fig.text(0.25, 0.87, f'成本', **normal_label_font)
        self.t8_1 = self.fig.text(0.30, 0.87, f' ', **normal_label_font)

        self.t9 = self.fig.text(0.35, 0.87, f'总手', **normal_label_font)
        self.t9_1 = self.fig.text(0.40, 0.87, f' ', **normal_label_font)

        self.t10 = self.fig.text(0.50, 0.87, f'利润', **normal_label_font)
        self.t10_1 = self.fig.text(0.55, 0.87, f' ', **normal_label_font)

        self.t11 = self.fig.text(0.65, 0.87, f'可用金额', **normal_label_font)
        self.t11_1 = self.fig.text(0.75, 0.87, f' ', **normal_label_font)

        self.t12 = self.fig.text(0.85, 0.87, f'市值', **normal_label_font)
        self.t12_1 = self.fig.text(0.95, 0.87, f' ', **normal_label_font)

        self.fig.canvas.mpl_connect('key_press_event', self.on_key_press)

    def save_stock_result(self, file_path=None):
        # 保存数据
        result_path = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        self.all_data.to_csv(result_path + '.csv')

    def get_plot_data(self):
        return self.plot_data

    def get_current_data(self):
        return self.current_data

    def next_day(self):
        print(self.current_data)
        # 下一天
        if self.start + self.len < self.all_data.shape[0]:
            self.start = self.start + 1
        self._refresh_data()
        print(self.current_data)

    def last_day(self):
        # 前一天
        if self.start > 1:
            self.start = self.start - 1
        self._refresh_data()

    def _refresh_data(self):
        # 刷新数据
        self.plot_data = self.all_data.iloc[self.start:self.start + self.len]
        self.current_data = self.plot_data.iloc[-1]
        plot_data = self.plot_data
        # 刷新图
        # 读取显示区间最后一个交易日的数据
        last_data = self.current_data
        # 将这些数据分别填入figure对象上的文本中
        self.t2_1.set_text(f'{np.round(last_data["open"], 3)} / {np.round(last_data["close"], 3)}')
        self.t3_1.set_text(f'{last_data["high"]}')
        self.t4_1.set_text(f'{last_data["low"]}')
        self.t5_1.set_text(f'{np.round(last_data["volume"] / 10000, 3)}')
        self.t6_1.set_text(f'{last_data["date"]}')
        self.t7_1.set_text(f'{last_data["original"]}')
        self.t8_1.set_text(f'{last_data["cost_price"]}')
        self.t9_1.set_text(f'{last_data["quantity"]}')
        profit = last_data['usable_money'] + last_data['quantity'] * last_data['close'] - self.current_data['original']
        # 利润
        self.t10_1.set_text(str(round(profit, 2)))
        # 可用金额
        self.t11_1.set_text(f'{round(last_data["usable_money"],2)}')
        self.t12_1.set_text(str(round(last_data['quantity'] * last_data['close'], 2)))

        # 生成一个空列表用于存储多个addplot
        ap = []
        # 添加均线
        ap.append(mpf.make_addplot(plot_data[['ma_5', 'ma_20', 'ma_60']], ax=self.price_axe))
        # 添加 diff 和 dea
        ap.append(mpf.make_addplot(plot_data[['diff']], color='black', ax=self.macd_axe))
        ap.append(mpf.make_addplot(plot_data[['dea']], color='orange', ax=self.macd_axe))
        # 添加macd
        ap.append(mpf.make_addplot(plot_data[['macd']], type='bar', color='green', ax=self.macd_axe))
        #  调用mpf.plot()函数,这里需要指定ax=price_axe,volume=ax2,将K线图显示在ax1中,交易量显示在ax2中
        mpf.plot(plot_data, ax=self.price_axe, addplot=ap, volume=self.volume_axe, type='candle', style=my_style,
                 xrotation=0)
        mpf.show()
        # self.show_stock()

    def buy_stock(self, part=10):
        # 买
        last_stock_data = self.plot_data.iloc[-2]
        # 可用金额
        last_usable_money = last_stock_data['usable_money'] * part / 10
        # 可购买数量 = 可用金额 / (收盘价*100) 取整
        buy_quantity = (last_usable_money // (self.current_data['close'] * 100)) * 100
        if buy_quantity < 100:
            return
        pd_index = self.current_data['date']
        # 整体数量
        quantity = last_stock_data['quantity'] + buy_quantity
        # 今日消费
        buy_money = (self.current_data['close'] * buy_quantity) + 5

        # 今日成本价 = 原始成本价* 数量 +  今天成本价* 数量  + 5 / 总数量
        cost_price = round((last_stock_data['cost_price'] * last_stock_data['quantity'] + buy_money) / quantity, 2)
        self.all_data.loc[pd_index:, 'cost_price'] = cost_price

        # 当前数量
        self.all_data.loc[pd_index:, 'quantity'] = quantity
        # 可用金额
        self.all_data.loc[pd_index:, 'usable_money'] = last_stock_data['usable_money'] - buy_money
        self.all_data.loc[pd_index, 'b/s'] = 1
        self.next_day()

    def sell_stock(self, part=10):
        # 卖
        # 数量
        temp_quantity = (self.current_data['quantity'] * part / 10) // 100
        if temp_quantity < 1:
            return
        sell_quantity = temp_quantity * 100
        pd_index = self.current_data['date']
        quantity = self.current_data['quantity'] - sell_quantity
        self.all_data.loc[pd_index:, 'quantity'] = quantity
        # 可用金额
        self.all_data.loc[pd_index:, 'usable_money'] = self.current_data['usable_money'] + \
                                                       sell_quantity * self.current_data['close'] - 5
        if quantity == 0:
            self.all_data.loc[pd_index:, 'cost_price'] = 0
        self.all_data.loc[pd_index, 'b/s'] = -1
        self.next_day()

    def show_stock(self):
        profit = self.current_data['usable_money'] + self.current_data['quantity'] * self.current_data['close']
        # 可用金额 + 市值 - 本金
        print("利润" + str(profit - self.current_data['original']))
        print("可用金额" + str(self.current_data['usable_money']))

    def on_key_press(self, event):
        key = event.key
        print("key: " + key)
        if key == 'enter':
            # 保存图片
            img_path = str(self.start) + '-' + str(self.start + self.len)
            plt.savefig(img_path + '.jpg')
            return
        elif key == 'left' or key == 'down':
            self.last_day()
        elif key == 'right' or key == 'up':
            self.next_day()
        elif key == '1':
            self.buy_stock()
        elif key == '0':
            self.sell_stock()
        self.price_axe.clear()
        self.macd_axe.clear()
        self.volume_axe.clear()
        if key != '1' or key != '0':
            self._refresh_data()


if __name__ == '__main__':
    f_path = './../0_data/000001_stock.csv'
    data = get_stock_data(f_path)
    # 选取我需要的数据
    all_data = data[
        ['open', 'high', 'low', 'close', 'volume', 'ma_5', 'ma_20', 'ma_60', 'macd', 'diff', 'dea', 'date']]

    start = all_data.shape[0] - 1000  # 开始序号
    len = 500  # 显示长度
    train_data = all_data.iloc[start:start + len]
    sb = StockSB(train_data)
    sb.next_day()
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值