前段时间老师让我处理一段时间序列,要用样本熵来处理。
目前能找到的代码好像只有一个版本,被各个网站、博主转来转去,我尝试了以后发现好像并不能解决我的问题,报错也不知道是因为什么。
后来查阅了资料发现PYPI上早就有了相对应的库,所以来和大家分享介绍一下。
关于样本熵的定义和物理意义,计算方法等可以参考这篇信号处理算法(2):样本熵(SampEn)
本文不对这方面内容展开描述
这是官网上对该库的介绍:
附上主要代码
import math
from .normalize_data import normalize_data
def sampen2(data, mm=2, r=0.2, normalize=False):
"""
Calculates an estimate of sample entropy and the variance of the estimate.
:param data: The data set (time series) as a list of floats.
:type data: list
:param mm: Maximum length of epoch (subseries).
:type mm: int
:param r: Tolerance. Typically 0.1 or 0.2.
:type r: float
:param normalize: Normalize such that the mean of the input is 0 and
the sample, variance is 1.
:type normalize: bool
:return: List[(Int, Float/None, Float/None)...]
Where the first (Int) value is the Epoch length.
The second (Float or None) value is the SampEn.
The third (Float or None) value is the Standard Deviation.
The outputs are the sample entropies of the input, for all epoch lengths of
0 to a specified maximum length, m.
If there are no matches (the data set is unique) the sample entropy and
standard deviation will return None.
:rtype: list
"""
n = len(data)
if n == 0:
raise ValueError("Parameter `data` contains an empty list")
if mm > n / 2:
raise ValueError(
"Maximum epoch length of %d too large for time series of length "
"%d (mm > n / 2)" % (
mm,
n,
)
)
mm += 1
mm_dbld = 2 * mm
if mm_dbld > n:
raise ValueError(
"Maximum epoch length of %d too large for time series of length "
"%d ((mm + 1) * 2 > n)" % (
mm,
n,
)
)
if normalize is True:
data = normalize_data(data)
# initialize the lists
run = [0] * n
run1 = run[:]
r1 = [0] * (n * mm_dbld)
r2 = r1[:]
f = r1[:]
f1 = [0] * (n * mm)
f2 = f1[:]
k = [0] * ((mm + 1) * mm)
a = [0] * mm
b = a[:]
p = a[:]
v1 = a[:]
v2 = a[:]
s1 = a[:]
n1 = a[:]
n2 = a[:]
for i in range(n - 1):
nj = n - i - 1
y1 = data[i]
for jj in range(nj):
j = jj + i + 1
if data[j] - y1 < r and y1 - data[j] < r:
run[jj] = run1[jj] + 1
m1 = mm if mm < run[jj] else run[jj]
for m in range(m1):
a[m] += 1
if j < n - 1:
b[m] += 1
f1[i + m * n] += 1
f[i + n * m] += 1
f[j + n * m] += 1
else:
run[jj] = 0
for j in range(mm_dbld):
run1[j] = run[j]
r1[i + n * j] = run[j]
if nj > mm_dbld - 1:
for j in range(mm_dbld, nj):
run1[j] = run[j]
for i in range(1, mm_dbld):
for j in range(i - 1):
r2[i + n * j] = r1[i - j - 1 + n * j]
for i in range(mm_dbld, n):
for j in range(mm_dbld):
r2[i + n * j] = r1[i - j - 1 + n * j]
for i in range(n):
for m in range(mm):
ff = f[i + n * m]
f2[i + n * m] = ff - f1[i + n * m]
k[(mm + 1) * m] += ff * (ff - 1)
m = mm - 1
while m > 0:
b[m] = b[m - 1]
m -= 1
b[0] = float(n) * (n - 1.0) / 2.0
for m in range(mm):
p[m] = float(a[m]) / float(b[m])
v2[m] = p[m] * (1.0 - p[m]) / b[m]
for m in range(mm):
d2 = m + 1 if m + 1 < mm - 1 else mm - 1
for d in range(d2):
for i1 in range(d + 1, n):
i2 = i1 - d - 1
nm1 = f1[i1 + n * m]
nm3 = f1[i2 + n * m]
nm2 = f2[i1 + n * m]
nm4 = f2[i2 + n * m]
# if r1[i1 + n * j] >= m + 1:
# nm1 -= 1
# if r2[i1 + n * j] >= m + 1:
# nm4 -= 1
for j in range(2 * (d + 1)):
if r2[i1 + n * j] >= m + 1:
nm2 -= 1
for j in range(2 * d + 1):
if r1[i2 + n * j] >= m + 1:
nm3 -= 1
k[d + 1 + (mm + 1) * m] += float(2 * (nm1 + nm2) * (nm3 + nm4))
n1[0] = float(n * (n - 1) * (n - 2))
for m in range(mm - 1):
for j in range(m + 2):
n1[m + 1] += k[j + (mm + 1) * m]
for m in range(mm):
for j in range(m + 1):
n2[m] += k[j + (mm + 1) * m]
# calculate standard deviation for the set
for m in range(mm):
v1[m] = v2[m]
dv = (n2[m] - n1[m] * p[m] * p[m]) / (b[m] * b[m])
if dv > 0:
v1[m] += dv
s1[m] = math.sqrt(v1[m])
# assemble and return the response
response = []
for m in range(mm):
if p[m] == 0:
# Infimum, the data set is unique, there were no matches.
response.append((m, None, None))
else:
response.append((m, -math.log(p[m]), s1[m]))
return response
其实只要安装上这个包就可以使用了
$ pip install sampen
用法
from sampen import sampen2
# initialize a list(初始化list)
series_data = []
# open the file and read each line into the list(按行读取)
with open('relative/path/to/file.txt', 'r') as file:
for row in file:
series_data.append(float(row.strip(' \t\n\r')))
# calculate the sample entropy
sampen_of_series = sampen2(series_data)
默认最大历元长度(m)为2,默认公差(r)为0.2。
检查返回的数据:
>>> sampen_of_series
[
(0, 2.140629540027156, 0.0028357991885715863)
(1, 2.162868347337613, 0.004903248034526253),
(
# Epoch length for max epoch(最大长度)
2,
# SampEn(样本熵的值)
2.123328492035711,
# Standard Deviation(标准偏差)
0.007596323621379352
),
]
如果序列数太少,很可能就会报错,比如n=3,mm>n/2, 就无法运行下去了。