CREST代码阅读

CREST代码阅读


在成功运行CREST后说一下自己对代码的有关理解。

特征提取网络
代码首先是利用VGG16构建一个特征提取网络,其中构建特征提取网络文件是initVGG16Net.m,可以看到作者在VGG16中取了conv1_1至conv4_3,可能考虑到太多的pooling会丢失过多结构信息,导致跟踪效果变差,所以去掉了conv4_3中的pooling层和conv3_3的pooling层,因此在该特征提取网络中共计2个pooling层。

在进行取跟踪目标patch块,减均值等操作后将patch块送入VGG中进行提取特征,并将conv4_3经过relu后的特征作为在线跟踪网络的输入。其中patch块如下图所示。


在这里插入图片描述

经过PCA降维减少通道数,能够给快收敛。根据跟踪目标位置形成一个高斯形label,用在训练在线跟踪网络的label。其中高斯label如下所示。


在这里插入图片描述

在线跟踪网络
构建的在线跟踪网络主要文件是initNet.m。对应论文中的图4.
在这里插入图片描述
分为三部分,第一部分对应论文中的base网络,是用于模拟DCF操作,输入是当前预测帧的conv4_3特征图,其网络结构是一个和目标尺寸一样大的卷积核,使用这个大卷积核得到一个响应图。其核心代码是

rw=ceil(target_sz1(2)/2);
rh=ceil(target_sz1(1)/2);
fw=2*rw+1;
fh=2*rh+1;
net_online.addLayer('conv11', dagnn.Conv('size', [fw,fh,channel,1],...
    'hasBias', true, 'pad',...
[rh,rh,rw,rw], 'stride', [1,1]), 'input1', 'conv_11', {'conv11_f', 'conv11_b'});

第二部分是构建一个空域残差网络,对应论文Spatial residual,输入是当前预测帧的conv4_3特征图,网络结构是3个1×1卷积核构成的,经过三个1×1卷积后生成一个响应图,核心代码是

net_online.addLayer('conv21', dagnn.Conv('size', [1,1,channel,channel],...
    'hasBias', true, 'pad',...
[0,0,0,0], 'stride', [1,1]), 'input1', 'conv_21', {'conv21_f', 'conv21_b'});
net_online.addLayer('conv22', dagnn.Conv('size', [1,1,channel,channel],...
    'hasBias', true, 'pad',...
[0,0,0,0], 'stride', [1,1]), 'relu_1', 'conv_22', {'conv22_f', 'conv22_b'});
net_online.addLayer('conv23', dagnn.Conv('size', [1,1,channel,1],...
    'hasBias', true, 'pad',...
[0,0,0,0], 'stride', [1,1]), 'relu_2', 'conv_23', {'conv23_f', 'conv23_b'});

第三部分是构建一个空域残差网络,对应论文Temporal residual,输入是第一帧的conv4_3特征图,网络结构是1个1×1卷积核,经过该层后,输出得到一个响应图,核心代码是

net_online.addLayer('conv23', dagnn.Conv('size', [1,1,channel,1],...
    'hasBias', true, 'pad',...
[0,0,0,0], 'stride', [1,1]), 'relu_2', 'conv_23', {'conv23_f', 'conv23_b'});

在得到三个不同的响应图后进行合并得到最终响应,核心代码是

net_online.addLayer('sum1',dagnn.Sum(),{'conv_23','conv_11','conv_31'},'sum_1');

在训练过程中的loss函数表达式是
在这里插入图片描述
核心代码是

function Y = L2normLoss(X, C, dzdy)
assert(numel(X) == numel(C));
n = size(X,1) * size(X,2);
if nargin <= 2     
	  diff=X(:)-C(:);
	  idx=find(abs(diff)<0.1);
	  X(idx)=C(idx);    
	  Y = sum((exp(C(:)).*(X(:)-C(:))).^2) ;  
else
	  assert(numel(dzdy) == 1);    
	  diff=X(:)-C(:);
	  idx=find(abs(diff)<0.1);
	  X(idx)=C(idx);    
	  Y = reshape((dzdy / n) * 2*(exp(C(:)).*(X(:)-C(:))), size(X));  
end
end

尺度估计
尺度估计在scale_estimation.m中进行,分别在1、0.95、1.05三个尺度计算响应图最大响应,得到效果最好的尺度,并更新目标尺寸,更新目标尺寸代码为。

target_szU=0.4target_sz+0.6round(target_sz*scale(id));

在线更新

在线更新阶段主要每隔2帧进行一次模型更新,每次更新训练2个epoch,其核心代码所在函数是cnn_train_dag_update.m。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
轴承故障诊断是机械故障诊断中的一个重要方向,常用的方法是通过振动信号分析,利用机器学习算法进行故障诊断。以下是一个简单的基于Python的轴承故障诊断代码示例: ```python import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import confusion_matrix # 加载数据 data = pd.read_csv("bearing_dataset.csv") # 特征提取 features = ['RMS', 'crest_factor', 'kurtosis', 'skewness'] X = data[features] y = data['fault'] # 数据集划分 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # 随机森林分类器 rf_clf = RandomForestClassifier(n_estimators=100, random_state=42) rf_clf.fit(X_train, y_train) y_pred = rf_clf.predict(X_test) # 混淆矩阵 conf_mat = confusion_matrix(y_test, y_pred) print(conf_mat) # 可视化 fig, ax = plt.subplots(figsize=(5, 5)) ax.imshow(conf_mat) ax.grid(False) ax.set_xlabel('Predicted outputs', fontsize=12, color='black') ax.set_ylabel('Actual outputs', fontsize=12, color='black') ax.xaxis.set(ticks=range(2)) ax.yaxis.set(ticks=range(2)) for i in range(2): for j in range(2): ax.text(j, i, conf_mat[i, j], ha='center', va='center', color='white') plt.show() ``` 这个代码示例中,我们使用了pandas库来加载数据,使用sklearn库进行数据集划分、随机森林分类器的训练和预测,使用matplotlib库进行混淆矩阵的可视化。当然,具体实现还需要根据数据集的具体情况进行相应的调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值