智能交互的阿尔兹海默症辅助筛查系统

项目背景

  • 阿尔茨海默病(Alzheimer’s Disease, AD)是一种起病隐匿的、进行性发展的神经系统退行性疾病。该疾病病程是一个不可逆过程,迄今为止没有有效的治疗药物和手段,因此与癌症一样,阿尔茨海默病治疗的关键是早期筛查诊断,并在疾病早期对阿尔兹海默症进行干预和延缓。
  • 临床上,通常借助蒙特利尔认知评估(Montreal Cognitive Assessmen, MoCA)测试来初步筛查阿尔茨海默病。受限于纸质MoCA测试临床测试效率低、测试结果严重依赖医生的主观判断、经验水平等诸多不利因素,当前阿尔茨海默病的早期筛查存在效率低下、人为不确定因素掺杂较多、单位时间筛查人数较少等缺陷。
  • 值得注意的是,截止2019年中国阿尔兹海默症患病人数超过1000万,是全球阿尔兹海默症患者数量最多的国家,由此造成的不利社会影响正在愈演愈烈。因此,设计、构建一款测试方式上简洁、评价更加客观科学、置信度更高的阿尔茨海默症筛查系统,对缓解社会实质性矛盾意义深远。
  • 中国阿尔兹海默症不同年龄段合并患病率、未来30年中国阿尔兹海默症发展趋势(万人)

技术方案

  • 项目实现上,以基于语音识别、图像处理的人工智能交互方式,实现MoCA筛查量表基本功能,部署在手机、平板电脑等移动终端。
  • 语音交互上,借助百度语音识别(API)转换技术,系统集成语音交互功能模块,实现人机交互中语音的精确采集、音轨录入、智能识别、反馈播放等功能。
  • 图像处理上,借助百度PaddlePaddle、PaddleLite框架平台,利用医学图像处理深度神经网络技术,用自建Sketch4IAS数据集图像特性来构建、训练图像识别深度神经网络模型。
  • 网络架构可视化如下:

  • 系统涵盖基于语音的可交互单元测试模块(IUTM)、基于真值的智能分析模块(IAM),两个模块均通过深度学习技术开发。

  • IUTM用于交互式地采集测试数据。

  • IAM负责对多模态数据进行智能化处理和分析。

  • 整个过程中,重点实现了基于语音识别的智能交互系统、基于事实的图像识别分类算法,接下来分别进行详细介绍、代码实现。

一、基于语音识别的智能交互系统

项目实现中的语音交互过程主要为:

  • 设计、处理语音识别API接口
  • 制作、搭建背景提示音
  • 实时采集测试语音数据
  • 调用语音识别API接口,云端实施"语音-文本"转换,将结果返回移动终端
  • 逻辑判断,返回结果

1. 设计、处理语音识别API接口

  • 注册百度云账号、登录并创建应用,获取"API Key""Secret Key"值。根据应用场景配置语音识别API参数,样例图如下。
globalData: {
    baiduai:{
        apiKey: 'Your_apiKey ',
        secretKey: 'Your_secretKey',
        url: "https://aip.baidubce.com/oauth/2.0/token"
        },
    baiduyuyin:{
        apiKey: 'Your_apiKey',
        secretKey: 'Your_secretKey',
        url: 'https://openapi.baidu.com/oauth/2.0/token'
        },
}
  • 获取AccessToken。在上阶段基础上,系统分配此接口相关凭证–“AppID”“API Key”“Secret Key”–用以生成AccessToken,此为鉴权认证。

  • 项目将鉴认证关键信息嵌入对应环节。该设计确保了语音识别前AccessToken的获取,避免了复杂逻辑的异步操作。

