self_drive car_学习笔记--第9课:预测系统

前言:这节课主要是介绍无人车里面的预测系统。水平有限,有些理解错误地方,还望大大们不吝赐教。觉得写得还行,麻烦赏个赞哈。好了,不废话,开始主题。

概要:
1 PNC OVERVIEW
2 PREDICTION TASK
3 VEHICLE PREDICT
4 PEDESTRIAN PREDICT

补充,PNC【planning and control】(预测系统是其组成之一)、感知是L4级别的两大重点模块

1 PNC OVERVIEW

1.1自动驾驶软件框架
在这里插入图片描述
Prediction属于planing的组成部分,其负责将感知传过来的信息,对这些信息进行加工处理,让感知对这些信息的变化产生更加敏感反应;对于理解障碍物的意图,是非常重要的。如果只是对静态物体进行规划,是很难达到一个可行的方案。

1.2 PNC

1)PNC:Planning and Control

2)Prediction->Routing->planning(decision、path、speed)->Control
【预测,选择,规划(决策,(横向)轨迹,(纵向)速度),控制】

3)淡绿色框是PNC所涵盖的内容
在这里插入图片描述
Canbus车辆底盘开发总线【控制与开发总线进行交互,从而实现无人驾驶】

4)平均跑多远需要一次人工接管(业界称为MPI=行驶总里程/人工干预次数)
在这里插入图片描述
【该指标可以一定程度上体现自动驾驶的能力大小,但不是绝对的。】
【后面介绍该报告,重点介绍遥遥领先的Waymo(代表自动驾驶的后期阶段)接管的原因,公开数据说接管的50%原因是PNC,25%是硬件,15%是天气,5%是感知】

【PNC在自动驾驶研发的后期阶段,研发难度会越来越高】【无人驾驶公司,初期阶段,主要专注于感知模块,研发到后期,PNC模块的重要性会提升】

【最难处理问题是,数据量很小,也就是极端情况下,怎么处理,业界内称为corner case,也就是回到PNC怎么解决这个问题】
【corner case,举例子,当十字路口,直行和左转都是绿灯时候,都是可行驶的,虽然优先是直行的,但是,对于无人车来说,很难确定直线车辆以及左转车辆的运动意图,也就是动态障碍物的意向,这总情况就是corner case】

【无保护左转,拥堵式汇入,是waymo说的两大难点】

2 PREDICTION TASK

在这里插入图片描述
1)为啥预测
----多类型【指的是障碍物类型,动态、静态判断,主要是动态需要预测】
【prediction和trash有区别的,trash只是判断接下来一两帧,也就是200-300ms的移动判断;而prediction是一个长时间推理预测,5-8s,其结合的信息比较丰富】【】

2)必要性
----实时性【障碍物很多,能给的判断时间很短的】、准确【感知过滤掉不合理的信息,获取得到的信息更加准确,同时感知也会对场景进行一个深加工,让后面的模块可以更容易处理】

3)理解corner case【一些不规律动态障碍物移动轨迹,比如怎么判断一个马路边上的人,怎么预测到那人是否要过马路呢】

4)人怎么考虑的(指的是考虑行走路线)呢?【比如,corner case 场景,人过马路,会实时根据自己走的方向,看有没有来车,是不是绿灯,就选择过马路与否】

2.1 两种方法(two methods)
【主流预测方法】
在这里插入图片描述
在这里插入图片描述
1)基于模型方法:【更适合物理、机械背景的同学使用】
–结合人类经验
–【有利于也便于】解决corner case 问题

2)基于数据驱动的方法:【数据越多,效果越来愈好;但是,代码可读性会降低】
–机器学习
–越多数据(效果越好)

3)特点:
–基于模型的方法,更容易结合人类经验,对于corner case的处理更加方便
–基于数据驱动方法,随着功能的完善,代码的可读性会降低
–应该结合两者的进行使用

4)图片的案例就是将两种方法结合起来考虑

5)如何知道业界怎么处理预测问题
–上业界最牛的公司招聘官网,看他们的招聘要求。比如
在这里插入图片描述
在这里插入图片描述
【从上面的招聘来看,waymo主要采用数据驱动方式来做预测】【这里做预测的数据,是从复杂的情况中,抽取特征来使用,相对感知使用比较困难】

3 VEHICLE PREDICT

案例介绍:阿波罗
步骤:
----git clone https://github.com/ApolloAuto/apollo.git【下载代码】
----Git checkout r3.5.0【切换当时最新分支,该版本去除ros了】
----Modules/prediction【找到预测模块】

