《利用Python 进行数据分析》第八章:绘图和可视化

       对《利用Python 进行数据分析》(Wes Mckinney著)一书中的第八章中绘图和可视化进行代码实验。原书中采用的是Python2.7,而我采用的Python3.7在Pycharm调试的,因此对书中源代码进行了一定的修改,每步结果与原文校验对照一致(除了随机函数外,输出结果在注释中,简单的输出就没写结果),全手工敲写,供参考。

       Pdf文档和数据集参见:《利用Python 进行数据分析》第二章:引言中的分析代码(含pdf和数据集下载链接)

       因为代码过长,放在一个代码段中显得冗长,因此进行了拆分,如下的库引入每个代码段中均可能有必要。

# -*-coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
import numpy as pd

1、matplotlib api入门

1.1 Figure和Subplot

# plt.figure创建一个新的Figure
fig = plt.figure()
# add_subplot创建一个或多个subplot
ax1 = fig.add_subplot(2,2,1)  # 2×2的图像

# 把后面两个subplot创建出来
ax2 = fig.add_subplot(2,2,2)
ax3 = fig.add_subplot(2,2,3)

# 调用plot命令,matplotlib会在最后一个用过的subplot上进行绘制
from numpy.random import randn
plt.plot(randn(50).cumsum(), 'k--')
# plt.show()

# 绘制过两次之后可以分别在其他空着的格子里得到不同的图像
_=ax1.hist(randn(100), bins=20, color='k', alpha=0.3)
ax2.scatter(np.arange(30), np.arange(30) + 3 * randn(30))
plt.show()

# plt.subplots创建一个新的Figure,并返回一个含有已创建的subplot对象的Numpy数组
fig,axes = plt.subplots(2,3)
print(axes)
'''
[[<matplotlib.axes._subplots.AxesSubplot object at 0x00000218C8AC4E88>
  <matplotlib.axes._subplots.AxesSubplot object at 0x00000218C8B15348>
  <matplotlib.axes._subplots.AxesSubplot object at 0x00000218C9D7BB48>]
 [<matplotlib.axes._subplots.AxesSubplot object at 0x00000218C9DB6348>
  <matplotlib.axes._subplots.AxesSubplot object at 0x00000218C9DF0208>
  <matplotlib.axes._subplots.AxesSubplot object at 0x00000218C9E2B208>]]
'''

# 默认情况,matplotlib会在subplot外围留下一定边距,并在subplot之间留下一定的间距
# Figure的subplots_adjust方法可以轻易地修改间距
# subplots_adjust(left=None, bottom=None, riht = None, top = None, wspace=None,hspace=None)
# wspace和hspace用于控制宽度和高度的百分比,可以用作subplot之间的间距
fig, axes = plt.subplots(2,2, sharex=True, sharey=True)
for i in range(2):
    for j in range(2):
        axes[i,j].hist(randn(500), bins=50, color='k', alpha=0.5)
plt.subplots_adjust(wspace=0,hspace=0)
plt.show()

       以下图片对应原书中的图片序号,按代码输出顺序给出(其他段落同样):
       图8-3 绘制一次之后的图像:
在这里插入图片描述
       图8-4 绘制两次之后的图像:
在这里插入图片描述

       图8-5 各subplot之间没有间距:
在这里插入图片描述

1.2 颜色、标记和线型

# plot函数除了接受x和y坐标,还接受一个表示颜色和线型的字符串缩写
x = np.arange(5)
y = np.arange(5) + 5
# plt.plot(x, y, 'g--')

# 通过显示方式指定也可以达到同样的效果
# plt.plot(x,y, linestyle='--', color='g')

# 还可以在线上增加标记,以此强调实际的数据点
plt.plot(randn(30).cumsum(), 'ko--')

# 也可以写成更明确的方式
# plt.plot(randn(30).cumsum(), color='k', linestyle='dashed', marker='o')
plt.show()

# 线形图中,非数据数据点默认按线性插值的,可以通过drawstyle选项修改
data=randn(30).cumsum()
plt.plot(data, 'k--', label = 'Default')

