股票相关性分类图

根据历史数据的相关性进行分类

代码注释中包含源代码网站,数据源使用 baostock

分组信息打印于命令行

# -*-coding:utf-8-*-
# https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html
# 股票聚类图

import datetime
import sys
from pathlib import Path

import baostock as bs
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import LineCollection
from sklearn import cluster, covariance, manifold

mpl.rcParams['font.sans-serif'] = ['SimHei']  # 设置简黑字体
mpl.rcParams['axes.unicode_minus'] = False  # 解决‘-’bug

symbol_dict = {
    "000659": "珠海中富",
    "600222": "太龙药业",
    "600527": "江南高纤",
    "600572": "康恩贝",
    "002007": "华兰生物",
    "600532": "宏达矿业",
    "601991": "大唐发电",
    "000869": "张 裕A",
    "601811": "新华文轩",
    "002282": "博深股份",
    "600078": "澄星股份",
    "002133": "广宇集团",
    "600678": "四川金顶",
    "601958": "金钼股份",
    "000830": "鲁西化工",
    "000935": "四川双马",
    "603833": "欧派家居",
    "600605": "汇通能源",
    "603578": "三星新材",
    "601086": "国芳集团",
    "600836": "*ST界龙",
    "600163": "中闽能源",
    "002160": "常铝股份",
    "002283": "天润工业",
    "002265": "西仪股份",
    "002652": "扬子新材",
    "600743": "华远地产",
    "002612": "朗姿股份",
    "600567": "山鹰纸业",
    "002277": "友阿股份",
    "603009": "北特科技",
    "000525": "红 太 阳",
    "002593": "日上集团",
    "601377": "兴业证券",
    "002387": "维信诺",
    "601555": "东吴证券",
    "002409": "雅克科技",
    "600265": "ST景谷",
    "002025": "航天电器",
    "600756": "浪潮软件",
    "002279": "久其软件",
    "002577": "雷柏科技",
    "002599": "盛通股份",
    "002777": "久远银海",
    "002446": "盛路通信",
    "000070": "特发信息",
    "002131": "利欧股份"
}

symbols, names = np.array(sorted(symbol_dict.items())).T
quotes = []

lg = bs.login()
print('login respond error_code:' + lg.error_code)
print('login respond  error_msg:' + lg.error_msg)
end_date = datetime.datetime.now().strftime('%Y-%m-%d')

date_info = []
for i, symbol in enumerate(symbols):
    file_name = "csv/day/" + symbol + "_day_bs.csv"
    my_file = Path(file_name)
    if my_file.exists():
        code_date = pd.read_csv("csv/day/" + symbol + "_day_bs.csv")
        code_date = code_date.head(625)
        if len(code_date) == 625:
            quotes.append(code_date)
            date_info.append([symbol, names[i]])
    else:
        first_code = symbol[0:1]
        if first_code == '0':
            code_bs = 'sz.' + symbol
        elif first_code == '6':
            code_bs = 'sh.' + symbol
        else:
            continue
        print('Fetching quote history for %r' % symbol, file=sys.stderr)
        rs = bs.query_history_k_data_plus(code_bs,
                                          "date,open,close",
                                          start_date='2018-01-01', end_date='2020-07-29',
                                          frequency="d", adjustflag="3")

        data_list = []
        while (rs.error_code == '0') & rs.next():
            # 获取一条记录,将记录合并在一起
            data_list.append(rs.get_row_data())
        code_date = pd.DataFrame(data_list, columns=rs.fields)
        code_date.to_csv(file_name, index=False)
        code_date = code_date.head(625)
        if len(code_date) == 625:
            quotes.append(code_date)
            date_info.append([symbol, names[i]])

date_info = np.array(date_info)
symbols = date_info[:, 0]
names = date_info[:, 1]

close_prices = np.vstack([q['close'] for q in quotes])
open_prices = np.vstack([q['open'] for q in quotes])

close_prices = close_prices.astype(np.float64)
open_prices = open_prices.astype(np.float64)

# The daily variations of the quotes are what carry most information
variation = close_prices - open_prices

# #############################################################################
# Learn a graphical structure from the correlations
edge_model = covariance.GraphicalLassoCV()

# standardize the time series: using correlations rather than covariance
# is more efficient for structure recovery
X = variation.copy().T
X /= X.std(axis=0)
edge_model.fit(X)

# #############################################################################
# Cluster using affinity propagation

_, labels = cluster.affinity_propagation(edge_model.covariance_,
                                         random_state=0)
n_labels = labels.max()
group_data = pd.DataFrame({"level": labels, "name": names, "symbol": symbols})
group_relate = group_data.groupby('level')
for name, group in group_relate:
    print("group: "+str(name))
    print(group)

# #############################################################################
# Find a low-dimension embedding for visualization: find the best position of
# the nodes (the stocks) on a 2D plane

# We use a dense eigen_solver to achieve reproducibility (arpack is
# initiated with random vectors that we don't control). In addition, we
# use a large number of neighbors to capture the large-scale structure.
node_position_model = manifold.LocallyLinearEmbedding(
    n_components=2, eigen_solver='dense', n_neighbors=6)

embedding = node_position_model.fit_transform(X.T).T

# #############################################################################
# Visualization
plt.figure(1, facecolor='w', figsize=(10, 8))
plt.clf()
ax = plt.axes([0., 0., 1., 1.])
plt.axis('off')

# Display a graph of the partial correlations
partial_correlations = edge_model.precision_.copy()
d = 1 / np.sqrt(np.diag(partial_correlations))
partial_correlations *= d
partial_correlations *= d[:, np.newaxis]
non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02)

# Plot the nodes using the coordinates of our embedding
plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,
            cmap=plt.cm.nipy_spectral)

# Plot the edges
start_idx, end_idx = np.where(non_zero)
# a sequence of (*line0*, *line1*, *line2*), where::
#            linen = (x0, y0), (x1, y1), ... (xm, ym)
segments = [[embedding[:, start], embedding[:, stop]]
            for start, stop in zip(start_idx, end_idx)]
values = np.abs(partial_correlations[non_zero])
lc = LineCollection(segments,
                    zorder=0, cmap=plt.cm.hot_r,
                    norm=plt.Normalize(0, .7 * values.max()))
lc.set_array(values)
lc.set_linewidths(15 * values)
ax.add_collection(lc)

# Add a label to each node. The challenge here is that we want to
# position the labels to avoid overlap with other labels
for index, (name, label, (x, y)) in enumerate(
        zip(names, labels, embedding.T)):

    dx = x - embedding[0]
    dx[index] = 1
    dy = y - embedding[1]
    dy[index] = 1
    this_dx = dx[np.argmin(np.abs(dy))]
    this_dy = dy[np.argmin(np.abs(dx))]
    if this_dx > 0:
        horizontalalignment = 'left'
        x = x + .002
    else:
        horizontalalignment = 'right'
        x = x - .002
    if this_dy > 0:
        verticalalignment = 'bottom'
        y = y + .002
    else:
        verticalalignment = 'top'
        y = y - .002
    plt.text(x, y, name, size=10,
             horizontalalignment=horizontalalignment,
             verticalalignment=verticalalignment,
             bbox=dict(facecolor='w',
                       edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
                       alpha=.6))

plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(),
         embedding[0].max() + .10 * embedding[0].ptp())
plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(),
         embedding[1].max() + .03 * embedding[1].ptp())

plt.show()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值