softmax函数的硬件实现

softmax作用

首先我们简单的介绍一下我们使用softmax的用处,以及softmax的函数形式。softmax在维基百科上面的解释是:“softmax function is a generalization of the logistic function that maps a length-p vector of real values to a length-K vector of values”简单的说它是一种适用于多分类的损失函数。公式表示如下:
这里写图片描述
我们可以看到,要使用硬件去实现softmax函数首先必须硬件实现exp(x)。实现指数函数的思路是什么呢?首先,我们可以使用泰勒级数展开,那么为了保持一定的精度我需要使用很多的乘法器所以每一个指数函数都需要耗费大量的时钟周期,所以说我们还是选择比较简单的分段线段拟合比较好。下面提供我的思路:我使用分段的线段去拟合softmax函数,当然在硬件实现之前我做了很多关于阈值还有精度对神经网络训练精度的实验发现缩小softmax的输入压缩到一定的阈值比如(-10,10)发现对训练正确率的影响很小。那么我可以使用一个ROM来存储线段的K,B,在计算的时候每一个指数函数我都只需要一个乘法器就能计算。这样大大提高了计算速度。(需要注意的是:我们的硬件结构计算全部使用的是十六位浮点型数据,做乘除法的时候只需要调用IP核就行了)

ROM存储K,B信息

memory_initialization_radix=2;
memory_initialization_vector=
0000100000000000
0001100000000010
0001000000000000
0001110000000101
0001010000000001
0010000000001100
0001110000000100
0010010000011100
0010000000001011
0010110001000010
0010100000100000
0011000010010011
0010110001010111
0011010100111001
0011000011101110
0011101001100110
0011101010000111
0011110000000000
0011111011011111
0011110000000000
0100010010101011
1010100000110000
0100101001011001
1011101111111100
0101000001010000
1011101000110000
0101010111011101
1011010101011010
0101101111111000
1110010001100110
0110000101101010
1110101101010101
0110011101011100
1111000111100111
0110110100000000
1111100010100011
0111001011001100
1111111100100111;

通过输入得到k,b地址

