matlab 对数据进行分类,使用深度学习对文本数据进行分类

导入数据

导入工厂报告数据。该数据包含已标注的工厂事件文本描述。要将文本数据作为字符串导入,请将文本类型指定为 'string'。

filename = "factoryReports.csv";

data = readtable(filename,'TextType','string');

head(data)

ans=8×5 table

Description Category Urgency Resolution Cost

_____________________________________________________________________ ____________________ ________ ____________________ _____

"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45

"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35

"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200

"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352

"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55

"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371

"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441

"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38

此示例的目标是按 Category 列中的标签对事件进行分类。要将数据划分到各个类,请将这些标签转换为分类。

data.Category = categorical(data.Category);

使用直方图查看数据中类的分布。

figure

histogram(data.Category);

xlabel("Class")

ylabel("Frequency")

title("Class Distribution")

c441367bdc734583435609d948f37e19.png

下一步是将其划分为训练集和验证集。将数据划分为训练分区和用于验证和测试的保留分区。将保留百分比指定为 20%。

cvp = cvpartition(data.Category,'Holdout',0.2);

dataTrain = data(training(cvp),:);

dataValidation = data(test(cvp),:);

从分区后的表中提取文本数据和标签。

textDataTrain = dataTrain.Description;

textDataValidation = dataValidation.Description;

YTrain = dataTrain.Category;

YValidation = dataValidation.Category;

要检查是否已正确导入数据,请使用文字云将训练文本数据可视化。

figure

wordcloud(textDataTrain);

title("Training Data")

b85e69a77a37314752093477b4073a7a.png

预处理文本数据

创建一个对文本数据进行分词和预处理的函数。在示例末尾列出的函数 preprocessText 执行以下步骤:

使用 tokenizedDocument 对文本进行分词。

使用 lower 将文本转换为小写。

使用 erasePunctuation 删除标点符号。

使用 preprocessText 函数预处理训练数据和验证数据。

documentsTrain = preprocessText(textDataTrain);

documentsValidation = preprocessText(textDataValidation);

查看前几个预处理的训练文档。

documentsTrain(1:5)

ans =

5×1 tokenizedDocument:

9 tokens: items are occasionally getting stuck in the scanner spools

10 tokens: loud rattling and banging sounds are coming from assembler pistons

10 tokens: there are cuts to the power when starting the plant

5 tokens: fried capacitors in the assembler

4 tokens: mixer tripped the fuses

将文档转换为序列

要将文档输入到 LSTM 网络中,请使用单词编码将文档转换为数值索引序列。

要创建单词编码,请使用 wordEncoding 函数。

enc = wordEncoding(documentsTrain);

下一个转换步骤是填充和截断文档,使全部文档的长度相同。trainingOptions 函数提供了自动填充和截断输入序列的选项。但是,这些选项不太适合单词向量序列。请改为手动填充和截断序列。如果对单词向量序列进行左填充和截断,训练效果可能会得到改善。

要填充和截断文档,请先选择目标长度,然后对长于它的文档进行截断,并对短于它的文档进行左填充。为获得最佳结果,目标长度应该较短,但又不至于丢弃大量数据。要找到合适的目标长度,请查看训练文档长度的直方图。

documentLengths = doclength(documentsTrain);

figure

histogram(documentLengths)

title("Document Lengths")

xlabel("Length")

ylabel("Number of Documents")

aa43d340ee5a6b13a6c02508247ce4a3.png

大多数训练文档的词数少于 10 个。将此数字用作截断和填充的目标长度。

使用 doc2sequence 将文档转换为数值索引序列。要对长度为 10 的序列进行截断或左填充,请将 'Length' 选项设置为 10。

sequenceLength = 10;

XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);

XTrain(1:5)

ans=5×1 cell array

{1×10 double}

{1×10 double}

{1×10 double}

{1×10 double}

{1×10 double}

使用相同选项将验证文档转换为序列。

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);

创建和训练 LSTM 网络

定义 LSTM 网络架构。要将序列数据输入到网络中,请包含一个序列输入层并将输入大小设置为 1。接下来,包含一个维度为 50 且与单词编码具有相同单词数的单词嵌入层。然后,包含一个 LSTM 层并将隐含单元个数设置为 80。要将该 LSTM 层用于“序列到标签”分类问题,请将输出模式设置为 'last'。最后,添加一个大小与类数相同的全连接层、一个 softmax 层和一个分类层。

inputSize = 1;

embeddingDimension = 50;

numHiddenUnits = 80;

numWords = enc.NumWords;

numClasses = numel(categories(YTrain));

layers = [ ...

sequenceInputLayer(inputSize)

wordEmbeddingLayer(embeddingDimension,numWords)

lstmLayer(numHiddenUnits,'OutputMode','last')

fullyConnectedLayer(numClasses)

softmaxLayer

classificationLayer]

layers =

6x1 Layer array with layers:

1 '' Sequence Input Sequence input with 1 dimensions

2 '' Word Embedding Layer Word embedding layer with 50 dimensions and 423 unique words

3 '' LSTM LSTM with 80 hidden units

4 '' Fully Connected 4 fully connected layer

5 '' Softmax softmax

6 '' Classification Output crossentropyex

指定训练选项

指定训练选项:

使用 Adam 求解器进行训练。

指定小批量大小为 16。

每轮训练都会打乱数据。

通过将 'Plots' 选项设置为 'training-progress',监控训练进度。

使用 'ValidationData' 选项指定验证数据。

通过将 'Verbose' 选项设置为 false,隐藏详细输出。

默认情况下,如果有 GPU 可用,trainNetwork 就会使用 GPU(需要 Parallel Computing Toolbox™ 和具有 3.0 或更高计算能力的支持 CUDA® 的 GPU)。否则将使用 CPU。要手动指定执行环境,请使用 trainingOptions 的 'ExecutionEnvironment' 名称-值对组参数。在 CPU 上进行训练所需的时间要明显长于在 GPU 上进行训练所需的时间。

options = trainingOptions('adam', ...

'MiniBatchSize',16, ...

'GradientThreshold',2, ...

'Shuffle','every-epoch', ...

'ValidationData',{XValidation,YValidation}, ...

'Plots','training-progress', ...

'Verbose',false);

使用 trainNetwork 函数训练 LSTM 网络。

net = trainNetwork(XTrain,YTrain,layers,options);

21611c3ac391f372c3b1f54542db9860.png

使用新数据进行预测

对三个新报告的事件类型进行分类。创建包含新报告的字符串数组。

reportsNew = [ ...

"Coolant is pooling underneath sorter."

"Sorter blows fuses at start up."

"There are some very loud rattling sounds coming from the assembler."];

使用与预处理训练文档相同的步骤来预处理文本数据。

documentsNew = preprocessText(reportsNew);

使用 doc2sequence 将文本数据转换为序列,所用选项与创建训练序列时的选项相同。

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

使用经过训练的 LSTM 网络对新序列进行分类。

labelsNew = classify(net,XNew)

labelsNew = 3×1 categorical

Leak

Electronic Failure

Mechanical Failure

预处理函数

函数 preprocessText 执行以下步骤:

使用 tokenizedDocument 对文本进行分词。

使用 lower 将文本转换为小写。

使用 erasePunctuation 删除标点符号。

function documents = preprocessText(textData)

% Tokenize the text.

documents = tokenizedDocument(textData);

% Convert to lowercase.

documents = lower(documents);

% Erase punctuation.

documents = erasePunctuation(documents);

end

  • 1
    点赞
  • 3
    收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:数字20 设计师:CSDN官方博客 返回首页
评论
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值