西储大学(CWRU)轴承数据集的数据读取与划分

简介

由于课题原因,最近在学习torch对于轴承故障检测的相关知识,但第一步读取数据以及数据的划分就难到了我,在网上查找相关资料,也没有完整的代码,于是只能东拼西凑,修修改改,最后勉强凑出来一个可以一用的代码,放在这里保存一下。

读取函数

def open_data(bath_path,key_num):
    path = bath_path + str(key_num) + ".mat"
    str1 =  "X" + "%03d"%key_num + "_DE_time"
    data = scio.loadmat(path)
    data = data[str1]
    return data

数据处理函数

def deal_data(data,length,label):
    data = np.reshape(data,(-1))
    num = len(data)//length
    data = data[0:num*length]
    data = np.reshape(data,(num,length))

    """
        最大值绝对值标准化(MaxAbs)即根据最大值的绝对值进行标准化,假设原转换的数据为x,新数据为x'
        那么x'=x/|max|,其中max为x所在列的最大值。
        MaxAbs方法跟Max-Min用法类似,也是将数据落入一定区间,但该方法的数据区间为[-1,1]。
        MaxAbs也具有不破坏原有数据分布结构的特点,因此也可以用于稀疏数据、稀疏的CSR或CSC矩阵。
    """

    maxabs_scaler = preprocessing.MaxAbsScaler()
    data = maxabs_scaler.fit_transform(np.transpose(data,[1,0]))


    data = np.transpose(data,[1,0])
    label = np.ones((num,1))*label
    return np.column_stack((data,label))

数据集划分函数

def split_data(data,split_rate):
    length = len(data)
    num1 = int(length*split_rate[0])
    num2 = int(length*split_rate[1])

    index1 = random.sample(range(num1),num1)
    train = data[index1]
    data = np.delete(data,index1,axis=0)
    index2 = random.sample(range(num2),num2)
    eval = data[index2]
    test = np.delete(data,index2,axis=0)
    return train,eval,test

数据加载函数

def load_data(num = 90,length = 1280,hp = [0,1,2],fault_diameter = [0.007,0.028,0.021],split_rate = [0.7,0.2,0.1]):
    #num 为每类故障样本数量,length为样本长度,hp为负载大小,可取[0,1,2,3],fauit_diameter为故障程度,可取[0.007,0.014,0.021]
    #split_rate为训练集,验证集和测试集划分比例。取值从0-1。
    #bath_path1 为西储大学数据集中,正常数据的文件夹路径
    #bath_path2 为西储大学数据集中,12K采频数据的文件夹路径
    bath_path1 = r"F:\data\cwru\Normal Baseline Data\\"
    bath_path2 = r"F:\data\cwru\12k Drive End Bearing Fault Data\\"
    data_list = []
    label = 0
    # 正常数据
    # path1 = bath_path1 + str(97+i) + ".mat"
    # normal_data = scio.loadmat(path1)
    # str1 = "X0" + str(97+i) + "_DE_time"
    normal_data = open_data(bath_path1, 97)
    data = deal_data(normal_data, length, label=label)
    data_list.append(data)
    for i in hp:
        #故障数据
        for j in fault_diameter:
            if j == 0.007:
                inner_num = 105
                ball_num = 118
                outer_num = 130
            elif j == 0.014:
                inner_num = 169
                ball_num = 185
                outer_num = 197
            else:
                inner_num = 209
                ball_num = 222
                outer_num = 234

            inner_data = open_data(bath_path2,inner_num + i)
            inner_data = deal_data(inner_data,length,label + 1)
            data_list.append(inner_data)

            ball_data = open_data(bath_path2,ball_num + i)
            ball_data = deal_data(ball_data,length,label + 2)
            data_list.append(ball_data)

            outer_data = open_data(bath_path2,outer_num + i)
            outer_data = deal_data(outer_data,length,label + 3)
            data_list.append(outer_data)

        label = label + 3

    #保持每类数据数据量相同
    num_list = []
    for i in data_list:
        num_list.append(len(i))
    min_num = min(num_list)

    if num > min_num:
        print("每类数量超出上限,最大数量为:%d" %min_num)

    min_num = min(num,min_num)
    #划分训练集,验证集和测试集,并随机打乱顺序
    train = []
    eval = []
    test = []
    for data in data_list:
        data = data[0:min_num,:]
        a,b,c = split_data(data,split_rate)
        train.append(a)
        eval.append(b)
        test.append(c)

    train = np.reshape(train,(-1,length+1))
    train = train[random.sample(range(len(train)),len(train))]
    train_data = train[:,0:length]
    train_label = torch.zeros(len(train),10).scatter_(1,torch.LongTensor(train[:,length]).unsqueeze(1),1)

    eval = np.reshape(eval,(-1,length+1))
    eval = eval[random.sample(range(len(eval)),len(eval))]
    eval_data = eval[:,0:length]
    eval_label = torch.zeros(len(eval), 10).scatter_(1, torch.LongTensor(eval[:,length]).unsqueeze(1),1)

    test = np.reshape(test,(-1,length+1))
    test = test[random.sample(range(len(test)),len(test))]
    test_data = test[:,0:length]
    test_label = torch.zeros(len(test), 10).scatter_(1, torch.LongTensor(test[:,length]).unsqueeze(1),1)

    return train_data,train_label,eval_data,eval_label,test_data,test_label

参考文献

参考资料1
参考资料2
参考资料3

好的,以下是使用PyTorch进行训练的代码,以CWRU轴承故障振动信号数据库为训练集,对电机轴承进行一维振动信号的故障诊断: 首先,需要导入PyTorch和相关库: ``` import torch import torch.nn as nn import torch.optim as optim import numpy as np import pandas as pd ``` 接着,需要读取CWRU轴承故障振动信号数据库的数据。我们可以使用NumPy库读取CSV文件中的数据,并将其转换为PyTorch张量: ``` data = pd.read_csv('data.csv', header=None) data = np.array(data) data = torch.from_numpy(data).float() ``` 假设我们的数据集包括了轴承的振动信号以及相应的故障类型标签,例如正常、内圈故障、外圈故障等。我们可以将数据集划分为训练集和测试集: ``` train_data = data[:8000, :-1] train_labels = data[:8000, -1].long() test_data = data[8000:, :-1] test_labels = data[8000:, -1].long() ``` 在训练之前,我们需要定义一个CNN模型。下面是一个简单的模型,包括两个卷积层和两个全连接层: ``` class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2) self.conv2 = nn.Conv1d(16, 32, kernel_size=5, stride=2) self.fc1 = nn.Linear(2880, 256) self.fc2 = nn.Linear(256, 10) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = nn.functional.max_pool1d(x, kernel_size=2, stride=2) x = self.conv2(x) x = nn.functional.relu(x) x = nn.functional.max_pool1d(x, kernel_size=2, stride=2) x = x.view(x.size(0), -1) x = self.fc1(x) x = nn.functional.relu(x) x = self.fc2(x) return x ``` 然后,我们可以定义优化器和损失函数。这里我们选择Adam优化器和交叉熵损失函数: ``` model = CNN() optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() ``` 接下来是训练模型的过程。我们将训练集分批进行训练,并在测试集上进行评估: ``` epochs = 10 batch_size = 32 for epoch in range(epochs): running_loss = 0.0 for i in range(0, len(train_data), batch_size): inputs = train_data[i:i+batch_size].unsqueeze(1) labels = train_labels[i:i+batch_size] optimizer.zero_grad() outputs = model(inputs)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值