Matconvnet框架中实现欧式距离损失函数代码。主要涉及的模块是vl_nnsoftmaxloss函数和processEpoch函数。vl_nnsoftmaxloss函数中实现了自己的欧式距离损失函数代码,相关的算法推导请见欧式距离损失函数的算法推导。processEpoch函数中的[im,labels] = params.getBatch(params.imdb, batch)被删掉,采用自己编写的图像样本读取代码读取图像样本。
1. vl_nnsoftmaxloss
function y =vl_nnsoftmaxloss(x,c,dzdy)
switch class(x)
case 'single', cast = @(z) single(z) ;
case 'double', cast = @(z) double(z) ;
end
v_c=squeeze(c);
v_size_c=size(c);
v_x=squeeze(x)
% numClasses=3;
% trainLabels=c;
% groundTruth=bsxfun(@eq,repmat(trainLabels,numClasses,1),(1:1:numClasses)');
% groundTruth=[ 2.3 1.7
% 0.3 0.8
% 1.1 0.2]
% labels=[0.68 0.49 0.30];
groundTruth=zeros(3,v_size_c(2));
groundTruth(1,:)=v_c;
groundTruth=[0.68 0.49 0.30 0.51 0.523 0.447 0.395
0 0 0 0.32 0.179 0.265 0.252
0 0 0 0 0 0 0 ];
y_error=groundTruth-v_x;
% [sw_error,sh_error]=size(y_error);
% y_error_out=reshape(y_error,[1,1,sw_error,sh_error])
% new_max_x=max(v_x);
% new_x=bsxfun(@minus,v_x,new_max_x);
% new_exp=exp(new_x);
% new_sum_exp=sum(new_exp);
% for k=1:numel(new_sum_exp)
% % new_sita(:,k)=new_exp(:,k)/new_sum_exp(k);
% new_sita(:,k)=new_exp(:,k)/new_sum_exp(k);
% new_sita(v_c(k),k)=new_sita(v_c(k),k)-1;
% end
% new_sita=new_sita
new_sita=-y_error
[sw,sh]=size(new_sita);
v_y_my=reshape(new_sita,[1,1,sw,sh]);
%X = X + 1e-6 ;
sz = [size(x,1) size(x,2) size(x,3) size(x,4)];
if numel(c)== sz(4)
% one label per image
c =reshape(c, [1 1 1 sz(4)]) ;
end
if size(c,1)== 1 & size(c,2) == 1
c =repmat(c, [sz(1) sz(2)]) ;
end
% v_c1=c
% one label per spatial location
sz_ = [size(c,1) size(c,2) size(c,3)size(c,4)] ;
assert(isequal(sz_, [sz(1) sz(2) sz_(3)sz(4)])) ;
assert(sz_(3)==1 | sz_(3)==2) ;
% v_sz_=sz_
% class c = 0 skips a spatial location
mass = cast(c(:,:,1,:) > 0) ;
v_mass=squeeze(mass);
if(~isempty(find(v_mass~=1)))
disp('*******In the vl_nnsoftmaxloss function!*****')
pause
end
% pause
% v_mass=mass
if sz_(3) ==2
% the second channel of c (if present) is used as weights
mass= mass .* c(:,:,2,:) ;
c(:,:,2,:) = [] ;
disp('*******unexpected behaviro***********************')
pause
end
% convert to indexes
c = c - 1 ;
c_ = 0:numel(c)-1 ;
c_ = 1 + ...
mod(c_, sz(1)*sz(2)) + ...
(sz(1)*sz(2)) * max(c(:), 0)' + ...
(sz(1)*sz(2)*sz(3)) * floor(c_/(sz(1)*sz(2))) ;
% compute softmaxloss
xmax = max(x,[],3) ;
ex = exp(bsxfun(@minus, x, xmax)) ;
%n = sz(1)*sz(2) ;
if nargin<= 2
% disp('In the branch 1!!')
t =xmax + log(sum(ex,3)) - reshape(x(c_), [sz(1:2) 1 sz(4)]) ;
y =sum(sum(sum(mass .* t,1),2),4) ;
y1=sum(sum(y_error.^2));
y=log(y1)
v_y1=y1
else
y =bsxfun(@rdivide, ex, sum(ex,3)) ;
y(c_)= y(c_) - 1;
% y = bsxfun(@times, y, bsxfun(@times, mass, dzdy)) ;
y=v_y_my;
v_y2=squeeze(y)
v_y3=y;
end
2. processEpoch
function [net,state] = processEpoch(net, state, params, mode)
ifisempty(state) || isempty(state.momentum)
for i = 1:numel(net.layers)
for j =1:numel(net.layers{i}.weights)
state.momentum{i}{j} = 0 ;
end
end
end
% move CNN to GPU as needed
numGpus = numel(params.gpus) ;
if numGpus>= 1
net =vl_simplenn_move(net, 'gpu') ;
for i = 1:numel(state.momentum)
for j = 1:numel(state.momentum{i})
state.momentum{i}{j} = gpuArray(state.momentum{i}{j}) ;
end
end
end
% disp('**********TP 2 of processEpochfunction*********** ')
if numGpus> 1
parserv = ParameterServer(params.parameterServer) ;
vl_simplenn_start_parserv(net, parserv) ;
else
parserv = [] ;
end
% disp('**********TP 3 of processEpochfunction*********** ')
% profile
ifparams.profile
if numGpus <= 1
profile clear ;
profile on ;
else
mpiprofile reset ;
mpiprofile on ;
end
end
% disp('**********TP 4 of processEpochfunction*********** ')
subset = params.(mode) ;
num = 0 ;
stats.num = 0 ; %return something even if subset = []
stats.time = 0 ;
adjustTime = 0 ;
res = [] ;
error = [] ;
start = tic ;
% params.batchSize=128;
% disp('**********TP 5 of processEpochfunction*********** ')
fort=1:params.batchSize:numel(subset)
fprintf('%s: epoch %02d: %3d/%3d:', mode, params.epoch, ...
fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ;
batchSize = min(params.batchSize, numel(subset) - t + 1) ;
% disp('**********TP 6 of processEpochfunction*********** ')
for s=1:params.numSubBatches
% get this image batch and prefetch the next
batchStart = t + (labindex-1) + (s-1) * numlabs ;
batchEnd = min(t+params.batchSize-1, numel(subset)) ;
batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;
num= num + numel(batch) ;
if numel(batch) == 0, continue ; end
% disp('**********TP 7 of processEpochfunction*********** ')
% v_params_imdb=params.imdb
% v_params_imdb_images=params.imdb.images
% v_params_imdb_images_name=params.imdb.images.name
% v_batch=batch
% pause
% [im, labels] = params.getBatch(params.imdb, batch) ;
d=dir(fullfile('sample_imgs'));
for i=3:numel(d)
v_d=d(i).name;
im(:,:,:,i-2)=single(imread(['sample_imgs/',v_d]));
end
labels=[0.68 0.49 0.30 0.51 0.523 0.447 0.395];
% v_size_im=size(im)
% v_labels=labels
%
% figure
% subplot(2,4,1)
% imshow(uint8(im(:,:,:,1)));
% subplot(2,4,2)
% imshow(uint8(im(:,:,:,2)));
% subplot(2,4,3)
% imshow(uint8(im(:,:,:,3)));
% subplot(2,4,4)
% imshow(uint8(im(:,:,:,4)));
% subplot(2,4,5)
% imshow(uint8(im(:,:,:,5)));
% subplot(2,4,6)
% imshow(uint8(im(:,:,:,6)));
% subplot(2,4,7)
% imshow(uint8(im(:,:,:,7)));
% pause
% v_res_1=res
% v_dzdy=dzdy
% v_size_params_im=size(im)
% disp('**********TP 8 of processEpochfunction*********** ')
% pause
% if params.prefetch
% if s == params.numSubBatches
% batchStart = t + (labindex-1) + params.batchSize;
% batchEnd = min(t+2*params.batchSize-1,numel(subset)) ;
% else
% batchStart = batchStart + numlabs ;
% end
% nextBatch = subset(batchStart : params.numSubBatches * numlabs :batchEnd) ;
% params.getBatch(params.imdb, nextBatch) ;
% end
%
% if numGpus >= 1
% im = gpuArray(im) ;
% end
if strcmp(mode, 'train')
dzdy = 1 ;
evalMode = 'normal' ;
else
dzdy = [] ;
evalMode = 'test' ;
end
net.layers{end}.class = labels ;
% v_size_of_im=size(im)
res= vl_simplenn(net, im, dzdy, res) ;
% v_res_2=res
% res = vl_simplenn(net, im, dzdy, res, ...
% 'accumulate', s ~= 1,...
% 'mode', evalMode, ...
% 'conserveMemory',params.conserveMemory, ...
% 'backPropDepth',params.backPropDepth, ...
% 'sync', params.sync, ...
% 'cudnn', params.cudnn,...
% 'parameterServer',parserv, ...
% 'holdOn', s <params.numSubBatches) ;
% accumulate errors
error = sum([error, [...
sum(double(gather(res(end).x))) ;
reshape(params.errorFunction(params, labels, res),[],1) ; ]],2) ;
end
% accumulate gradient
if strcmp(mode, 'train')
if ~isempty(parserv),parserv.sync() ; end
[net, res, state] = accumulateGradients(net, res, state, params,batchSize, parserv) ;
end
% get statistics
time= toc(start) + adjustTime ;
batchTime = time - stats.time ;
stats= extractStats(net, params, error / num) ;
stats.num = num ;
stats.time = time ;
currentSpeed = batchSize / batchTime ;
averageSpeed = (t + batchSize - 1) / time ;
if t == 3*params.batchSize + 1
% compensate for the first three iterations, which areoutliers
adjustTime = 4*batchTime - time ;
stats.time= time + adjustTime ;
end
fprintf(' %.1f (%.1f) Hz',averageSpeed, currentSpeed) ;
for f =setdiff(fieldnames(stats)', {'num', 'time'})
f =char(f) ;
fprintf(' %s: %.3f', f,stats.(f)) ;
end
fprintf('\n') ;
% collect diagnostic statistics
if strcmp(mode, 'train') &¶ms.plotDiagnostics
switchFigure(2) ; clf ;
diagn = [res.stats] ;
diagnvar = horzcat(diagn.variation) ;
diagnpow = horzcat(diagn.power) ;
subplot(2,2,1) ; barh(diagnvar) ;
set(gca,'TickLabelInterpreter', 'none', ...
'YTick', 1:numel(diagnvar), ...
'YTickLabel',horzcat(diagn.label),...
'YDir', 'reverse', ...
'XScale', 'log', ...
'XLim', [1e-5 1], ...
'XTick', 10.^(-5:1)) ;
grid on ;
subplot(2,2,2) ; barh(sqrt(diagnpow)) ;
set(gca,'TickLabelInterpreter', 'none', ...
'YTick', 1:numel(diagnpow), ...
'YTickLabel',{diagn.powerLabel},...
'YDir', 'reverse', ...
'XScale', 'log', ...
'XLim', [1e-5 1e5], ...
'XTick', 10.^(-5:5)) ;
grid on ;
subplot(2,2,3); plot(squeeze(res(end-1).x)) ;
drawnow ;
end
end
% Save back to state.
state.stats.(mode) = stats ;
ifparams.profile
if numGpus <= 1
state.prof.(mode) = profile('info') ;
profile off ;
else
state.prof.(mode) = mpiprofile('info');
mpiprofile off ;
end
end
if~params.saveMomentum
state.momentum = [] ;
else
for i = 1:numel(state.momentum)
for j = 1:numel(state.momentum{i})
state.momentum{i}{j} = gather(state.momentum{i}{j}) ;
end
end
end
net = vl_simplenn_move(net, 'cpu') ;