预测模块截图
在这里插入图片描述
【Apollo项目对于L4级的学习有很好的效果】

3.1 lane model
【道路建模方法】
在这里插入图片描述
1)lane sequence(车道顺序)
–HD map(高精地图)【高精地图在车道建模中位置很重要】
–Junction(拐点)
–Off lane(车道外,这里理解是编号lane的区域外意思,比如斑马线区域就是off lane)
【当车辆处于off lane状态时,采用的是根据车头朝向,就近原则来归属车辆所在的lane(延长线)】

2)classification(分类)【结合上面的车道模型来看,列举出车所有可能走的路线;这是车道建模的思想,预测车辆行走的可能性】
----线:0->1->3->7
----线:0->2->6【为了便于建模,所有不把2区域内的建模情况计算进来】
----线:0->4->5
----更多可能

【道路建模思想,就是分类出车辆所可能的行走可能性】

3)道路建模考虑点【将上述的复杂状态,做出简化,方便分析处理;将连续空间离散化处理】

3-1)lane feature(车道特征)【如下左侧图片】:
----lane S【车前进方向的,这里S,L是指的是车道线坐标系里面的两轴线】
----lane L【车与车道边界线距离】
----reference lane【跟道路中心线差多远】
----curvature(曲率)【车道的弯曲情况】
----traffic law【交通信号】

3-2)vehicle state(车辆状态)【如图下右侧】【速度、加速度、车头角度】
----velocity【速度】
----Acc【Adaptive cruise control ,自适应巡航控制】
----Heading【朝向】
----Heading rate【车头速度】
----Type【车类型,如私家小车、卡车;车类型不同,行为也不一样】
----Size【大小】
----等等

3-3)environment(环境信息)【比如,无人车感知到左侧有车,那么变道就不选择左侧了】
【上述三特征是lane model考虑的重点】
在这里插入图片描述
3-4)Network【网络】
----behavior【车辆行走的行为】

3-5)Sequence data【序列化数据】【右图表示有,车道线输入个数或长度,障碍物信息时长;表示这两者要根据模型来调整对应的参数】
----RNN【使用神经网络来处理这些输入信息(属于序列化任务)】
在这里插入图片描述
【车道线+障碍物状态->模型->输出走的路线概率;建立好模型后,需要选择合适的输入输出;输入HD map和感知信息,注意,车道线个数或长度以及障碍物时长等等;输出时候,根据无人车的应用场景来输出预测,比如高速上,需要输出8s的预测;而对于园区的物流小车,可能2秒就行了】

3.2 SEQUENCE DATA NETWORK【序列化数据网络】

1)常用处理方式:
----MLP
----CNN

----RNN【RNN运作时候,需要一步一步推进,CPU、GPU存在等待的过程,是其一个缺点】
--------LSTM【这种属于处理序列化数据比较好方式】
--------GRU

----Attention【注意力机制】
----TCN【时间上序列卷积,是CNN的改进版本;也是RNN的替代版本,比较热门】

2)这里介绍的是RNN方式【预测任务可理解为一种监督学习】
在这里插入图片描述
来源:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
在这里插入图片描述
【上图,x表示输入;h表示输出,但是也是作为下一次输入的反馈,也属于输入层面】
在这里插入图片描述
【上图,Ct可理解为中间状态的记录,表示哪些信息需要的,哪些信息需要舍去的】
在这里插入图片描述
【上图,Ot表示每次循环输出】

【概率论是无人驾驶的基础课程之一,因为现在都是使用数据驱动的;无人车涉及的知识很多,需要根据你选择其中的方向来选择学习】【车规级的要求,是当前无人车的一大难点】

3)感知模块输入样式,以阿波罗为例:

https://github.com/ApolloAuto/apollo/blob/r3.5.0/modules/perception/proto/perception_obstacle.proto

【注意感知输入跟传感器数据输入有区别的,感知是前面经过处理后的信息,而不是传感器输入数据】
【阿波罗感知输入时,使用到ID,位置,角度和速度等信息】

【感知数据是实时变化的,但是,感知处理时候,会使用timestamp来保存输入感知信息的,而不是传一个就处理一个或者过了很久的也拿来用,也就是处理的是对应的时间的感知数据】

4)预测模块输出样式:

https://github.com/ApolloAuto/apollo/blob/r3.5.0/modules/prediction/proto/prediction_obstacle.proto

