两个半月环的数据集
为什么要写这个?
还能问什么啊???肯定是老师要求的啊啊啊啊啊啊。
但是从本质上来说,还是学习,学习都是自己的,应该是自发的、自主的,通过自己一步一步的做了,去学习了才能真正收获到东西,学习目的. 学习的目的是掌握知识,为自己的将来打好基础,作好铺垫。. 学习仅仅是一个提高自己过程。. 正如人们常说的"学以致用",学习就是为了将来的发展。. 因此头脑清醒的人,学习会有的放矢,目标非常明确。. 而头脑糊涂的人,学习则无的放矢,忙于应付,非常被动。. 接触新事物,只要抓住了机会,调整好心态,在哪里都可以学习到有价值的东西。说了这么多我竟然又在给自己灌毒鸡汤了。。。
好了好了,回归正题,在不拉回来就又要跑偏了~
数据集的构成
分为两个半月行数据,并且给每一个数据做一个标签,这样我们才能进行二分类问题,下一篇博客就使用这个双半月环数据放到神经网络最初的原型感知器perception中进行训练。
话不多说,先给大家看看生成好数据的样子:
生成的数据还可以进行调整,可以对数据进行方向旋转,垂直,水平距离等进行设置。
下面给出旋转后的样子:
好了,我们现在看到现象了,接下来就来说说是怎么实现的吧。
我们可以采用极坐标的思想半径radius = radius + [一个随机数]
, 这个随机数就是半月型的宽度,角度slope = [一个随机数]
,这个随机数是在0-180°
,注意在这里为了限定角度是个半月型,必须在代码里写上这样一句slope %= 180
,对slope用180取余数后再赋值给slope。
在食用下面代码之前,一定要看清楚这个定义:
"""生成双半月形的随机数,并且做好数据标签"""
"""origin_x,origin_y:第一个半圆的圆心
radius:s指定生成圆形的半径
width:半圆行的线宽
hor_distance:两个半圆行圆心之间的距离
ver_distance:两个半圆行圆心之间的垂直距离
sample_data:两个圆行生成的总的样本点数
slope:指定生成半圆行的旋转角度
positive_value:上半圆的label
negative_value:下半圆的label
"""
下面上代码:
import matplotlib.pyplot as plt
from random import uniform, shuffle, seed
import math
import numpy
import logging
class half_moon_train_test_split(object):
"""定义分离函数"""
def train_test_split(self, sample_set, label_set, test_rate=0.2):
if 1 <= test_rate <= 1:
logging.exception('the test rate must be between 0 and 1,we will assign test rate of 0.2')
test_rate = 0.2
data = numpy.column_stack((sample_set, label_set))
shuffle(data) # 打乱一次所有的数据
sample_count = int(len(data)*test_rate)
Test_sample = data[0:sample_count]
Train_sample = data[sample_count:]
#
# Train_x = Train_sample[:, 0:2]
# Train_label = Train_sample[:, 2]
#
# Test_x = Test_sample[:, 0:2]
# Test_label = Test_sample[:, 2]
return Train_sample[:, 0:2], Train_sample[:, 2], Test_sample[:, 0:2], Test_sample[:, 2]
class Sj_Hal_moonDataSet(object):
"""将half_moon_train_test_split类实例化成Sj_Hal_moonDataSet类的一个属性,以便外面好访问"""
moon_train_test_split = half_moon_train_test_split()
"""生成随机数种子"""
def random_seed(self, random_seed):
seed(random_seed)
"""生成双半月形的随机数,并且做好数据标签"""
"""origin_x,origin_y:第一个半圆的圆心
radius:s指定生成圆形的半径
width:半圆行的线宽
hor_distance:两个半圆行圆心之间的距离
ver_distance:两个半圆行圆心之间的垂直距离
sample_data:两个圆行生成的总的样本点数
slope:指定生成半圆行的旋转角度
positive_value:上半圆的label
negative_value:下半圆的label
"""
def double_moon(self, origin_x=1.0, origin_y=1.0, radius=4.0, width=1.0, hor_distance=4.0, ver_distance=0.0,
sample_data=1000, slope=180, positive_value=1, negative_value=-1):
each_m = sample_data // 2
slope %= 180 # 限制slope角度值在0~180°
n_sample = []
p_sample = []
for i in range(each_m):
radius_l = radius + uniform(0, width)
temp_angle = uniform(slope, slope+180)
p_point_x = origin_x + radius_l * math.cos(math.pi / 180 * temp_angle)
p_point_y = origin_y + radius_l * math.sin(math.pi / 180 * temp_angle)
p_sample.append([p_point_x, p_point_y, positive_value])
for i in range(each_m):
radius_l = radius + uniform(0, width)
temp_angle = uniform(slope+180, slope + 360)
n_point_x = origin_x + hor_distance + radius_l * math.cos(math.pi / 180 * temp_angle)
n_point_y = origin_y - ver_distance + radius_l * math.sin(math.pi / 180 * temp_angle)
n_sample.append([n_point_x, n_point_y, negative_value])
sample_point = p_sample + n_sample
shuffle(sample_point)
sample_point = numpy.array(sample_point)
return sample_point[:, 0:2], sample_point[:, 2]
if __name__ == '__main__':
random_seed = 52 # 指定随机数种子,用于复现随机数样本
makeData = Sj_Hal_moonDataSet()
makeData.random_seed(random_seed)
np_data, label = makeData.double_moon(origin_x=1, origin_y=1, sample_data=2000, ver_distance=-1, width=1, hor_distance=3, slope=10)
# 测试分离出训练集和测试集函数是否正确
Train_x, Train_label, Test_x, Test_label = makeData.moon_train_test_split.train_test_split(sample_set=np_data, label_set=label, test_rate=0.2)
print(len(Train_x), len(Test_x))
p_point_x = [np_data[i][0] for i in range(len(np_data)) if label[i] == 1]
p_point_y = [np_data[i][1] for i in range(len(np_data)) if label[i] == 1]
n_point_x = [np_data[i][0] for i in range(len(np_data)) if label[i] == -1]
n_point_y = [np_data[i][1] for i in range(len(np_data)) if label[i] == -1]
fig = plt.figure(num='HalfMoon', figsize=(8, 8))
ax1 = fig.add_subplot(111)
ax1.scatter(p_point_x, p_point_y, c='red')
ax1.scatter(n_point_x, n_point_y, c='blue')
plt.show()
print(np_data)
参考博客:
https://blog.csdn.net/qq_40454401/article/details/121505962?spm=1001.2014.3001.5501