% findLayersToReplace(lgraph) finds the single classification layer and the
% preceding learnable (fully connected or convolutional) layer of the layer
% graph lgraph.
function [learnableLayer,classLayer] = findLayersToReplace(lgraph)
if ~isa(lgraph,'nnet.cnn.LayerGraph')
error('Argument must be a LayerGraph object.')
end
% Get source, destination, and layer names.
src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');
% Find the classification layer. The layer graph must have a single
% classification layer.
isClassificationLayer = arrayfun(@(l) ...
(isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
lgraph.Layers);
if sum(isClassificationLayer) ~= 1
error('Layer graph must have a single classification layer.')
end
classLayer = lgraph.Layers(isClassificationLayer);
% Traverse the layer graph in reverse starting from the classification
% layer. If the network branches, throw an error.
currentLayerIdx = find(isClassificationLayer);
while true
if numel(currentLayerIdx) ~= 1
error('Layer graph must have a single learnable layer preceding the classification layer.')
end
currentLayerType = class(lgraph.Layers(currentLayerIdx));
isLearnableLayer = ismember(currentLayerType, ...
['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
if isLearnableLayer
learnableLayer = lgraph.Layers(currentLayerIdx);
return
end
currentDstIdx = find(layerNames(currentLayerIdx) == dst);
currentLayerIdx = find(src(currentDstIdx) == layerNames);
end
end
一键复制
编辑
Web IDE
原始数据
按行查看
历史