【输出:轨迹+概率;轨迹使用轨迹点来代替,这个点的定义函数TrajectoryPoint里面,id,速度,加速度,相对时间(相对给出预测信息的时间),转向角信息等等;】

3.3 APOLLO MODEL
【介绍阿波罗如何做预测】

1)MLP model
2)RNN model
在这里插入图片描述
【上面是其预测整体处理的流程图】

https://github.com/ApolloAuto/apollo/tree/r3.5.0/modules/prediction

链接里面的部分文件夹包含内容介绍:
common:包含模块使用数据结构;
Conf : 配置文件;
Container:记录抽特征过程,并保存所抽取的特征;预测信息也是保存在这里面的;

Dag:是阿波罗用起来代替ros系统框架;
Data:放已经训练好的模型;

Evaluator:放模型推理的过程,【把特征提取出来,往模型里面一放,输出一个概率来】
Images:【开源用的,个人感觉是docker镜像】

Launch:启动文件
Network:阿波罗自家的推理框架【主流的推理框架如tensorflow】
Predictor:就是当拿到了上面的概率后,怎么得到轨迹,理解为根据概率获取planing使用的轨迹的文件

Proto:一些数据结构【Google为了方便存储数据结构所开发的一种格式】
Scenario:场景的分类框架

Testdata:测试数据
Util:一些工具类
Prediction_component.cc : 等效于ros node的创建文件

重点解析:
【阿波罗的代码不一定效率最好,但是对应学习者很友好,采用模式化方式编写并管理代码】
1)container:

----container_manager.h : 声明adc_trajectory【航线】、obstacles【障碍物】、pose【位置】相关信息在这里面声明
----obstacles:将感知获取的信息【主要指障碍物信息】进行抽取特征操作

—pose:保存无人车自身定位信息,如从车imu获取自身位姿【由于无人车相对旁边的车辆(障碍物)而言,也是障碍物,旁边车运动也会受无人车影响调整自己的运动状态,所以无人车也需要保存自身的位姿信息来做预测处理】
----adc_trajectory:将上一帧规划的轨迹保存下来,方便在有优先路权情况下做交互处理

2)evaluator【评价,这里推理器】
【里面代码采用统一的样式,对学习者优化,可读性也很强】

----cyclist:自动车模型推理【里面有多种场景的推理】
----vehicle:汽车模型推理【包含多种场景推理】
【其中,cost_开头的,是基于model方式,也就是上面说的建模方式进行推理;其余的都是基于数据驱动方式进行推理】

3)predictor
【很重要,但是知道其目的是根据evaluation?(获取的信息)绘制预测的轨迹信息就行了;里面有很多函数,也就是方法,针对不同场景也会选择不同的方法】

3.4 DATA PIPELINE
【数据管道,其是机器学习一个重要的环】

【预测,主流操作就是作为一个机器学习任务来做】【这里展示的是预测例子,在机器学习走的基本流程】

感知机器学习数据管道样例1:

1)Sampling engineering【样本工程】
样本清洗:滤除不合适的样本【感知,在超出能力视距范围时,获取的有效数据需要考虑,比如有效量程100m的64线激光雷达,虽然可以提供100m以外的数据,但是可用与否需要分析才知道】

Label:根据上面历史信息,推断出接下来的运动情况,如轨迹,称为label获取【由于机器学习,随着数据量的增加,模型效果越好,预测的轨迹更好】

样本平衡:正样本和负样本比例【比如前进的车概率,转弯车概率不一样,使用样本时,也应该根据比例调整样本比例】

【RNN预测模型产生场景:采集数据(开环方式:先采集数据后面再处理)–感知工程师标注处理(训练前准备)—使用模型推理感知结果—将感知结果交给PNC工程师—PNC工程师利用该结果刷对应的任务(训练模型)—刷出来的任务就是抽象出来的样本,就是最后跑的样本】

2)Model optimization【模型优化】
3)Feature engineering【特征工程】
在这里插入图片描述
【先通过label和特征选择训练感知RNN模型,也就是offline离线学习方式训练模型;然后得到模型后,开始结合online在线方式(代码跑在无人车上意思)开始测试训练得到的模型的效果;通过线上线下方式结合,可以让模型越来越好】

【训练过程就是线上,预测过程就是在线下做的,线下拿到数据回来线上继续训练,训练得到的模型(参数),再部署到线下预测】