plt.plot(data, 'k-', drawstyle='steps-post', label='steps-post')

plt.legend(loc='best')
plt.show()

       图8-6 带有标记的现型图示例:
在这里插入图片描述

       图8-7 不同drawstyle选项的线型图:
在这里插入图片描述

1.3 刻度、标签和图例

'''
xlim、xticks和xticklabel分别控制图标的范围、刻度的位置、刻度标签:
(1) 调用时不带参数,则返回当前的参数值,例如plt.xlim()返回当前X轴绘图范围
(2) 调用时带参数,则设置参数值。因此plt.xlim([0,10])会将X轴的范围设置为0到10
'''
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.plot(randn(1000).cumsum())

# set_xticks 告诉matplotlib将刻度放在数据范围中的哪些位置
ticks = ax.set_xticks([0, 250, 500, 750, 1000])
# set_xticklabels可以将任何其他的值作为标签
labesl = ax.set_xticklabels(['one', 'two', 'three', 'four', 'five'],
                            rotation = 30, fontsize = 'small')
plt.show()

# 添加图例
# legend函数是用于标识图标元素的重要工具
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.plot(randn(1000).cumsum(), 'k', label='one')
ax.plot(randn(1000).cumsum(), 'k--', label='two')
ax.plot(randn(1000).cumsum(), 'k.', label='three')
# 可以调用ax.legend()或plt.legend()自动创建图例
ax.legend(loc='best')
plt.show()

       图8-8 用于演示xticks的简单线型图:
在这里插入图片描述

       图8-9 用于演示xticks的简单线型图:
在这里插入图片描述

       图8-10 带有三条线及图例的简单线型图:
在这里插入图片描述

1.4 注解及在Subplot上绘图

# 通过text、arrow和annotate等函数可以为图片添加注解
# text可以将文本绘制在图标的指定坐标(x,y),还可以加上一些自定义格式
from datetime import datetime
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
data = pd.read_csv('python_data/ch08/spx.csv', index_col=0, parse_dates=True)
spx = data['SPX']
spx.plot(ax=ax, style='k-')
crisis_data=[
    (datetime(2007,10,11), 'Peak of bull market'),
    (datetime(2008,3,12), 'Bear Stearns Fails'),
    (datetime(2008,9,15), 'Lehman Bankruptcy')
]
for date, label in crisis_data:
    ax.annotate(label, xy=(date, spx.asof(date) + 50),
                xytext=(date,spx.asof(date)+ 200),
                arrowprops = dict(facecolor='black'),
                horizontalalignment = 'left', verticalalignment='top')
# 放大到2007-2010
ax.set_xlim(['1/1/2007','1/1/2011'])
ax.set_ylim([600, 1800])
ax.set_title('Important dates in 2008-2009 finalcial crisis')
plt.show()

# 在图表中添加一个图形,如圆形、矩形、三角形
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
rect = plt.Rectangle((0.2,0.75), 0.4, 0.15, color='k',alpha=0.3)
circ = plt.Circle((0.7, 0.2), 0.15, color='b', alpha=0.3)
pgon = plt.Polygon([[0.15,0.15], [0.35,0.4], [0.2, 0.6]], color = 'g', alpha = 0.5)
ax.add_patch(rect)
ax.add_patch(circ)
ax.add_patch(pgon)
plt.show()

       图8-11 2008-2009年金融危机期间的重要日期:
在这里插入图片描述

       图8-12 由三个块图像组成的图:
在这里插入图片描述

1.5 将图标保存到文件

# plt.savefigure('figpath.svg') 可以将当前图表保存到文件
# 发布图片时,有两个重要参数dpi(控制“每英寸点数”分辨率)和bbox_inches(可以剪除当前图表周围的空白部分)
plt.savefig('figpath.png', dpi=400, bbox_inchex= 'tight')

# savefig并非一定要写入磁盘,也可以写入任何文件型的对象
from io import BytesIO
buffer = BytesIO()  # 用StringIO会报错
plt.savefig(buffer)
plot_data = buffer.getvalue()

1.6 matplotlib配置

