DataFrame进行数据分组运算并筛选指定条件的group

需求

假设有个股票行情文件,内容如下,数据为虚构:

code,time,open,high,low
000001.SZ,095000,2,3,2.5
000001.SZ,095300,2,3,2.5
000001.SZ,095600,2,3,2.5
000002.SZ,095000,2,3,2.5
000003.SZ,095600,2,3,2.5
000003.SZ,095900,2,3,2.5

现在要计算每支股票high和low的均值,如果某股票的行情条数不足2条,则忽略不计。

实现

问题不难,方法也有多种。

这里介绍一种使用DataFrame分组groupby和筛选filter满足条件group的方式。

关于groupby的使用可以参考:

https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.groupby.html?highlight=groupby#pandas.DataFrame.groupby

原型如下:

DataFrame.groupby(by=None, axis=0, level=None, as_index=True, sort=True, group_keys=True, squeeze=NoDefault.no_default, observed=False, dropna=True)

它返回包含各组信息的groupby对象,可对该对象进行各种聚合运算。

关于详细使用和参数说明,可参考:

https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html

代码如下:

import pandas as pd

df = pd.read_csv('stock.csv')

# 根据标的分组,并对分组后的对象的指定列应用聚合函数mean,求均值
grouped_stock = df.groupby('code')[['high','low']].mean()
print(grouped_stock)

输出如下:

>>> print(grouped_stock)
           high  low
code                
000001.SZ   3.0  2.5
000002.SZ   3.0  2.5
000003.SZ   3.0  2.5

此时,标的000002.SZ的行情只有1条,需要过滤过,增加一行代码做条件筛选,如下:

import pandas as pd

df = pd.read_csv('stock.csv')

# 根据行情条数筛选
filtered_df = df.groupby('code').filter(lambda x: len(x) >= 2)

grouped_stock = filtered_df.groupby('code')[['high','low']].mean()
print(grouped_stock)

输出如下:

>>> print(grouped_stock)
           high  low
code                
000001.SZ   3.0  2.5
000003.SZ   3.0  2.5

可以看到,不符合条件的标的已经被过滤。

关于groupby的一些使用总结

下面是实现上述需求后,作的一些技术上的总结,当作复习使用。

直接以代码形式展示:

import pandas as pd

df = pd.read_csv('stock.csv')

# 根据标的分组
grouped = df.groupby('code')

# 分组后对象的类型
print(type(grouped))
# <class 'pandas.core.groupby.generic.DataFrameGroupBy'>

# 遍历分组后对象
for name, group in grouped:
	print(name, '\n', group)

# 以下为输出:
000001.SZ 
         code   time  open  high  low
0  000001.SZ  95000     2     3  2.5
1  000001.SZ  95300     2     3  2.5
2  000001.SZ  95600     2     3  2.5
000002.SZ 
         code   time  open  high  low
3  000002.SZ  95000     2     3  2.5
000003.SZ 
         code   time  open  high  low
4  000003.SZ  95600     2     3  2.5
5  000003.SZ  95900     2     3  2.5
# 输出结束

# 获取指定分组
grouped.get_group('000001.SZ')
# 以下为输出
        code   time  open  high  low
0  000001.SZ  95000     2     3  2.5
1  000001.SZ  95300     2     3  2.5
2  000001.SZ  95600     2     3  2.5
# 输出结束

# 应用聚合运算
grouped.aggregate(np.sum)
# 输出开始
             time  open  high  low
code                              
000001.SZ  285900     6     9  7.5
000002.SZ   95000     2     3  2.5
000003.SZ  191500     4     6  5.0
# 输出结束

grouped.size()
# 输出开始
code
000001.SZ    3
000002.SZ    1
000003.SZ    2
dtype: int64
# 输出结束

grouped.count()
# 输出开始
           time  open  high  low
code                            
000001.SZ     3     3     3    3
000002.SZ     1     1     1    1
000003.SZ     2     2     2    2
# 输出结束

grouped.describe()
#输出开始
           time                                                           ...  low                              
          count     mean         std      min      25%      50%      75%  ... mean  std  min  25%  50%  75%  max
code                                                                      ...                                   
000001.SZ   3.0  95300.0  300.000000  95000.0  95150.0  95300.0  95450.0  ...  2.5  0.0  2.5  2.5  2.5  2.5  2.5
000002.SZ   1.0  95000.0         NaN  95000.0  95000.0  95000.0  95000.0  ...  2.5  NaN  2.5  2.5  2.5  2.5  2.5
000003.SZ   2.0  95750.0  212.132034  95600.0  95675.0  95750.0  95825.0  ...  2.5  0.0  2.5  2.5  2.5  2.5  2.5

[3 rows x 32 columns]
# 输出结束

grouped.agg([np.sum, np.mean, np.std])
# 输出开始
             time                      open           high            low          
              sum     mean         std  sum mean  std  sum mean  std  sum mean  std
code                                                                               
000001.SZ  285900  95300.0  300.000000    6  2.0  0.0    9  3.0  0.0  7.5  2.5  0.0
000002.SZ   95000  95000.0         NaN    2  2.0  NaN    3  3.0  NaN  2.5  2.5  NaN
000003.SZ  191500  95750.0  212.132034    4  2.0  0.0    6  3.0  0.0  5.0  2.5  0.0
# 输出结束

grouped.filter(lambda x: len(x) >= 2)
# 输出开始
        code   time  open  high  low
0  000001.SZ  95000     2     3  2.5
1  000001.SZ  95300     2     3  2.5
2  000001.SZ  95600     2     3  2.5
4  000003.SZ  95600     2     3  2.5
5  000003.SZ  95900     2     3  2.5
# 输出结束

grouped.filter(lambda x: len(x) >= 2).groupby('code').mean()
# 输出开始
              time  open  high  low
code                               
000001.SZ  95300.0   2.0   3.0  2.5
000003.SZ  95750.0   2.0   3.0  2.5
# 输出结束
小结

对于数据分组和聚合运算的需求,使用DataFrame非常方便。

本文主要使用了:

  • groupby
  • filter

要注意区分各个方法的返回对象,再进行后续操作。

参考

https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.groupby.html?highlight=groupby#pandas.DataFrame.groupby
https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html

  • 4
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值