感知机器学习里面的数据管道样例2:【waymo公司】
在这里插入图片描述
【data w/sensor logs获取到数据;对简单的数据进行标注auto-label;labelers工程师标注不简单的数据;从而快速大量产生一批数据labeled data;接着进行模型models的训练;模型训练时,使用了auto-tunes optimizes自动调参方式;将训练好的模型,放到仿真机机test&validation测试;仿真机通过后,就部署releases到线上;接着继续收集collects数据,产生data w/sensor logs;循环往复上面过程;waymo就是这么搞,使得自身实现叠加式进展。】

1)Auto labeling(自动标注)【对简单的数据进行自动标注】
–History data(历史数据)
–label data(标注数据)
–Positive and negative samples(优、劣样本)

2)label cleaning(标注清除)
3)label balance(标注平衡)

【上述是模型预测介绍】

3.5 TRAJECTORY BUILDER【轨迹生成】
在这里插入图片描述
【5秒内从A到B点过程所可能采用线路方式】

1)kalman filter(卡尔曼滤波)

【卡尔曼滤波轨迹应用:叠加平移到达目标点位置】
【拟合方法,比如,在5秒钟内,以每一步5%的拟合度,从A拟合到B点;那么kalman filter里面观测值,以观测结果的目标值乘以95%的方式往B点状态转移,每一次叠加以后,5秒后达到B点;这就是卡尔曼滤波在轨迹上面的运用样式】

【阿波罗项目也有类似使用https://github.com/ApolloAuto/apollo/blob/r3.5.0/modules/prediction/predictor/regional/regional_predictor.cc

2)polynomial(多项式拟合)

【轨迹使用多项式拟合出来的;假设利用的是三次多项式(ax(3立方)+bx(2次方)+cx+d)拟合平移过程轨迹,多项式的参数可以人为设定;此外,也可以使用模型输出启发项(常数项d),根据a\b\c和d的对应关系,计算出a\b\c,从而获取多项式表达式,从而可以算出,走多远才能在规定时间内到达B点】

【多项式拟合,属于计算机图形学的简单知识】

3)Velocity(速度)【这里的速度理解为速率比较合适】
【速度,属于动力学推理;公式可以理解为s=s0+vt,从而获取运动轨迹】

【上述关于阿波罗的预测方案案例,算是一个主流预测方案,其他的差异主要就是在特征的提取上,以及一些场景预测处理差异】

3.6 STATE OF THE ART
【比较新的预测方案(19年初)】

较新关于预测方法样例1:

Uber: short-term Motion Prediction of Traffic Actors for Autonomous
Driving using Deep Convolutional Networks
【论文来源https://arxiv.org/pdf/1808.05819.pdf

【预测公开数据集较少,而且一般不通用;不通用原因有,抽取特征不一致,不会把所有信息公开出来的】

----CNN model【CNN模型】

----Graph generate【生成图】
--------full information【信息完整,图一般包括很多信息,包括交规,道路信息等等;】

----Trajectory generate【路线生成】
--------regression problem【回归问题】【上面介绍,属于分类方式解决预测;这里采用的是回归方式解决预测】
在这里插入图片描述
【红色指的是预测的目标车辆,红色变化也表示的是历史帧;黄色是周围障碍物;黄色的变化该障碍物的历史帧;蓝色是label;绿色是trajectory路线;】【模型最后输出的是轨迹】
【该预测方法存在问题:由于障碍物很多,每一个都需要用一个图来处理,这会导致计算量很大;回归不稳定,得到的结果有抖动,可能是不可用的】

较新关于预测方法样例2:

Waymo: ChauffeurNet: Learning to Drive by Imitating the Best and Synthesizing the Worst
论文来源:https://arxiv.org/pdf/1812.03079.pdf
在这里插入图片描述
来源:https://sites.google.com/view/waymo-learn-to-drive
【该论文思路,规划和预测相结合】

----Mid-to-mid learning 【中到中的学习】
----Multi-task【多任务】
----imitation learning【模仿学习】
----combine prediction and planning【规划和预测结合方式】
在这里插入图片描述
在这里插入图片描述
【在仿真环境里面,让无人车一直跑,进行训练】
【label model不一定是未来的做预测的主流方式】

4 PEDESTRIAN PREDICT

4.1 简介

【行人预测,主要介绍李飞飞团队的文章来介绍】
【行人预测在无人驾驶中很重要,也很难;轨迹难以预测,不存在高精地图里面;行人安全受交通规则保护】
在这里插入图片描述
行人预测特点或者难点:

1)high randomness【高随机性】

2)low traffic constraints【交通规则约束少】

3)no kinematics model【缺少动力学模型】

4)benchmark【基准】
----ETH
----UCY

4.2行人预测方法介绍

方法一:
Li Feifei: Social LSTM: Human Trajectory Prediction in Crowded Spaces
下载链接:

https://openaccess.thecvf.com/content_cvpr_2016/papers/Alahi_Social_LSTM_Human_CVPR_2016_paper.pdf

----LSTM【long short-term memory,长短期记忆?】
----Social-pooling:Human-Human【人和人】
在这里插入图片描述
在这里插入图片描述
该论文对应的代码:https://github.com/xuerenlv/social-lstm-tf
【h表示输入,这里输入时候,不仅是自己信息,也把别人的状态一起输入,生成新的lstm】

文章方法二:
Li Feifei: Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks
下载链接:

https://openaccess.thecvf.com/content_cvpr_2018/papers/Gupta_Social_GAN_Socially_CVPR_2018_paper.pdf

----Generative Adversary Networks,GAN【gen网络方法,waymo也有在用;】
----Human-human interaction【人人交互】
在这里插入图片描述
在这里插入图片描述
论文对应的代码https://github.com/agrimgupta92/sgan
【中间大框就是生成一个预测轨迹,最右侧大框来判断轨迹的合理性,两者交互,不断训练,从而使得最后生成的结果更符合实际情况】

文章方法三:
Lifeifie:Peeking into the Future: Predicting Future Person Activities and Locations in Videos
下载:https://arxiv.org/pdf/1902.03748.pdf
----multi-task learning【多任务学习】【网络非常复杂】
----Grid mapping【网格地图】
----Combine perception and prediction【感知和预测的结合】
在这里插入图片描述
在这里插入图片描述
【该方法,将感知和预测做了一个结合】
在这里插入图片描述
在这里插入图片描述
【主要思想是,把人姿态信息、环境信息、分割、识别都加在一起,做了一套感知和预测的结合,其输出并不只是轨迹,而是输出人要去做的事情,做的任务】

在这里插入图片描述
【如图a,判断该人在车旁边,打算做的事情是开门】

5 小结

在这里插入图片描述
【这节课内容,主要是讲预测,分成4部分。PNC是什么东西,预测任务是做什么的;车辆预测;阿波罗的车辆预测,也包含了行人和自行车,介绍他们的预测具体怎么做的;简单介绍了行人预测】
【对于无人车来说,行人预测场景相对较少的,除非是在拥挤的有人的马路上行走;相对而言,小机器人跟行人交互多一些,其更加注重行人预测】

老师答疑环节

感知特征信息要求有哪些?
这个要根据实际的任务需要来定的,并没有一个绝对的结构化标准。

【感知、预测规划组成的PNC是L4级别自动驾驶的两大核心,也是非常难的两块】

思考:交通违规怎么办?
交互预测怎么处理?或者说,让行概率有多大,是一个博弈问题。【该问题也是业界一大难题,包括waymo在内,都做得不好】

个人总结:
这一节就是介绍预测,预测到底在无人车里面位置怎么样。区分预测和轨迹跟踪区别。预测采用的方式有模型驱动和数据驱动。重点介绍了数据驱动,也就是机器学习。介绍了,关于预测的机器学习到底怎么做的。接着,介绍车辆预测方法,并以阿波罗项目为例子,介绍了里面一些源码得含义。最后以李飞飞团队的三篇文章为例子,简单介绍行人预测的方法。个人感觉,听得很迷惑。没错,菜是原罪。

#####################
感恩授课老师的付出
图片版权归原作者所有
不积硅步,无以至千里
好记性不如烂笔头
感觉有点收获的话,麻烦大大们点赞收藏哈