# 对全局的图像默认大小设置为10×10,rc的第一个参数是希望自定义的对象
plt.rc('figure',figsize=(10,10))

# 可以将这些选项写成一个字典
font_options = {'family':'monospace',
                'weight':'bold',
                'size':20}
plt.rc('font',**font_options)

2、pandas中绘图函数

2.1 线形图

s = pd.Series(np.random.randn(10).cumsum(), index = np.arange(0, 100, 10))
s.plot()
plt.grid(linestyle = '-.')
plt.show()

# DataFrame的plot方法会在一个subplot中为各列绘制一条线,并自动创建图例
df = pd.DataFrame(np.random.randn(10,4).cumsum(0),
                  columns = ['A', 'B', 'C', 'D'],
                  index = np.arange(0,100,10))
plt.plot(df)
plt.show()

       图8-13 简单Series图表示例:
在这里插入图片描述

       图8-14 简单DataFrame图表示例:
在这里插入图片描述

2.2 柱状图

fig,axes = plt.subplots(2,1)
data = pd.Series(np.random.rand(16), index = list('abcdefghijklmnop'))
data.plot(kind='bar',ax=axes[0],color='k',alpha=0.7)
data.plot(kind='barh',ax=axes[1],color='k',alpha=0.7)
# 以下可以实现与上述相同的效果
# axes[0].bar(data.index,data.values, color='k',alpha=0.7)
# axes[1].barh(data.index,data.values, color='k', alpha=0.7)
plt.yticks(rotation=45)
# plt.show()

# 对于DataFrame,柱状图将每一行的值为分为一组
df = pd.DataFrame(np.random.rand(6,4), index=['one','two','three','four','five','six'],
                  columns=pd.Index(['A','B','C','D'],name='Genus'))
print(df)
'''
Genus         A         B         C         D
one    0.312715  0.883614  0.712203  0.797839
two    0.158892  0.214012  0.702087  0.698235
three  0.460272  0.639971  0.863585  0.671799
four   0.346089  0.254262  0.129439  0.675211
five   0.160448  0.500737  0.924986  0.875125
six    0.312943  0.164694  0.161089  0.321563
'''
df.plot(kind='bar')
plt.xticks(rotation = 25)
plt.show()

# stacker=True即可为DataFrame生成堆积柱状图,每行的值就会被堆积在一起
df.plot(kind='barh', stacked = True, alpha = 0.5)
plt.show()

# 读取有关小费的数据集
tips = pd.read_csv('python_data/ch08/tips.csv')
party_counts = pd.crosstab(tips.day, tips['size']) # 不能用tips.size,这样会得到数据的总数据点数(行×列)
print(party_counts)
'''
day                      
Fri   1  16   1   1  0  0
Sat   2  53  18  13  1  0
Sun   0  39  15  18  3  1
Thur  1  48   4   5  1  3
'''
# 1个人和6个人聚会较少
party_counts = party_counts.iloc[:, 1:5]
# 对数据进行规格化,使得各行的和为1
party_pcts = party_counts.div(party_counts.sum(1).astype(float), axis =0)
print(party_pcts)
'''
size         2         3         4         5
day                                         
Fri   0.888889  0.055556  0.055556  0.000000
Sat   0.623529  0.211765  0.152941  0.011765
Sun   0.520000  0.200000  0.240000  0.040000
Thur  0.827586  0.068966  0.086207  0.017241
'''
party_pcts.plot(kind = 'bar', stacked = True)
plt.xticks(rotation = 25)
plt.show()

       图8-15 水平和垂直柱状图示例:
在这里插入图片描述

       图8-16 DataFrame柱状图示例:
在这里插入图片描述

       图8-17 DataFrame堆积柱状图示例:
在这里插入图片描述

       图8-18 每天各种聚会规模的比例:
在这里插入图片描述

2.3 直方图和密度图

tips['tip_pct'] = tips['tip']/tips['total_bill'] # tips数据参见2.2小节
# 消费百分比的直方图
# t.hist(bins=50)
plt.hist(tips['tip_pct'],bins=50)
plt.grid(alpha=0.3,linestyle='-.')
plt.show()