`timescale 1ns / 1ps
//
// Company: 
// Engineer: 
// 
// Create Date: 2017/01/12 20:52:46
// Design Name: 
// Module Name: getaddr
// Project Name: 
// Target Devices: 
// Tool Versions: 
// Description: 
// 
// Dependencies: 
// 
// Revision:
// Revision 0.01 - File Created
// Additional Comments:
// 
//


module getaddr(aclk,x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,
addra0,addra1,addra2,addra3,addra4,addra5,addra6,addra7,addra8,addra9,
addrb0,addrb1,addrb2,addrb3,addrb4,addrb5,addrb6,addrb7,addrb8,addrb9);
input aclk;
input [15:0] x0,x1,x2,x3,x4,x5,x6,x7,x8,x9;
output[5:0] addra0,addra1,addra2,addra3,addra4,addra5,addra6,addra7,addra8,addra9;
output[5:0] addrb0,addrb1,addrb2,addrb3,addrb4,addrb5,addrb6,addrb7,addrb8,addrb9;
reg [5:0] addra0,addra1,addra2,addra3,addra4,addra5,addra6,addra7,addra8,addra9;
reg [5:0] addrb0,addrb1,addrb2,addrb3,addrb4,addrb5,addrb6,addrb7,addrb8,addrb9;
always @(posedge aclk)
begin
if(x0[15:15]==1)
        begin
        if((x0[14:0]<=15'b100100010000000)&&(x0[14:0]>15'b100100000000000)) 
            begin
                addra0=6'b000000;
                addrb0=6'b000001;
            end
        else if((x0[14:0]<=15'b100100000000000)&&(x0[14:0]>15'b100011100000000))
            begin
                addra0=6'b000010;
                addrb0=6'b000011;
            end
        else if((x0[14:0]<=15'b100011100000000)&&(x0[14:0]>15'b100011000000000))
            begin
                addra0=6'b000100;
                addrb0=6'b000101;
            end
        else if((x0[14:0]<=15'b100011000000000)&&(x0[14:0]>15'b100010100000000))
            begin
                addra0=6'b000110;
                addrb0=6'b000111;
            end
        else if((x0[14:0]<=15'b100010100000000)&&(x0[14:0]>15'b100010000000000))
            begin
                addra0=6'b001000;
                addrb0=6'b001001;
            end
        else if((x0[14:0]<=15'b100010000000000)&&(x0[14:0]>15'b100001000000000))
            begin
                addra0=6'b001010;
                addrb0=6'b001011;
            end
        else if((x0[14:0]<=15'b100001000000000)&&(x0[14:0]>15'b100000000000000))
            begin
                addra0=6'b001100;
                addrb0=6'b001101;
            end
        else if((x0[14:0]<=15'b100000000000000)&&(x0[14:0]>15'b011110000000000))
            begin
                addra0=6'b001110;
                addrb0=6'b001111;
            end
        else if((x0[14:0]<=15'b011110000000000)&&(x0[14:0]>15'b000000000000000))
            begin
                addra0=6'b010000;
                addrb0=6'b010001;
            end
            end
else
     begin
        if((x0[14:0]<=15'b011110000000000)&&(x0[14:0]>15'b000000000000000))
            begin
                addra0=6'b010010;
                addrb0=6'b010011;
            end
        else if((x0[14:0]<=15'b100000000000000)&&(x0>15'b011110000000000))
            begin
                addra0=6'b010100;
                addrb0=6'b010101;
            end
        else if((x0[14:0]<=15'b100001000000000)&&(x0>15'b100000000000000))
            begin
                addra0=6'b010110;
                addrb0=6'b010111;
            end
        else if((x0[14:0]<=15'b100010000000000)&&(x0[14:0]>15'b100001000000000))
            begin
                addra0=6'b011000;
                addrb0=6'b011001;
            end
        else if((x0[14:0]<=15'b100010100000000)&&(x0[14:0]>15'b100010000000000))
            begin
                addra0=6'b011010;
                addrb0=6'b011011;
            end
        else if((x0[14:0]<=15'b100011000000000)&&(x0[14:0]>15'b100010100000000))
            begin
                addra0=6'b011100;
                addrb0=6'b011101;
            end
        else if((x0[14:0]<=15'b100011100000000)&&(x0[14:0]>15'b100011000000000))
            begin
                addra0=6'b011110;
                addrb0=6'b011111;
            end
        else if((x0[14:0]<=15'b100100000000000)&&(x0[14:0]>15'b100011100000000))
            begin
                addra0=6'b100000;
                addrb0=6'b100001;
            end
        else if((x0[14:0]<=15'b100100010000000)&&(x0[14:0]>15'b100100000000000))
            begin
                addra0=6'b100010;
                addrb0=6'b100011;
            end
        else if((x0[14:0]<=15'b100100100000000)&&(x0[14:0]>15'b100100010000000))
            begin
                addra0=6'b100100;
                addrb0=6'b100101;
            end
    end
    end
always @(posedge aclk)
    begin
    if(x1[15:15]==1)
            begin
            if((x1[14:0]<=15'b100100010000000)&&(x1[14:0]>15'b100100000000000)) 
                begin
                    addra1=6'b000000;
                    addrb1=6'b000001;
                end
            else if((x1[14:0]<=15'b100100000000000)&&(x1[14:0]>15'b100011100000000))
                begin
                    addra1=6'b000010;
                    addrb1=6'b000011;
                end
            else if((x1[14:0]<=15'b100011100000000)&&(x1[14:0]>15'b100011000000000))
                begin
                    addra1=6'b000100;
                    addrb1=6'b000101;
                end
            else if((x1[14:0]<=15'b100011000000000)&&(x1[14:0]>15'b100010100000000))
                begin
                    addra1=6'b000110;
                    addrb1=6'b000111;
                end
            else if((x1[14:0]<=15'b100010100000000)&&(x1[14:0]>15'b100010000000000))
                begin
                    addra1=6'b001000;
                    addrb1=6'b001001;
                end
            else if((x1[14:0]<=15'b100010000000000)&&(x1[14:0]>15'b100001000000000))
                begin
                    addra1=6'b001010;
                    addrb1=6'b001011;
                end
            else if((x1[14:0]<=15'b100001000000000)&&(x1[14:0]>15'b100000000000000))
                begin
                    addra1=6'b001100;
                    addrb1=6'b001101;
                end
            else if((x1[14:0]<=15'b100000000000000)&&(x1[14:0]>15'b011110000000000))
                begin
                    addra1=6'b001110;
                    addrb1=6'b001111;
                end
            else if((x1[14:0]<=15'b011110000000000)&&(x1[14:0]>15'b000000000000000))
                begin
                    addra1=6'b010000;
                    addrb1=6'b010001;
                end
                end
    else
         begin
            if((x1[14:0]<=15'b011110000000000)&&(x1[14:0]>15'b000000000000000))
                begin
                    addra1=6'b010010;
                    addrb1=6'b010011;
                end
            else if((x1[14:0]<=15'b100000000000000)&&(x1>15'b011110000000000))
                begin
                    addra1=6'b010100;
                    addrb1=6'b010101;
                end
            else if((x1[14:0]<=15'b100001000000000)&&(x1>15'b100000000000000))
                begin
                    addra1=6'b010110;
                    addrb1=6'b010111;
                end
            else if((x1[14:0]<=15'b100010000000000)&&(x1[14:0]>15'b100001000000000))
                begin
                    addra1=6'b011000;
                    addrb1=6'b011001;
                end
            else if((x1[14:0]<=15'b100010100000000)&&(x1[14:0]>15'b100010000000000))
                begin
                    addra1=6'b011010;
                    addrb1=6'b011011;
                end
            else if((x1[14:0]<=15'b100011000000000)&&(x1[14:0]>15'b100010100000000))
                begin
                    addra1=6'b011100;
                    addrb1=6'b011101;
                end
            else if((x1[14:0]<=15'b100011100000000)&&(x1[14:0]>15'b100011000000000))
                begin
                    addra1=6'b011110;
                    addrb1=6'b011111;
                end
            else if((x1[14:0]<=15'b100100000000000)&&(x1[14:0]>15'b100011100000000))
                begin
                    addra1=6'b100000;
                    addrb1=6'b100001;
                end
            else if((x1[14:0]<=15'b100100010000000)&&(x1[14:0]>15'b100100000000000))
                begin
                    addra1=6'b100010;
                    addrb1=6'b100011;
                end
            else if((x1[14:0]<=15'b100100100000000)&&(x1[14:0]>15'b100100010000000))
                begin
                    addra1=6'b100100;
                    addrb1=6'b100101;
                end
        end
        end
    always @(posedge aclk)
    begin
    if(x2[15:15]==1)
            begin
            if((x2[14:0]<=15'b100100010000000)&&(x2[14:0]>15'b100100000000000)) 
                begin
                    addra2=6'b000000;
                    addrb2=6'b000001;
                end
            else if((x2[14:0]<=15'b100100000000000)&&(x2[14:0]>15'b100011100000000))
                begin
                    addra2=6'b000010;
                    addrb2=6'b000011;
                end
            else if((x2[14:0]<=15'b100011100000000)&&(x2[14:0]>15'b100011000000000))
                begin
                    addra2=6'b000100;
                    addrb2=6'b000101;
                end
            else if((x2[14:0]<=15'b100011000000000)&&(x2[14:0]>15'b100010100000000))
                begin
                    addra2=6'b000110;
                    addrb2=6'b000111;
                end
            else if((x2[14:0]<=15'b100010100000000)&&(x2[14:0]>15'b100010000000000))
                begin
                    addra2=6'b001000;
                    addrb2=6'b001001;
                end
            else if((x2[14:0]<=15'b100010000000000)&&(x2[14:0]>15'b100001000000000))
                begin
                    addra2=6'b001010;
                    addrb2=6'b001011;
                end
            else if((x2[14:0]<=15'b100001000000000)&&(x2[14:0]>15'b100000000000000))
                begin
                    addra2=6'b001100;
                    addrb2=6'b001101;
                end
            else if((x2[14:0]<=15'b1000000000000
  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值