### 回答1: 以下是一个 Python 代码示例,用于实现 multi-head self-attention: ```python import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model self.depth = d_model // num_heads self.query_linear = nn.Linear(d_model, d_model) self.key_linear = nn.Linear(d_model, d_model) self.value_linear = nn.Linear(d_model, d_model) self.output_linear = nn.Linear(d_model, d_model) def forward(self, query, key, value, mask=None): batch_size = query.size() # Linear transformations query = self.query_linear(query) key = self.key_linear(key) value = self.value_linear(value) # Split into heads query = query.view(batch_size * self.num_heads, -1, self.depth) key = key.view(batch_size * self.num_heads, -1, self.depth) value = value.view(batch_size * self.num_heads, -1, self.depth) # Transpose for matrix multiplication query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Calculate scores scores = torch.matmul(query, key.transpose(-2, -1)) scores = scores / torch.sqrt(torch.tensor(self.depth).float()) # Apply mask (if provided) if mask is not None: mask = mask.unsqueeze(1) scores = scores.masked_fill(mask == , -1e9) # Softmax attention_weights = nn.Softmax(dim=-1)(scores) # Dropout attention_weights = nn.Dropout(p=.1)(attention_weights) # Multiply by values context = torch.matmul(attention_weights, value) # Reshape and concatenate context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.depth) # Linear transformation output = self.output_linear(context) return output ``` 希望对你有所帮助! ### 回答2: 下面是使用Python语言实现multi-head self-attention的一个示例代码: ``` import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadSelfAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadSelfAttention, self).__init__() self.num_heads = num_heads self.d_head = d_model // num_heads self.fc_query = nn.Linear(d_model, d_model) self.fc_key = nn.Linear(d_model, d_model) self.fc_value = nn.Linear(d_model, d_model) self.fc_concat = nn.Linear(d_model, d_model) def forward(self, x): batch_size, seq_len, d_model = x.size() h = self.num_heads # Split input into multiple heads query = self.fc_query(x).view(batch_size, seq_len, h, self.d_head) key = self.fc_key(x).view(batch_size, seq_len, h, self.d_head) value = self.fc_value(x).view(batch_size, seq_len, h, self.d_head) # Compute attention scores scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_head ** 0.5) attn_weights = F.softmax(scores, dim=-1) # Apply attention weights to value vectors attended_values = torch.matmul(attn_weights, value) attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # Concatenate and linearly transform attended values output = self.fc_concat(attended_values) return output # 使用示例 d_model = 128 num_heads = 8 seq_len = 10 batch_size = 4 input_tensor = torch.randn(batch_size, seq_len, d_model) attention = MultiHeadSelfAttention(d_model, num_heads) output = attention(input_tensor) print("Input Shape: ", input_tensor.shape) print("Output Shape: ", output.shape) ``` 上述代码定义了一个`MultiHeadSelfAttention`的类,其中`forward`函数实现了multi-head self-attention的计算过程。在使用示例中,我们输入一个大小为`(batch_size, seq_len, d_model)`的张量,经过multi-head self-attention计算后输出一个大小为`(batch_size, seq_len, d_model)`的张量。其中`d_model`表示输入的特征维度,`num_heads`表示attention头的数量。 ### 回答3: 下面是使用Python实现multi-head self-attention示例的代码: ```python import torch import torch.nn as nn class MultiHeadSelfAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadSelfAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads self.head_size = embed_size // num_heads self.query = nn.Linear(embed_size, embed_size) self.key = nn.Linear(embed_size, embed_size) self.value = nn.Linear(embed_size, embed_size) self.out = nn.Linear(embed_size, embed_size) def forward(self, x): batch_size, seq_len, embed_size = x.size() # Split the embedding into num_heads and reshape x = x.view(batch_size, seq_len, self.num_heads, self.head_size) x = x.permute(0, 2, 1, 3) # Apply linear transformations to obtain query, key, and value query = self.query(x) key = self.key(x) value = self.value(x) # Compute scaled dot product attention scores scores = torch.matmul(query, key.permute(0, 1, 3, 2)) scores = scores / self.head_size**0.5 # Apply softmax to obtain attention probabilities attn_probs = nn.Softmax(dim=-1)(scores) # Apply attention weights to value and sum across heads attended = torch.matmul(attn_probs, value) attended = attended.permute(0, 2, 1, 3) attended = attended.contiguous().view(batch_size, seq_len, self.embed_size) # Apply output linear transformation output = self.out(attended) return output ``` 上述代码中定义了一个名为MultiHeadSelfAttention的类,继承自nn.Module,可以通过指定嵌入大小(embed_size)和头部数量(num_heads)来创建多头自注意力层。在前向传播方法forward中,先通过线性变换将输入张量分别变换为查询(query)、键(key)和值(value)张量。然后计算缩放点积注意力得分,将其作为注意力概率经过softmax函数进行归一化。通过注意力概率权重对值进行加权求和,并应用线性变换得到最终的输出张量。最后返回输出张量。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值