# plot时加上kind = 'kde'可以生成一张密度图(标准混合正态分布图KDE)
tips['tip_pct'].plot(kind = 'kde')
plt.grid(alpha=0.3,linestyle='-.')
plt.show()

# 直方图和密度图常常画在一起
# 由两个不同的正态分布组成的双峰分布
comp1 = np.random.normal(0,1,size=200) # N(0,1)
comp2 = np.random.normal(10,2, size = 200) # N(10,4)
values = pd.Series(np.concatenate([comp1, comp2]))
values.hist(bins=100, alpha=0.3, color='k', density=True)  # 新版本中归一化由normed换成了density
values.plot(kind='kde',style='k--')
plt.grid(alpha=0.3,linestyle='-.')
plt.show()

       图8-19 消费百分比直方图:
在这里插入图片描述

       图8-20 消费百分比的密度图:
在这里插入图片描述

       图8-21 带有密度估计的规格化直方图:
在这里插入图片描述

2.4 散布图

# 散点图使用plt.scatter()得到
macro = pd.read_csv('python_data/ch08/macrodata.csv')
data = macro[['cpi','m1','unemp']]
trans_data = np.log(data).diff().dropna()
print(trans_data[-5:])
'''
          cpi        m1     unemp
198 -0.007904  0.045361  0.105361
199 -0.021979  0.066753  0.139762
200  0.002340  0.010286  0.160343
201  0.008419  0.037461  0.127339
202  0.008894  0.012202  0.042560
'''

plt.scatter(trans_data['m1'],trans_data['unemp'])
plt.title('Changes in log %s vs. log %s' %('m1','unemp'))
plt.show()

# scatter_matrix函数支持在对角上放置各变量的直方图或密度图
pd.plotting.scatter_matrix(trans_data, diagonal='kde', color='k', alpha=0.3)
plt.show()

       图8-22 一张简单的散布图:
在这里插入图片描述

       图8-23 statsmodels macro data的散布图矩阵:
在这里插入图片描述

3、绘制地图:图形化显示海地地震危机数据

# 导入2010年海地地震及余震期间搜集的数据
data = pd.read_csv('python_data/ch08/Haiti.csv')
print(data.columns)
'''
Index(['Serial', 'INCIDENT TITLE', 'INCIDENT DATE', 'LOCATION', 'DESCRIPTION',
       'CATEGORY', 'LATITUDE', 'LONGITUDE', 'APPROVED', 'VERIFIED'],
      dtype='object')
'''

# 查看我们想要的数据的前10条
print(data[['INCIDENT DATE', 'LATITUDE', 'LONGITUDE']][:10])
'''
      INCIDENT DATE   LATITUDE   LONGITUDE
0  05/07/2010 17:26  18.233333  -72.533333
1  28/06/2010 23:06  50.226029    5.729886
2  24/06/2010 16:21  22.278381  114.174287
3  20/06/2010 21:59  44.407062    8.933989
4  18/05/2010 16:26  18.571084  -72.334671
5  26/04/2010 13:14  18.593707  -72.310079
6  26/04/2010 14:19  18.482800  -73.638800
7  26/04/2010 14:27  18.415000  -73.195000
8  15/03/2010 10:58  18.517443  -72.236841
9  15/03/2010 11:00  18.547790  -72.410010
'''

# CATEGOTY 字段含有一组以逗号分隔的代码,代码表示消息的类型
print(data['CATEGORY'][:6])
'''
0          1. Urgences | Emergency, 3. Public Health, 
1    1. Urgences | Emergency, 2. Urgences logistiqu...
2    2. Urgences logistiques | Vital Lines, 8. Autr...
3                            1. Urgences | Emergency, 
4                            1. Urgences | Emergency, 
5                       5e. Communication lines down, 
Name: CATEGORY, dtype: object
'''
# 通过describe可以发现数据中存在一些异常的地理位置
print(data.describe())
'''
            Serial     LATITUDE    LONGITUDE
count  3593.000000  3593.000000  3593.000000
mean   2080.277484    18.611495   -72.322680
std    1171.100360     0.738572     3.650776
min       4.000000    18.041313   -74.452757
25%    1074.000000    18.524070   -72.417500
50%    2163.000000    18.539269   -72.335000
75%    3088.000000    18.561820   -72.293570
max    4052.000000    50.226029   114.174287
'''

