项目背景
- 阿尔茨海默病(Alzheimer’s Disease, AD)是一种起病隐匿的、进行性发展的神经系统退行性疾病。该疾病病程是一个不可逆过程,迄今为止没有有效的治疗药物和手段,因此与癌症一样,阿尔茨海默病治疗的关键是早期筛查诊断,并在疾病早期对阿尔兹海默症进行干预和延缓。
- 临床上,通常借助蒙特利尔认知评估(Montreal Cognitive Assessmen, MoCA)测试来初步筛查阿尔茨海默病。受限于纸质MoCA测试临床测试效率低、测试结果严重依赖医生的主观判断、经验水平等诸多不利因素,当前阿尔茨海默病的早期筛查存在效率低下、人为不确定因素掺杂较多、单位时间筛查人数较少等缺陷。
- 值得注意的是,截止2019年中国阿尔兹海默症患病人数超过1000万,是全球阿尔兹海默症患者数量最多的国家,由此造成的不利社会影响正在愈演愈烈。因此,设计、构建一款测试方式上简洁、评价更加客观科学、置信度更高的阿尔茨海默症筛查系统,对缓解社会实质性矛盾意义深远。
- 中国阿尔兹海默症不同年龄段合并患病率、未来30年中国阿尔兹海默症发展趋势(万人)
技术方案
- 项目实现上,以基于语音识别、图像处理的人工智能交互方式,实现MoCA筛查量表基本功能,部署在手机、平板电脑等移动终端。
- 语音交互上,借助百度语音识别(API)转换技术,系统集成语音交互功能模块,实现人机交互中语音的精确采集、音轨录入、智能识别、反馈播放等功能。
- 图像处理上,借助百度PaddlePaddle、PaddleLite框架平台,利用医学图像处理深度神经网络技术,用自建Sketch4IAS数据集图像特性来构建、训练图像识别深度神经网络模型。
- 网络架构可视化如下:
-
系统涵盖基于语音的可交互单元测试模块(IUTM)、基于真值的智能分析模块(IAM),两个模块均通过深度学习技术开发。
-
IUTM用于交互式地采集测试数据。
-
IAM负责对多模态数据进行智能化处理和分析。
-
整个过程中,重点实现了基于语音识别的智能交互系统、基于事实的图像识别分类算法,接下来分别进行详细介绍、代码实现。
一、基于语音识别的智能交互系统
项目实现中的语音交互过程主要为:
- 设计、处理语音识别API接口
- 制作、搭建背景提示音
- 实时采集测试语音数据
- 调用语音识别API接口,云端实施"语音-文本"转换,将结果返回移动终端
- 逻辑判断,返回结果
1. 设计、处理语音识别API接口
- 注册百度云账号、登录并创建应用,获取"API Key""Secret Key"值。根据应用场景配置语音识别API参数,样例图如下。
globalData: {
baiduai:{
apiKey: 'Your_apiKey ',
secretKey: 'Your_secretKey',
url: "https://aip.baidubce.com/oauth/2.0/token"
},
baiduyuyin:{
apiKey: 'Your_apiKey',
secretKey: 'Your_secretKey',
url: 'https://openapi.baidu.com/oauth/2.0/token'
},
}
-
获取AccessToken。在上阶段基础上,系统分配此接口相关凭证–“AppID”“API Key”“Secret Key”–用以生成AccessToken,此为鉴权认证。
-
项目将鉴认证关键信息嵌入对应环节。该设计确保了语音识别前AccessToken的获取,避免了复杂逻辑的异步操作。
getBaiduAiAccessToken: function(){
var that = this;
var baiduai = that.globalData.baiduai;
wx.request({
url: baiduai.url,
data: {
grant_type: 'client_credentials',
client_id: baiduai.apiKey,
client_secret: baiduai.secretKey
},
method: 'POST',
header: {
'content-type': 'application/x-www-form-urlencoded'
},
success(res) {
wx.setStorageSync("baidu_ai_access_token", res.data.access_token);
wx.setStorageSync("baidu_ai_time", new Date().getTime());
}
})
},
getBaiduYuyinAccessToken: function(){
var that = this;
var baiduyuyin = that.globalData.baiduyuyin;
wx.request({
url: baiduyuyin.url,
data: {
grant_type: 'client_credentials',
client_id: baiduyuyin.apiKey,
client_secret: baiduyuyin.secretKey
},
method: 'POST',
header: {
'content-type': 'application/x-www-form-urlencoded'
},
success(res) {
wx.setStorageSync("baidu_yuyin_access_token", res.data.access_token);
wx.setStorageSync("baidu_yuyin_time", new Date().getTime());
}
})
},
2. 制作、搭建背景提示音
- 用语音引导开启每个测试界面。
- 借助wx.createInnerAudioContext()方法来实现语音起始交互。
- 设置innerAudioContext.autoplay = true,来实现智能自动语音交互。
const innerAudioContext = wx.createInnerAudioContext()
onLoad: function(options){
var myDate = new Date()
this.setData({start_ms: myDate.getTime()})
innerAudioContext.autoplay = true
innerAudioContext.src = 'pages/video/X.mp3'
innerAudioContext.onPlay(() => {
console.log('开始播放')
})
innerAudioContext.onError((res) => {
console.log(res.errMsg)
console.log(res.errCode)
})
},
- 通过innerAudioContext.stop()来实现页面转换时语音的自动停止播放。
gotoX:function()
{
innerAudioContext.stop();
app.globalData.global_sum += that.data.count;
wx.redirectTo({
url: '../page_X/page_X',
})
},
3. 实时采集测试语音数据
-
用简洁的手势动作–轻触录音、移开结束–来实施实时语音采集。细化讲:
-
分析应用场景,设置录音参数如下。
sampleRate: 16000,
numberOfChannels: 1,
encodeBitRate: 48000,
format: 'PCM',
- 借助recorderManager.start()开始录音。
handleTouchStart: function(e){
const options = {
sampleRate: 16000,
numberOfChannels: 1,
encodeBitRate: 48000,
format: 'PCM'
}
recorderManager.start(options);
wx.showLoading({
title: '正在录音中...',
})
},
- 借助recorderManager.stop()结束录音,生成音频文件。
handleTouchEnd: function(e){
wx.hideLoading();
this.setData({
buttonName: e.target.dataset.name
});
recorderManager.stop();
},
4. 调用语音识别API接口,云端实施"语音-文本"转换,将结果返回移动终端
- 语音API接口端音频格式须为base64。
- 借助fs.readFile读取文件。
- 利用wx.arrayBufferToBase64()实现格式转换。
- 调用语音API接口,智能化实施"语音-文本"转换。
bindRecorderStopEvent: function(){
var that = this;
recorderManager.onStop((res) => {
wx.showLoading({
title: '正在识别中...',
})
var baiduBccessToken = wx.getStorageSync("baidu_yuyin_access_token");
var tempFilePath = res.tempFilePath;
const fs = wx.getFileSystemManager();
fs.readFile({
filePath: tempFilePath,
success(res) {
const base64 = wx.arrayBufferToBase64(res.data);
var fileSize = res.data.byteLength;
wx.request({
url: 'http://vop.baidu.com/server_api',
data: {
format: 'pcm',
rate: 16000,
channel: 1,
cuid: 'sdfdfdfsfs',
token: baiduBccessToken,
speech: base64,
len: fileSize
},
method: 'POST',
header: {
'content-type': 'application/json'
},
success(res) {
console.log("语音识别结束")
wx.hideLoading();
console.log(res.data);
if(res.data.err_no == 0){
var result = res.data.result;
if(result.length == 0){
wx.showToast({
title: "未识别到语音信息!",
icon: 'none',
duration: 3000
})
return;
}
}
})
}
})
})
}
5. 逻辑判断,返回结果
- 将智能识别后的文字信息与标准答案进行对比,借助逻辑判断单元,得到测试分数,并返回结果。
二、基于事实的图像识别分类算法
主要围绕以下5个方面,进行详细阐述:
- 安装功能库
- 载入Sketch4IAS数据集
- 训练手绘立方体图像识别模型
- 训练手绘指定时刻钟表图识别模型
- 模型部署
1. 安装功能库
- 分析应用场景,安装部署必要的功能库。
!python -m pip install paddlepaddle-gpu==2.1.1.post110 -f https://paddlepaddle.org.cn/whl/mkl/stable.html
!pip install pandas
!pip install Pillow
!pip install numpy
2. 载入Sketch4IAS数据集
-
构建临床MoCA数据集Sketch4IAS,包含400张Screen-Image、1850张Camera-Image。
-
Sketch4IAS数据集是第一个MoCA手绘图像数据集,由“手绘立方体、手绘指定时刻钟表”2类图像构成,所有图像都标注了评估分数。
-
数据集部分内容展示如下:四行表格分别表示Camera-Image、Screen-Image、预处理图像、标记分数。左侧、右侧各三列分别表示手绘立方体、手绘指定时刻钟表图像。
-
对照数据集Sketch4IAS展示图像,生成对应预处理图像。各部分详细代码如下:
-
手绘sketch图片预处理。高斯平滑处理原始采集图像,选择合适的二值化处理方式得到灰度图。
import cv2 as cv
import os
blur_blk_size = (5, 5)
sigma_x = 2
thres_blk_size = (11, 12)
def convert(src_img_path, tar_img_path):
src_img = cv.imread(src_img_path)
# 得到灰度图
gray_img = cv.cvtColor(src_img, cv.COLOR_BGR2GRAY)
# 平滑处理
gray_img = cv.GaussianBlur(gray_img, blur_blk_size, sigma_x)
# Method1,阈值处理的方式
# r, b = cv.threshold(gray_img, 140, 255, cv.THRESH_BINARY)
# Method2,自适应阈值确定的方式
b = cv.adaptiveThreshold(gray_img, 255,
cv.ADAPTIVE_THRESH_MEAN_C,
cv.THRESH_BINARY,
thres_blk_size[0],
thres_blk_size[1])
cv.imwrite(tar_img_path, b)
# 转换手绘正方体(Cube)和钟表(Watch)图像
src_cube = os.path.join(src_dir, "image_cube")
src_watch = os.path.join(src_dir, "image_watch")
tar_cube = os.path.join(tar_dir, "image_cube")
tar_watch = os.path.join(tar_dir, "image_watch")
cube_names = os.listdir(src_cube)
for name in cube_names:
scp_path = os.path.join(src_cube, name)
tcp_path = os.path.join(tar_cube, name)
convert(scp_path, tcp_path)
watch_names = os.listdir(src_watch)
for name in cube_names:
scp_path = os.path.join(src_watch, name)
tcp_path = os.path.join(tar_watch, name)
convert(scp_path, tcp_path)
# 可参考https://www.tutorialspoint.com/opencv/opencv_adaptive_threshold.htm
- 继续处理Sketch4IAS预处理数据。划分训练集、测试集,记录对应图片所在位置、标签。
import codecs
import os
import random
import shutil
from PIL import Image
import pandas as pd
import argparse
def init_parser():
# 初始化参数
parser = argparse.ArgumentParser()
parser.add_argument('--train_ratio', default=4.0 / 5)
parser.add_argument('--all_file_dir', default='/mnt/hdd1/miaolanxin/program/flowers/work/')
parser.add_argument('--label_files', default='/mnt/hdd1/miaolanxin/program/dataset_image_label_pre/labels.csv')
parser.add_argument('--image_files', default='/mnt/hdd1/miaolanxin/program/dataset_image_label_pre/pre_images_blur/image_cube/')
parser.add_argument('--input', default='cube')
return parser.parse_args()
opts = init_parser()
train_ratio = opts.train_ratio
all_file_dir = opts.all_file_dir
label_files = opts.label_files
image_files = opts.image_files
# 对数据集进行划分, 将训练和测试集图片分别存放到两个文件夹中,并记录对应的文件路径和标签
def preprocess(target_name: str):
label = pd.read_csv(label_files)
label_cube = list(label['cube'])
label_shape = list(label['shape'])
label_number = list(label['number'])
label_pointer = list(label['pointer'])
# 步骤一,生成训练集、测试集文件夹,以及存储图像标签信息的文本
train_name = str(target_name) + "_trainImageSet"
eval_name = str(target_name) + "_evalImageSet"
images = os.path.join(all_file_dir + str(target_name))
train_image_dir = os.path.join(all_file_dir, str(target_name), train_name)
eval_image_dir = os.path.join(all_file_dir, str(target_name), eval_name)
if not os.path.exists(train_image_dir):
os.makedirs(train_image_dir)
if not os.path.exists(eval_image_dir):
os.makedirs(eval_image_dir)
train_file_list = "train_" + str(target_name) + ".txt"
eval_file_list = "eval_" + str(target_name) + ".txt"
train_file = codecs.open(os.path.join(images, train_file_list), 'w')
eval_file = codecs.open(os.path.join(images, eval_file_list), 'w')
image_file = image_files
# 步骤二,生成对应标签
for file in os.listdir(image_file):
image_id = int(file[:-4])
if target_name == 'cube':
label_id = str(int(label_cube[image_id - 1]))
if target_name == 'clock':
# 分别对应着有关钟表的shape, number, pointer三个标签
label_id1 = str(int(label_shape[image_id - 1]))
label_id2 = str(int(label_number[image_id - 1]))
label_id3 = str(int(label_pointer[image_id - 1]))
label_id = label_id1 + label_id2 + label_id3
try:
if random.uniform(0, 1) <= train_ratio:
shutil.copyfile(os.path.join(image_files, file), os.path.join(train_image_dir, file))
train_file.write("{0}\t{1}\n".format(os.path.join(train_image_dir, file), label_id))
else:
shutil.copyfile(os.path.join(image_files, file), os.path.join(eval_image_dir, file))
eval_file.write("{0}\t{1}\n".format(os.path.join(eval_image_dir, file), label_id))
except Exception as e:
pass
train_file.close()
eval_file.close()
name = opts.input
preprocess(name)
- 数据流读取
# 读取预处理图像
import os
import math
import codecs
import numpy as np
from PIL import Image, ImageEnhance
from config import train_parameters
import paddle.fluid as fluid
# 考虑到手绘图像不一定处画布中央,对预处理图像进行随机填充处理,完成图像数据增强。该过程保证了图像完整性,实现了画布合适位置均可出现该图像。
def random_padding(img, scale_ratio):
w, h = img.size
length = int(math.sqrt(w * h * scale_ratio))
length = max(w, h, length)
img_new = Image.new(img.mode, (length, length), "white")
width = length - w
height = length - h
width = int(np.random.uniform(0, width))
height = int(np.random.uniform(0, height))
img_new.paste(img, (width, height))
return img_new
# 在图像长或宽方向上进行随机伸缩处理,完成图像数据增强。
def random_scaling(img):
w, h = img.size
prob = np.random.uniform(0, 1)
if prob > 0.5:
img = img.resize((w, 500), Image.ANTIALIAS)
else:
img = img.resize((500, h), Image.ANTIALIAS)
return img
# 关于图像几何中心随机旋转一定角度,完成图像数据增强。参照应用场景的实际,将旋转角度范围控制在±14°
def rotate_image(img):
angle = np.random.randint(-14, 15)
img = img.rotate(angle)
return img
# 对图像采取扭曲处理,完成图像数据增强。
def distort_color(img):
img = random_brightness(img)
img = random_contrast(img)
img = random_saturation(img)
img = random_hue(img)
return img
# 立方体识别任务的reader
def image_reader(file_list, mode):
with codecs.open(file_list) as flist:
lines = [line.strip() for line in flist]
def reader():
np.random.shuffle(lines)
for line in lines:
if mode == 'train':
img_path, label = line.split()
key = img_path[-8:-4]
img = Image.open(img_path)
try:
# 参照实际应用环境,采取四种方式对输入图像进行数据增强。
if img.mode != 'L':
# 转换为灰度图。
img = img.convert('L')
if train_parameters['image_enhance_strategy']['need_distort']:
img = distort_color(img)
if train_parameters['image_enhance_strategy']['need_rotate']:
img = rotate_image(img)
if train_parameters['image_enhance_strategy']['need_scaling']:
img = random_scaling(img)
if train_parameters['image_enhance_strategy']['need_padding']:
img = random_padding(img, 1.5)
# 后续处理基于此灰度图。
img = img.resize((225, 225), Image.BILINEAR)
img = img.convert('L')
img = np.array(img).astype('float32')
img *= 0.007843
yield img, int(label), key
except Exception as e:
pass
if mode == 'val':
img_path, label = line.split()
key = img_path[-8:-4]
img = Image.open(img_path)
if img.mode != 'L':
img = img.convert('L')
img = img.resize((225, 225), Image.BILINEAR)
img = img.convert('L')
img = np.array(img).astype('float32')
img *= 0.007843
yield img, int(label), key
return reader
# 返回一个产生器
# 钟表识别任务的reader
def custom_image_reader(file_list, data_dir, mode):
with codecs.open(file_list) as flist:
lines = [line.strip() for line in flist]
def reader():
np.random.shuffle(lines)
for line in lines:
if mode == 'train':
img_path, label = line.split()
label1 = label[0]
label2 = label[1]
label3 = label[2]
key = img_path[-8:-4]
img = Image.open(img_path)
try:
# 在数据增强上面,进行了一系列处理,padding是为了保证图像的完整性以及随机出现在画布的某个位置;
if img.mode != 'L':
img = img.convert('L')
if train_parameters_clock['image_enhance_strategy']['need_distort']:
img = distort_color(img)
if train_parameters_clock['image_enhance_strategy']['need_rotate']:
img = rotate_image(img)
if train_parameters_clock['image_enhance_strategy']['need_scaling']:
img = random_scaling(img)
if train_parameters_clock['image_enhance_strategy']['need_padding']:
img = random_padding(img, 1.5)
img = img.resize((224, 224), Image.BILINEAR)
img = np.array(img).astype('float32')
img = img.transpose((2, 0, 1))
img *= 0.007843
yield img, int(label1), int(label2), int(label3), key
except Exception as e:
pass
if mode == 'val':
img_path, label = line.split()
if not os.path.isfile(img_path):
img_path = os.path.join(data_dir, img_path)
img = Image.open(img_path)
if img.mode != 'L':
img = img.convert('L')
img = img.resize((224, 224), Image.BILINEAR)
img = np.array(img).astype('float32')
img = img.transpose((2, 0, 1))
img *= 0.007843
yield img, int(label1), int(label2), int(label3), key
elif mode == 'test':
img_path = line
if not os.path.isfile(img_path):
img_path = os.path.join(data_dir, img_path)
img = Image.open(img_path)
if img.mode != 'L':
img = img.convert('L')
img = img.resize((224, 224), Image.BILINEAR)
img = np.array(img).astype('float32')
img = img.transpose((2, 0, 1))
img *= 0.007843
yield img, int(label1), int(label2), int(label3), key
return reader
3. 训练手绘立方体图像识别模型
- 分析应用场景、实验情况,选取Sketch-a-Net网络作为基础,让模型更好地习得手绘立方体的点、线、面3维空间关系。网络形式简单、部署简便。
# 训练Sketch-a-Net基础网络,用于手绘立方体识别任务,输入tensor的大小为[Batch_size, 1, 255, 255],输出tensor的大小为[Batch_size, class_dim]
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
import paddle.nn.functional as F
class SketchANet(fluid.dygraph.Layer):
def __init__(self, num_classes=10):
super().__init__()
self.num_classes = num_classes
self.conv1 = Conv2D(num_channels=1, num_filters=64, filter_size=15, stride=3, padding=0)
self.conv2 = Conv2D(num_channels=64, num_filters=128, filter_size=5, stride=1, padding=0)
self.conv3 = Conv2D(num_channels=128, num_filters=256, filter_size=3, stride=1, padding=1)
self.conv4 = Conv2D(num_channels=256, num_filters=256, filter_size=3, stride=1, padding=1)
self.conv5 = Conv2D(num_channels=256, num_filters=256, filter_size=3, stride=1, padding=1)
self.conv6 = Conv2D(num_channels=256, num_filters=512, filter_size=7, stride=1, padding=0)
self.conv7 = Conv2D(num_channels=512, num_filters=512, filter_size=1, stride=1, padding=0)
self.dropout = Dropout(p=0.2)
self.linear = Linear(512, self.num_classes)
def forward(self, x):
x = fluid.layers.unsqueeze(input=x, axes=[1])
x = fluid.layers.relu(self.conv1(x))
x = F.max_pool2d(x, (3, 3), stride=2)
x = fluid.layers.relu(self.conv2(x))
x = F.max_pool2d(x, (3, 3), stride=2)
x = fluid.layers.relu(self.conv3(x))
x = fluid.layers.relu(self.conv4(x))
x = fluid.layers.relu(self.conv5(x))
x = F.max_pool2d(x, (3, 3), stride=2)
x = self.dropout(fluid.layers.relu(self.conv6(x)))
x = self.dropout(fluid.layers.relu(self.conv7(x)))
x = fluid.layers.reshape(x, shape=[-1, 512])
return self.linear(x)
- 训练手绘立方体图像识别网络。
import paddle.fluid as fluidimport paddle.fluid as fluid
import numpy as np
import paddle
import reader
import os
import utils
import config
from config import __init_train_parameters
from utils import init_log_config
from network import MobileNet, SketchANet
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from eval import eval_model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', default=50)
parser.add_argument('--train_batch_size', default=64)
parser.add_argument('--input', default='cube')
opts = parser.parse_args()
config.train_parameters["num_epochs"] = opts.num_epochs
config.train_parameters["train_batch_size"] = opts.train_batch_size
__init_train_parameters(opts.input)
init_log_config(opts.input, is_train=config.train_parameters["mode"])
def build_optimizer(parameter_list=None):
epoch = config.train_parameters["num_epochs"]
batch_size = config.train_parameters["train_batch_size"]
iters = config.train_parameters["train_image_count"] // batch_size
learning_strategy = config.train_parameters['sgd_strategy']
lr = learning_strategy['learning_rate']
boundaries = [int(epoch * i * iters) for i in learning_strategy["lr_epochs"]]
values = [i * lr for i in learning_strategy["lr_decay"]]
optimizer = fluid.optimizer.SGDOptimizer(learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005),
parameter_list=parameter_list)
utils.logger.info("use Adam optimizer")
return optimizer
def load_params(model, optimizer):
if config.train_parameters["continue_train"] and os.path.exists(config.train_parameters['save_model_dir']+'.pdparams'):
utils.logger.info("load params from {}".format(config.train_parameters['save_model_dir']))
para_dict, opti_dict = fluid.dygraph.load_dygraph(config.train_parameters['save_model_dir'])
model.set_dict(para_dict)
optimizer.set_dict(opti_dict)
def train():
utils.logger.info("start train")
with fluid.dygraph.guard():
epoch_num = config.train_parameters["num_epochs"]
SketchNet = SketchANet(config.train_parameters['class_dim'])
optimizer = build_optimizer(parameter_list=SketchNet.parameters())
file_list = config.train_parameters['train_file_list']
custom_reader = reader.image_reader(file_list, mode='train')
train_reader = paddle.batch(custom_reader,
batch_size=config.train_parameters['train_batch_size'],
drop_last=True)
current_acc = 0.0
for current_epoch in range(epoch_num):
epoch_acc = 0.0
batch_count = 0
epoch_recall = 0.0
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array([x[0] for x in data]).astype('float32')
y_data = np.array([[x[1]] for x in data]).astype('int')
key = np.array([[x[2]] for x in data]).astype('str')
img = fluid.dygraph.to_variable(dy_x_data)
label = fluid.dygraph.to_variable(y_data)
label.stop_gradient = True
out = SketchNet(img)
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
acc_test = fluid.layers.accuracy(input=softmax_out, label=label)
pred_label = paddle.argmax(softmax_out, axis=1)
comp = fluid.metrics.CompositeMetric()
precision = fluid.metrics.Precision()
recall = fluid.metrics.Recall()
comp.add_metric(precision)
comp.add_metric(recall)
comp.update(preds=np.array(pred_label), labels=np.array(label))
acc, recall = comp.eval()
loss = fluid.layers.cross_entropy(softmax_out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
SketchNet.clear_gradients()
batch_count += 1
epoch_acc += acc
epoch_recall += recall
if batch_id % 5 == 0 and batch_id != 0:
utils.logger.info("train: loss at epoch {} step {}: avg_loss: {}, acc: {}, recall: {}"
.format(current_epoch, batch_id, avg_loss, acc, recall))
epoch_acc /= batch_count
epoch_recall /= batch_count
utils.logger.info("epoch {} acc: {} recall: {}".format(current_epoch, epoch_acc, epoch_recall))
if epoch_acc >= current_acc:
utils.logger.info("current epoch {} acc: {} better than last acc: {}, save model"
.format(current_epoch, epoch_acc, current_acc))
current_acc = epoch_acc
fluid.dygraph.save_dygraph(SketchNet.state_dict(), config.train_parameters['save_model_dir'])
fluid.dygraph.save_dygraph(optimizer.state_dict(), config.train_parameters['save_model_dir'])
eval_model()
utils.logger.info("train till end")
if __name__ == "__main__":
train()
4. 训练手绘指定时刻钟表图像识别模型
- 逐一分析”手绘指定时刻钟表“测试项目中3分构成(钟表是否满圆、数字是否依序、指针朝向是否正确),依据实验结果,选取MobileNet轻量级网络作为基础,针对实际应用场景进行优化。模型部署简单、结果反馈迅速。对应代码、数据,详见network.py文件。
5. 模型部署
-
立方体识别、手绘钟表识别训练模型的部署,具体代码详见train_clock.py、train_cube_test.py文件。
-
以cube识别为例,测试在训练集上已习得的模型。
import utils
import paddle.fluid as fluid
import paddle
import reader
from network import SketchANet, MobileNet
import numpy as np
import os
import config
import time
from config import __init_train_parameters_clock, __init_train_parameters
# __init_train_parameters('cube')
__init_train_parameters_clock('clock')
# 加载cube识别评估文件
def eval_model():
utils.logging.info("start eval")
file_list = os.path.join(config.train_parameters['data_dir'], config.train_parameters['eval_file_list'])
with fluid.dygraph.guard():
params, _ = fluid.dygraph.load_persistables(config.train_parameters['save_model_dir'])
params, _ = fluid.load_dygraph(config.train_parameters['save_model_dir'])
SketchNet = SketchANet(config.train_parameters['class_dim'])
SketchNet.set_dict(params)
SketchNet.eval()
test_reader = paddle.batch(reader.image_reader(file_list, 'val'),
batch_size=1,
drop_last=True)
accs = []
recall_list = []
start_time = time.time()
for batch_id, data in enumerate(test_reader()):
dy_x_data = np.array([x[0] for x in data]).astype('float32')
y_data = np.array([[x[1]] for x in data]).astype('int')
img = fluid.dygraph.to_variable(dy_x_data)
label = fluid.dygraph.to_variable(y_data)
label.stop_gradient = True
out = SketchNet(img)
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
pred_label = paddle.argmax(softmax_out, axis=1)
comp = fluid.metrics.CompositeMetric()
precision = fluid.metrics.Precision()
recall = fluid.metrics.Recall()
comp.add_metric(precision)
comp.add_metric(recall)
comp.update(preds=np.array(pred_label), labels=np.array(label))
acc, recall = comp.eval()
accs.append(acc)
recall_list.append(recall)
utils.logger.info("test count: {} , acc: {}, recall: {}, cost time: {}"
.format(config.train_parameters['eval_image_count'], np.mean(accs), np.mean(recall_list), time.time() - start_time))
项目展示
总结展望
- “智能交互的阿尔兹海默症辅助筛查系统”智能化地解决当前阿尔兹海默症初步筛查方式中测试方式繁琐、评价方式过于主观、筛查结果可靠性不高的问题。
- 借助成熟的百度深度学习框架PaddlePaddle,我们稳健地搭建了图像识别模型。实际应用中,能快速完成手绘图像识别任务,可靠给出高置信度测试分数。
- 借助成熟的百度语音识别API,我们系统地构建起智能语音交互系统。筛查过程,测试者可在没有医护人员参与情况下,以基于语音交互的方式实现MoCA测试。
- 二期任务中,扩充系统支持的交互语言,增加广东话、上海话、英语、德语等,从系统支持语言端着手,泛化筛查可覆盖人群。
- 赛后,该系统将部署于某三甲医院相关科室,实际用于阿尔茨海默病的临床筛查。
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.