getBaiduAiAccessToken: function(){
    var that = this;
    var baiduai = that.globalData.baiduai;
    wx.request({
        url: baiduai.url,
        data: {
            grant_type: 'client_credentials',
            client_id: baiduai.apiKey,
            client_secret: baiduai.secretKey
        },
        method: 'POST',
        header: {
            'content-type': 'application/x-www-form-urlencoded'
            },
        success(res) {
            wx.setStorageSync("baidu_ai_access_token", res.data.access_token);
            wx.setStorageSync("baidu_ai_time", new Date().getTime());
        }
    })
},

getBaiduYuyinAccessToken: function(){
    var that = this;
    var baiduyuyin = that.globalData.baiduyuyin;
    wx.request({
        url: baiduyuyin.url,
        data: {
            grant_type: 'client_credentials',
            client_id: baiduyuyin.apiKey,
            client_secret: baiduyuyin.secretKey
        },
        method: 'POST',
        header: {
            'content-type': 'application/x-www-form-urlencoded'
        },
        success(res) {
            wx.setStorageSync("baidu_yuyin_access_token", res.data.access_token);
            wx.setStorageSync("baidu_yuyin_time", new Date().getTime());
        }
    })
},

2. 制作、搭建背景提示音

  • 用语音引导开启每个测试界面。
  • 借助wx.createInnerAudioContext()方法来实现语音起始交互。
  • 设置innerAudioContext.autoplay = true,来实现智能自动语音交互。
const innerAudioContext = wx.createInnerAudioContext()
onLoad: function(options){
    var myDate = new Date()
    this.setData({start_ms: myDate.getTime()})
    innerAudioContext.autoplay = true
    innerAudioContext.src = 'pages/video/X.mp3'
    innerAudioContext.onPlay(() => {
        console.log('开始播放')
    })
    innerAudioContext.onError((res) => {
        console.log(res.errMsg)
        console.log(res.errCode)
    })
},
  • 通过innerAudioContext.stop()来实现页面转换时语音的自动停止播放。
gotoX:function()
{
    innerAudioContext.stop();
    app.globalData.global_sum += that.data.count;    
    wx.redirectTo({
        url: '../page_X/page_X',
    })
},

3. 实时采集测试语音数据

  • 用简洁的手势动作–轻触录音、移开结束–来实施实时语音采集。细化讲:

  • 分析应用场景,设置录音参数如下。

sampleRate: 16000,
numberOfChannels: 1,
encodeBitRate: 48000,
format: 'PCM',
  • 借助recorderManager.start()开始录音。
handleTouchStart: function(e){
    const options = {
        sampleRate: 16000,
        numberOfChannels: 1,  
        encodeBitRate: 48000,
        format: 'PCM'
    }
    recorderManager.start(options);
    wx.showLoading({
        title: '正在录音中...',
    })
},
  • 借助recorderManager.stop()结束录音,生成音频文件。
handleTouchEnd: function(e){
    wx.hideLoading();
    this.setData({
        buttonName: e.target.dataset.name
    });
    recorderManager.stop();
},

4. 调用语音识别API接口,云端实施"语音-文本"转换,将结果返回移动终端

  • 语音API接口端音频格式须为base64。
  • 借助fs.readFile读取文件。
  • 利用wx.arrayBufferToBase64()实现格式转换。
  • 调用语音API接口,智能化实施"语音-文本"转换。
