简介
由于课题原因,最近在学习torch对于轴承故障检测的相关知识,但第一步读取数据以及数据的划分就难到了我,在网上查找相关资料,也没有完整的代码,于是只能东拼西凑,修修改改,最后勉强凑出来一个可以一用的代码,放在这里保存一下。
读取函数
def open_data(bath_path,key_num):
path = bath_path + str(key_num) + ".mat"
str1 = "X" + "%03d"%key_num + "_DE_time"
data = scio.loadmat(path)
data = data[str1]
return data
数据处理函数
def deal_data(data,length,label):
data = np.reshape(data,(-1))
num = len(data)//length
data = data[0:num*length]
data = np.reshape(data,(num,length))
"""
最大值绝对值标准化(MaxAbs)即根据最大值的绝对值进行标准化,假设原转换的数据为x,新数据为x'
那么x'=x/|max|,其中max为x所在列的最大值。
MaxAbs方法跟Max-Min用法类似,也是将数据落入一定区间,但该方法的数据区间为[-1,1]。
MaxAbs也具有不破坏原有数据分布结构的特点,因此也可以用于稀疏数据、稀疏的CSR或CSC矩阵。
"""
maxabs_scaler = preprocessing.MaxAbsScaler()
data = maxabs_scaler.fit_transform(np.transpose(data,[1,0]))
data = np.transpose(data,[1,0])
label = np.ones((num,1))*label
return np.column_stack((data,label))
数据集划分函数
def split_data(data,split_rate):
length = len(data)
num1 = int(length*split_rate[0])
num2 = int(length*split_rate[1])
index1 = random.sample(range(num1),num1)
train = data[index1]
data = np.delete(data,index1,axis=0)
index2 = random.sample(range(num2),num2)
eval = data[index2]
test = np.delete(data,index2,axis=0)
return train,eval,test
数据加载函数
def load_data(num = 90,length = 1280,hp = [0,1,2],fault_diameter = [0.007,0.028,0.021],split_rate = [0.7,0.2,0.1]):
#num 为每类故障样本数量,length为样本长度,hp为负载大小,可取[0,1,2,3],fauit_diameter为故障程度,可取[0.007,0.014,0.021]
#split_rate为训练集,验证集和测试集划分比例。取值从0-1。
#bath_path1 为西储大学数据集中,正常数据的文件夹路径
#bath_path2 为西储大学数据集中,12K采频数据的文件夹路径
bath_path1 = r"F:\data\cwru\Normal Baseline Data\\"
bath_path2 = r"F:\data\cwru\12k Drive End Bearing Fault Data\\"
data_list = []
label = 0
# 正常数据
# path1 = bath_path1 + str(97+i) + ".mat"
# normal_data = scio.loadmat(path1)
# str1 = "X0" + str(97+i) + "_DE_time"
normal_data = open_data(bath_path1, 97)
data = deal_data(normal_data, length, label=label)
data_list.append(data)
for i in hp:
#故障数据
for j in fault_diameter:
if j == 0.007:
inner_num = 105
ball_num = 118
outer_num = 130
elif j == 0.014:
inner_num = 169
ball_num = 185
outer_num = 197
else:
inner_num = 209
ball_num = 222
outer_num = 234
inner_data = open_data(bath_path2,inner_num + i)
inner_data = deal_data(inner_data,length,label + 1)
data_list.append(inner_data)
ball_data = open_data(bath_path2,ball_num + i)
ball_data = deal_data(ball_data,length,label + 2)
data_list.append(ball_data)
outer_data = open_data(bath_path2,outer_num + i)
outer_data = deal_data(outer_data,length,label + 3)
data_list.append(outer_data)
label = label + 3
#保持每类数据数据量相同
num_list = []
for i in data_list:
num_list.append(len(i))
min_num = min(num_list)
if num > min_num:
print("每类数量超出上限,最大数量为:%d" %min_num)
min_num = min(num,min_num)
#划分训练集,验证集和测试集,并随机打乱顺序
train = []
eval = []
test = []
for data in data_list:
data = data[0:min_num,:]
a,b,c = split_data(data,split_rate)
train.append(a)
eval.append(b)
test.append(c)
train = np.reshape(train,(-1,length+1))
train = train[random.sample(range(len(train)),len(train))]
train_data = train[:,0:length]
train_label = torch.zeros(len(train),10).scatter_(1,torch.LongTensor(train[:,length]).unsqueeze(1),1)
eval = np.reshape(eval,(-1,length+1))
eval = eval[random.sample(range(len(eval)),len(eval))]
eval_data = eval[:,0:length]
eval_label = torch.zeros(len(eval), 10).scatter_(1, torch.LongTensor(eval[:,length]).unsqueeze(1),1)
test = np.reshape(test,(-1,length+1))
test = test[random.sample(range(len(test)),len(test))]
test_data = test[:,0:length]
test_label = torch.zeros(len(test), 10).scatter_(1, torch.LongTensor(test[:,length]).unsqueeze(1),1)
return train_data,train_label,eval_data,eval_label,test_data,test_label