# 清楚错误位置并移除分类信息
data = data[(data.LATITUDE > 18)&(data.LATITUDE < 20)&
            (data.LONGITUDE > -75)&(data.LONGITUDE < -70)
            &data.CATEGORY.notnull()]

def to_cat_list(catstr):
    stripped = (x.strip() for x in catstr.split(','))
    return [x for x in stripped if x]

# 获取所有分类的列表
def get_all_categories(cat_series):
    cat_sets = (set(to_cat_list(x)) for x in cat_series)
    return sorted(set.union(*cat_sets))

# 将各个分类信息拆分成编码和英语名称
def get_english(cat):
    code, names = cat.split('.')
    if '|' in names:
        names = names.split('|')[1]
    return code, names.strip()

# 测试get_english函数是否正常工作
ret = get_english('2. Urgences logistiques | Vital Llies')
print(ret) # ('2', 'Vital Llies')

all_cats =get_all_categories(data.CATEGORY)

# 生成器表达式
english_mapping = dict(get_english(x) for x in all_cats)
print(english_mapping['2a']) # Food Shortage
print(english_mapping['6c']) # Earthquake and aftershocks

# 抽取出为宜的分类编码,并构造一个全零DataFrame(列为分类编码,索引跟data的索引一样)
def get_code(seq):
    return [x.split('.')[0] for x in seq if x]

all_codes = get_code(all_cats)
code_inex = pd.Index(np.unique(all_codes))
dummy_frame = pd.DataFrame(np.zeros((len(data), len(code_inex))), index = data.index,
                           columns=code_inex)
print(dummy_frame.head())
'''
     1   1a   1b   1c   1d    2   2a  ...   7h    8   8a   8c   8d   8e   8f
0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0  0.0  0.0  0.0
4  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0  0.0  0.0  0.0
5  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0  0.0  0.0  0.0
6  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0  0.0  0.0  0.0
7  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0  0.0  0.0  0.0
'''

for row,cat in zip(data.index, data.CATEGORY):
    codes = get_code(to_cat_list(cat))
    dummy_frame.loc[row, codes] = 1

data = data.join(dummy_frame.add_prefix('category_'))
print(data.head())
'''
   Serial  ... category_8f
0    4052  ...         0.0
4    4042  ...         0.0
5    4041  ...         0.0
6    4040  ...         0.0
7    4039  ...         0.0
'''

from mpl_toolkits.basemap import Basemap
import matplotlib.pyplot as plt

def basic_haiti_map(ax=None, lllat=11.25, urlat=20.25,
                    lllon=75,urlon=-71):
    # 创建极球面投影的Basemap实例
    m=Basemap(ax=ax,projection='stere',
              lon_0=(urlon + lllon) /2,
              lat_0=(urlat + lllat)/2,
              llcrnrlat=lllat, urcrnrlat=urlat,
              llcrnrlon=lllon, urcrnrlon=urlon,
              resolution='f')
    # 绘制海岸线、州界、国界以及地图边界
    m.drawcoastlines()
    m.drawstates()
    m.drawcountries()
    return m

# 让返回的Basemap对象的坐标返回到画布上
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12,10))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
to_plot=['2a','1','3c','7a']
lllat = 17.25
urlat = 20.25
lllon = -75
urlon = -71
for code, ax in zip(to_plot, axes.flat):
    m = basic_haiti_map(ax, lllat=lllat, urlat=urlat,
                        lllon=lllon, urlon=urlon)
    cat_data = data[data['category_%s' % code] == 1]
    # 计算地图的投影坐标
    x, y = m(list(cat_data.LONGITUDE), list(cat_data.LATITUDE))
    m.plot(x, y, 'k.', alpha=0.5)
    ax.set_title('%s:%s'%(code, english_mapping[code]), fontsize=12)
plt.show()

       图8-24 海地地震的4类数据:
在这里插入图片描述

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

南洲.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值