bindRecorderStopEvent: function(){
    var that = this;
    recorderManager.onStop((res) => {
        wx.showLoading({
            title: '正在识别中...',
        })
        var baiduBccessToken = wx.getStorageSync("baidu_yuyin_access_token");
        var tempFilePath = res.tempFilePath;
        const fs = wx.getFileSystemManager();
        fs.readFile({
            filePath: tempFilePath,
            success(res) {
                const base64 = wx.arrayBufferToBase64(res.data);
                var fileSize = res.data.byteLength; 
                wx.request({
                    url: 'http://vop.baidu.com/server_api',
                    data: {
                        format: 'pcm',
                        rate: 16000,
                        channel: 1,
                        cuid: 'sdfdfdfsfs',
                        token: baiduBccessToken,
                        speech: base64,
                        len: fileSize
                    },
                    method: 'POST',
                    header: {
                        'content-type': 'application/json' 
                    },
                    success(res) {
                        console.log("语音识别结束")
                        wx.hideLoading();
                        console.log(res.data);
                        if(res.data.err_no == 0){
                            var result = res.data.result;
                            if(result.length == 0){
                                wx.showToast({
                                    title: "未识别到语音信息!",
                                    icon: 'none',
                                    duration: 3000
                                })
                                return;
                            }
                    }
                })
            }
        })
    })
}

5. 逻辑判断,返回结果

  • 将智能识别后的文字信息与标准答案进行对比,借助逻辑判断单元,得到测试分数,并返回结果。

二、基于事实的图像识别分类算法

主要围绕以下5个方面,进行详细阐述:

  • 安装功能库
  • 载入Sketch4IAS数据集
  • 训练手绘立方体图像识别模型
  • 训练手绘指定时刻钟表图识别模型
  • 模型部署

1. 安装功能库

  • 分析应用场景,安装部署必要的功能库。
!python -m pip install paddlepaddle-gpu==2.1.1.post110 -f https://paddlepaddle.org.cn/whl/mkl/stable.html
!pip install pandas
!pip install Pillow  
!pip install numpy

2. 载入Sketch4IAS数据集

  • 构建临床MoCA数据集Sketch4IAS,包含400张Screen-Image、1850张Camera-Image。

  • Sketch4IAS数据集是第一个MoCA手绘图像数据集,由“手绘立方体、手绘指定时刻钟表”2类图像构成,所有图像都标注了评估分数。

  • 数据集部分内容展示如下:四行表格分别表示Camera-Image、Screen-Image、预处理图像、标记分数。左侧、右侧各三列分别表示手绘立方体、手绘指定时刻钟表图像。

  • 对照数据集Sketch4IAS展示图像,生成对应预处理图像。各部分详细代码如下:

  • 手绘sketch图片预处理。高斯平滑处理原始采集图像,选择合适的二值化处理方式得到灰度图。

import cv2 as cv
import os
blur_blk_size = (5, 5)
sigma_x = 2
thres_blk_size = (11, 12)
def convert(src_img_path, tar_img_path):
    src_img = cv.imread(src_img_path)
    # 得到灰度图
    gray_img = cv.cvtColor(src_img, cv.COLOR_BGR2GRAY)
    # 平滑处理
    gray_img = cv.GaussianBlur(gray_img, blur_blk_size, sigma_x)
    # Method1,阈值处理的方式
    # r, b = cv.threshold(gray_img, 140, 255, cv.THRESH_BINARY)
    # Method2,自适应阈值确定的方式
    b = cv.adaptiveThreshold(gray_img, 255, 
                             cv.ADAPTIVE_THRESH_MEAN_C, 
                             cv.THRESH_BINARY, 
                             thres_blk_size[0], 
                             thres_blk_size[1])
    cv.imwrite(tar_img_path, b)
    
# 转换手绘正方体(Cube)和钟表(Watch)图像
src_cube = os.path.join(src_dir, "image_cube")
src_watch = os.path.join(src_dir, "image_watch")
tar_cube = os.path.join(tar_dir, "image_cube")
tar_watch = os.path.join(tar_dir, "image_watch")

cube_names = os.listdir(src_cube)
for name in cube_names:
    scp_path = os.path.join(src_cube, name)
    tcp_path = os.path.join(tar_cube, name)
    convert(scp_path, tcp_path)
    
watch_names = os.listdir(src_watch)
for name in cube_names:
    scp_path = os.path.join(src_watch, name)
    tcp_path = os.path.join(tar_watch, name)
convert(scp_path, tcp_path)

# 可参考https://www.tutorialspoint.com/opencv/opencv_adaptive_threshold.htm 
  • 继续处理Sketch4IAS预处理数据。划分训练集、测试集,记录对应图片所在位置、标签。
import codecs
import os
import random
import shutil
from PIL import Image
import pandas as pd
import argparse

def init_parser():
    # 初始化参数
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_ratio', default=4.0 / 5)
    parser.add_argument('--all_file_dir', default='/mnt/hdd1/miaolanxin/program/flowers/work/')
    parser.add_argument('--label_files', default='/mnt/hdd1/miaolanxin/program/dataset_image_label_pre/labels.csv')
    parser.add_argument('--image_files', default='/mnt/hdd1/miaolanxin/program/dataset_image_label_pre/pre_images_blur/image_cube/')
    parser.add_argument('--input', default='cube')
    return parser.parse_args()

opts = init_parser()
train_ratio = opts.train_ratio
all_file_dir = opts.all_file_dir
label_files = opts.label_files
image_files = opts.image_files

# 对数据集进行划分, 将训练和测试集图片分别存放到两个文件夹中,并记录对应的文件路径和标签
def preprocess(target_name: str):
    label = pd.read_csv(label_files)
    label_cube = list(label['cube'])
    label_shape = list(label['shape'])
    label_number = list(label['number'])
    label_pointer = list(label['pointer'])

    # 步骤一,生成训练集、测试集文件夹,以及存储图像标签信息的文本
    train_name = str(target_name) + "_trainImageSet"
    eval_name = str(target_name) + "_evalImageSet"
    images = os.path.join(all_file_dir + str(target_name))
    train_image_dir = os.path.join(all_file_dir, str(target_name), train_name)
    eval_image_dir = os.path.join(all_file_dir, str(target_name), eval_name)
    if not os.path.exists(train_image_dir):
        os.makedirs(train_image_dir)
    if not os.path.exists(eval_image_dir):
        os.makedirs(eval_image_dir)
    train_file_list = "train_" + str(target_name) + ".txt"
    eval_file_list = "eval_" + str(target_name) + ".txt"
    train_file = codecs.open(os.path.join(images, train_file_list), 'w')
    eval_file = codecs.open(os.path.join(images, eval_file_list), 'w')
    image_file = image_files

    # 步骤二,生成对应标签
    for file in os.listdir(image_file):
        image_id = int(file[:-4])
        if target_name == 'cube':
            label_id = str(int(label_cube[image_id - 1]))

        if target_name == 'clock':
            # 分别对应着有关钟表的shape, number, pointer三个标签
            label_id1 = str(int(label_shape[image_id - 1]))
            label_id2 = str(int(label_number[image_id - 1]))
            label_id3 = str(int(label_pointer[image_id - 1]))
            label_id = label_id1 + label_id2 + label_id3
        try:
            if random.uniform(0, 1) <= train_ratio:
                shutil.copyfile(os.path.join(image_files, file), os.path.join(train_image_dir, file))
                train_file.write("{0}\t{1}\n".format(os.path.join(train_image_dir, file), label_id))
            else:
                shutil.copyfile(os.path.join(image_files, file), os.path.join(eval_image_dir, file))
                eval_file.write("{0}\t{1}\n".format(os.path.join(eval_image_dir, file), label_id))
        except Exception as e:
            pass
    train_file.close()
    eval_file.close()

name = opts.input
preprocess(name)
  • 数据流读取
# 读取预处理图像

import os
import math
import codecs
import numpy as np
from PIL import Image, ImageEnhance
from config import train_parameters
import paddle.fluid as fluid

# 考虑到手绘图像不一定处画布中央,对预处理图像进行随机填充处理,完成图像数据增强。该过程保证了图像完整性,实现了画布合适位置均可出现该图像。
def random_padding(img, scale_ratio):
    w, h = img.size
    length = int(math.sqrt(w * h * scale_ratio))
    length = max(w, h, length)
    img_new = Image.new(img.mode, (length, length), "white")
    width = length - w
    height = length - h
    width = int(np.random.uniform(0, width))
    height = int(np.random.uniform(0, height))
    img_new.paste(img, (width, height))
    return img_new

# 在图像长或宽方向上进行随机伸缩处理,完成图像数据增强。
def random_scaling(img):
    w, h = img.size
    prob = np.random.uniform(0, 1)
    if prob > 0.5:
        img = img.resize((w, 500), Image.ANTIALIAS)
    else:
        img = img.resize((500, h), Image.ANTIALIAS)
    return img

# 关于图像几何中心随机旋转一定角度,完成图像数据增强。参照应用场景的实际,将旋转角度范围控制在±14°
def rotate_image(img):
    angle = np.random.randint(-14, 15)  
    img = img.rotate(angle)
    return img

# 对图像采取扭曲处理,完成图像数据增强。
def distort_color(img):
    img = random_brightness(img)
    img = random_contrast(img)
    img = random_saturation(img)
    img = random_hue(img)
    return img

# 立方体识别任务的reader
def image_reader(file_list, mode):
    with codecs.open(file_list) as flist:
        lines = [line.strip() for line in flist]
    def reader():
        np.random.shuffle(lines)  
        for line in lines:
            if mode == 'train':
                img_path, label = line.split()
                key = img_path[-8:-4]
                img = Image.open(img_path)
                try:
                    # 参照实际应用环境,采取四种方式对输入图像进行数据增强。
                    if img.mode != 'L':
                        # 转换为灰度图。
                        img = img.convert('L')  
                    if train_parameters['image_enhance_strategy']['need_distort']:
                        img = distort_color(img)
                    if train_parameters['image_enhance_strategy']['need_rotate']:
                        img = rotate_image(img)
                    if train_parameters['image_enhance_strategy']['need_scaling']:
                        img = random_scaling(img)
                    if train_parameters['image_enhance_strategy']['need_padding']:
                        img = random_padding(img, 1.5)
                    # 后续处理基于此灰度图。
                    img = img.resize((225, 225), Image.BILINEAR)
                    img = img.convert('L')
                    img = np.array(img).astype('float32')
                    img *= 0.007843
                    yield img, int(label), key
                except Exception as e:
                    pass
            if mode == 'val':
                img_path, label = line.split()
                key = img_path[-8:-4]
                img = Image.open(img_path)
                if img.mode != 'L':
                    img = img.convert('L')
                img = img.resize((225, 225), Image.BILINEAR)
                img = img.convert('L')
                img = np.array(img).astype('float32')
                img *= 0.007843 
                yield img, int(label), key
    return reader
    # 返回一个产生器
    
# 钟表识别任务的reader
def custom_image_reader(file_list, data_dir, mode):
    with codecs.open(file_list) as flist:
        lines = [line.strip() for line in flist]
    def reader():
        np.random.shuffle(lines)
        for line in lines:
            if mode == 'train':
                img_path, label = line.split()
                label1 = label[0]
                label2 = label[1]
                label3 = label[2]
                key = img_path[-8:-4]
                img = Image.open(img_path)
                try:
                    # 在数据增强上面,进行了一系列处理,padding是为了保证图像的完整性以及随机出现在画布的某个位置;
                    if img.mode != 'L':
                        img = img.convert('L')
                    if train_parameters_clock['image_enhance_strategy']['need_distort']:
                        img = distort_color(img)
                    if train_parameters_clock['image_enhance_strategy']['need_rotate']:
                        img = rotate_image(img)
                    if train_parameters_clock['image_enhance_strategy']['need_scaling']:
                        img = random_scaling(img)
                    if train_parameters_clock['image_enhance_strategy']['need_padding']:
                        img = random_padding(img, 1.5)
                    img = img.resize((224, 224), Image.BILINEAR)
                    img = np.array(img).astype('float32')
                    img = img.transpose((2, 0, 1))
                    img *= 0.007843
                    yield img, int(label1), int(label2), int(label3), key
                except Exception as e:
                    pass
            if mode == 'val':
                img_path, label = line.split()
                if not os.path.isfile(img_path):
                    img_path = os.path.join(data_dir, img_path)
                img = Image.open(img_path)
                if img.mode != 'L':
                    img = img.convert('L')
                img = img.resize((224, 224), Image.BILINEAR)
                img = np.array(img).astype('float32')
                img = img.transpose((2, 0, 1))
                img *= 0.007843  
                yield img, int(label1), int(label2), int(label3), key
            elif mode == 'test':
                img_path = line
                if not os.path.isfile(img_path):
                    img_path = os.path.join(data_dir, img_path)
                img = Image.open(img_path)
                if img.mode != 'L':
                    img = img.convert('L')
                img = img.resize((224, 224), Image.BILINEAR)
                img = np.array(img).astype('float32')
                img = img.transpose((2, 0, 1))
                img *= 0.007843  
                yield img, int(label1), int(label2), int(label3), key
    return reader

3. 训练手绘立方体图像识别模型

  • 分析应用场景、实验情况,选取Sketch-a-Net网络作为基础,让模型更好地习得手绘立方体的点、线、面3维空间关系。网络形式简单、部署简便。
# 训练Sketch-a-Net基础网络,用于手绘立方体识别任务,输入tensor的大小为[Batch_size, 1, 255, 255],输出tensor的大小为[Batch_size, class_dim]

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
import paddle.nn.functional as F
class SketchANet(fluid.dygraph.Layer):
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes
        self.conv1 = Conv2D(num_channels=1, num_filters=64, filter_size=15, stride=3, padding=0)
        self.conv2 = Conv2D(num_channels=64, num_filters=128, filter_size=5, stride=1, padding=0)
        self.conv3 = Conv2D(num_channels=128, num_filters=256, filter_size=3, stride=1, padding=1)
        self.conv4 = Conv2D(num_channels=256, num_filters=256, filter_size=3, stride=1, padding=1)
        self.conv5 = Conv2D(num_channels=256, num_filters=256, filter_size=3, stride=1, padding=1)
        self.conv6 = Conv2D(num_channels=256, num_filters=512, filter_size=7, stride=1, padding=0)
        self.conv7 = Conv2D(num_channels=512, num_filters=512, filter_size=1, stride=1, padding=0)
        self.dropout = Dropout(p=0.2)
        self.linear = Linear(512, self.num_classes)
    def forward(self, x):
        x = fluid.layers.unsqueeze(input=x, axes=[1])  
        x = fluid.layers.relu(self.conv1(x))
        x = F.max_pool2d(x, (3, 3), stride=2) 
        x = fluid.layers.relu(self.conv2(x))
        x = F.max_pool2d(x, (3, 3), stride=2)
        x = fluid.layers.relu(self.conv3(x)) 
        x = fluid.layers.relu(self.conv4(x))
        x = fluid.layers.relu(self.conv5(x)) 
        x = F.max_pool2d(x, (3, 3), stride=2)
        x = self.dropout(fluid.layers.relu(self.conv6(x))) 
        x = self.dropout(fluid.layers.relu(self.conv7(x))) 
        x = fluid.layers.reshape(x, shape=[-1, 512])
        return self.linear(x)
  • 训练手绘立方体图像识别网络。
import paddle.fluid as fluidimport paddle.fluid as fluid
import numpy as np
import paddle
import reader
import os
import utils
import config
from config import __init_train_parameters
from utils import init_log_config
from network import MobileNet, SketchANet
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from eval import eval_model
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', default=50)
parser.add_argument('--train_batch_size', default=64)
parser.add_argument('--input', default='cube')
opts = parser.parse_args()

config.train_parameters["num_epochs"] = opts.num_epochs
config.train_parameters["train_batch_size"] = opts.train_batch_size
__init_train_parameters(opts.input)
init_log_config(opts.input, is_train=config.train_parameters["mode"])

def build_optimizer(parameter_list=None):
    epoch = config.train_parameters["num_epochs"]
    batch_size = config.train_parameters["train_batch_size"]
    iters = config.train_parameters["train_image_count"] // batch_size
    learning_strategy = config.train_parameters['sgd_strategy']
    lr = learning_strategy['learning_rate']
    boundaries = [int(epoch * i * iters) for i in learning_strategy["lr_epochs"]]
    values = [i * lr for i in learning_strategy["lr_decay"]]
    optimizer = fluid.optimizer.SGDOptimizer(learning_rate=fluid.layers.piecewise_decay(boundaries, values),
                                             regularization=fluid.regularizer.L2Decay(0.00005),
                                             parameter_list=parameter_list)
    utils.logger.info("use Adam optimizer")
    return optimizer

def load_params(model, optimizer):
    if config.train_parameters["continue_train"] and os.path.exists(config.train_parameters['save_model_dir']+'.pdparams'):
        utils.logger.info("load params from {}".format(config.train_parameters['save_model_dir']))
        para_dict, opti_dict = fluid.dygraph.load_dygraph(config.train_parameters['save_model_dir'])
        model.set_dict(para_dict)
        optimizer.set_dict(opti_dict)

def train():
    utils.logger.info("start train")
    with fluid.dygraph.guard():
        epoch_num = config.train_parameters["num_epochs"]
        SketchNet = SketchANet(config.train_parameters['class_dim'])
        optimizer = build_optimizer(parameter_list=SketchNet.parameters())
        file_list = config.train_parameters['train_file_list']
        custom_reader = reader.image_reader(file_list, mode='train')
        train_reader = paddle.batch(custom_reader,
                                    batch_size=config.train_parameters['train_batch_size'],
                                    drop_last=True)
        current_acc = 0.0
        for current_epoch in range(epoch_num):
            epoch_acc = 0.0
            batch_count = 0
            epoch_recall = 0.0
            for batch_id, data in enumerate(train_reader()):
                dy_x_data = np.array([x[0] for x in data]).astype('float32')
                y_data = np.array([[x[1]] for x in data]).astype('int')
                key = np.array([[x[2]] for x in data]).astype('str')

                img = fluid.dygraph.to_variable(dy_x_data)
                label = fluid.dygraph.to_variable(y_data)
                label.stop_gradient = True

                out = SketchNet(img)
                softmax_out = fluid.layers.softmax(out, use_cudnn=False)
                acc_test = fluid.layers.accuracy(input=softmax_out, label=label)

                pred_label = paddle.argmax(softmax_out, axis=1)
                comp = fluid.metrics.CompositeMetric()
                precision = fluid.metrics.Precision()
                recall = fluid.metrics.Recall()
                comp.add_metric(precision)
                comp.add_metric(recall)
                comp.update(preds=np.array(pred_label), labels=np.array(label))
                acc, recall = comp.eval()
                loss = fluid.layers.cross_entropy(softmax_out, label)
                avg_loss = fluid.layers.mean(loss)
                avg_loss.backward()

                optimizer.minimize(avg_loss)
                SketchNet.clear_gradients()
                batch_count += 1
                epoch_acc += acc
                epoch_recall += recall
                if batch_id % 5 == 0 and batch_id != 0:
                    utils.logger.info("train: loss at epoch {} step {}: avg_loss: {}, acc: {}, recall: {}"
                                      .format(current_epoch, batch_id, avg_loss, acc, recall))

            epoch_acc /= batch_count
            epoch_recall /= batch_count
            utils.logger.info("epoch {} acc: {} recall: {}".format(current_epoch, epoch_acc, epoch_recall))
            if epoch_acc >= current_acc:
                utils.logger.info("current epoch {} acc: {} better than last acc: {}, save model"
                                  .format(current_epoch, epoch_acc, current_acc))
                current_acc = epoch_acc
                fluid.dygraph.save_dygraph(SketchNet.state_dict(), config.train_parameters['save_model_dir'])
                fluid.dygraph.save_dygraph(optimizer.state_dict(), config.train_parameters['save_model_dir'])
                eval_model()
        utils.logger.info("train till end")

if __name__ == "__main__":
    train()

4. 训练手绘指定时刻钟表图像识别模型

  • 逐一分析”手绘指定时刻钟表“测试项目中3分构成(钟表是否满圆、数字是否依序、指针朝向是否正确),依据实验结果,选取MobileNet轻量级网络作为基础,针对实际应用场景进行优化。模型部署简单、结果反馈迅速。对应代码、数据,详见network.py文件。

5. 模型部署

  • 立方体识别、手绘钟表识别训练模型的部署,具体代码详见train_clock.py、train_cube_test.py文件。

  • 以cube识别为例,测试在训练集上已习得的模型。

import utils
import paddle.fluid as fluid
import paddle
import reader
from network import SketchANet, MobileNet
import numpy as np
import os
import config
import time
from config import __init_train_parameters_clock, __init_train_parameters

# __init_train_parameters('cube')
__init_train_parameters_clock('clock')

# 加载cube识别评估文件
def eval_model():
    utils.logging.info("start eval")
    file_list = os.path.join(config.train_parameters['data_dir'], config.train_parameters['eval_file_list'])
    with fluid.dygraph.guard():
        params, _ = fluid.dygraph.load_persistables(config.train_parameters['save_model_dir'])
        params, _ = fluid.load_dygraph(config.train_parameters['save_model_dir'])
        SketchNet = SketchANet(config.train_parameters['class_dim'])
        SketchNet.set_dict(params)
        SketchNet.eval()
        test_reader = paddle.batch(reader.image_reader(file_list, 'val'),
                                   batch_size=1,
                                   drop_last=True)

        accs = []
        recall_list = []
        start_time = time.time()
        for batch_id, data in enumerate(test_reader()):
            dy_x_data = np.array([x[0] for x in data]).astype('float32')
            y_data = np.array([[x[1]] for x in data]).astype('int')

            img = fluid.dygraph.to_variable(dy_x_data)
            label = fluid.dygraph.to_variable(y_data)
            label.stop_gradient = True

            out = SketchNet(img)
            softmax_out = fluid.layers.softmax(out, use_cudnn=False)
            pred_label = paddle.argmax(softmax_out, axis=1)

            comp = fluid.metrics.CompositeMetric()
            precision = fluid.metrics.Precision()
            recall = fluid.metrics.Recall()
            comp.add_metric(precision)
            comp.add_metric(recall)
            comp.update(preds=np.array(pred_label), labels=np.array(label))
            acc, recall = comp.eval()

            accs.append(acc)
            recall_list.append(recall)
    utils.logger.info("test count: {} , acc: {}, recall: {}, cost time: {}"
                      .format(config.train_parameters['eval_image_count'], np.mean(accs), np.mean(recall_list), time.time() - start_time))

项目展示

总结展望

  • “智能交互的阿尔兹海默症辅助筛查系统”智能化地解决当前阿尔兹海默症初步筛查方式中测试方式繁琐、评价方式过于主观、筛查结果可靠性不高的问题。
  • 借助成熟的百度深度学习框架PaddlePaddle,我们稳健地搭建了图像识别模型。实际应用中,能快速完成手绘图像识别任务,可靠给出高置信度测试分数。
  • 借助成熟的百度语音识别API,我们系统地构建起智能语音交互系统。筛查过程,测试者可在没有医护人员参与情况下,以基于语音交互的方式实现MoCA测试。
  • 二期任务中,扩充系统支持的交互语言,增加广东话、上海话、英语、德语等,从系统支持语言端着手,泛化筛查可覆盖人群。
  • 赛后,该系统将部署于某三甲医院相关科室,实际用于阿尔茨海默病的临床筛查。

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值