矩阵相乘nxn block的计算过程

/* Create macros so that the matrices are stored in column-major order */

#define A(i,j) a[ (j)*lda + (i) ]
#define B(i,j) b[ (j)*ldb + (i) ]
#define C(i,j) c[ (j)*ldc + (i) ]

/* Routine for computing C = A * B + C */

void AddDot4x4( int, float *, int, float *, int, float *, int );

void MY_MMult( int m, int n, int k, float *a, int lda, 
                                    float *b, int ldb,
                                    float *c, int ldc )
{
  int i, j;

  for ( j=0; j<n; j+=4 ){        /* Loop over the columns of C, unrolled by 4 */
    for ( i=0; i<m; i+=4 ){        /* Loop over the rows of C */
      /* Update C( i,j ), C( i,j+1 ), C( i,j+2 ), and C( i,j+3 ) in
         one routine (four inner products) */

      AddDot4x4( k, &A( i,0 ), lda, &B( 0,j ), ldb, &C( i,j ), ldc );
    }
  }
}


void AddDot4x4( int k, float *a, int lda,  float *b, int ldb, float *c, int ldc )
{
  /* So, this routine computes a 4x4 block of matrix A

           C( 0, 0 ), C( 0, 1 ), C( 0, 2 ), C( 0, 3 ).  
           C( 1, 0 ), C( 1, 1 ), C( 1, 2 ), C( 1, 3 ).  
           C( 2, 0 ), C( 2, 1 ), C( 2, 2 ), C( 2, 3 ).  
           C( 3, 0 ), C( 3, 1 ), C( 3, 2 ), C( 3, 3 ).  

     Notice that this routine is called with c = C( i, j ) in the
     previous routine, so these are actually the elements 

           C( i  , j ), C( i  , j+1 ), C( i  , j+2 ), C( i  , j+3 ) 
           C( i+1, j ), C( i+1, j+1 ), C( i+1, j+2 ), C( i+1, j+3 ) 
           C( i+2, j ), C( i+2, j+1 ), C( i+2, j+2 ), C( i+2, j+3 ) 
           C( i+3, j ), C( i+3, j+1 ), C( i+3, j+2 ), C( i+3, j+3 ) 
          
     in the original matrix C 

     A simple rearrangement to prepare for the use of vector registers */

  int p;
  register float 
    /* hold contributions to
       C( 0, 0 ), C( 0, 1 ), C( 0, 2 ), C( 0, 3 ) 
       C( 1, 0 ), C( 1, 1 ), C( 1, 2 ), C( 1, 3 ) 
       C( 2, 0 ), C( 2, 1 ), C( 2, 2 ), C( 2, 3 ) 
       C( 3, 0 ), C( 3, 1 ), C( 3, 2 ), C( 3, 3 )   */
       c_00_reg,   c_01_reg,   c_02_reg,   c_03_reg,  
       c_10_reg,   c_11_reg,   c_12_reg,   c_13_reg,  
       c_20_reg,   c_21_reg,   c_22_reg,   c_23_reg,  
       c_30_reg,   c_31_reg,   c_32_reg,   c_33_reg,
    /* hold 
       A( 0, p ) 
       A( 1, p ) 
       A( 2, p ) 
       A( 3, p ) */
       a_0p_reg,
       a_1p_reg,
       a_2p_reg,
       a_3p_reg,
       b_p0_reg,
       b_p1_reg,
       b_p2_reg,
       b_p3_reg;

  float 
    /* Point to the current elements in the four columns of B */
    *b_p0_pntr, *b_p1_pntr, *b_p2_pntr, *b_p3_pntr; 
    
  b_p0_pntr = &B( 0, 0 );
  b_p1_pntr = &B( 0, 1 );
  b_p2_pntr = &B( 0, 2 );
  b_p3_pntr = &B( 0, 3 );

  c_00_reg = 0.0;   c_01_reg = 0.0;   c_02_reg = 0.0;   c_03_reg = 0.0;
  c_10_reg = 0.0;   c_11_reg = 0.0;   c_12_reg = 0.0;   c_13_reg = 0.0;
  c_20_reg = 0.0;   c_21_reg = 0.0;   c_22_reg = 0.0;   c_23_reg = 0.0;
  c_30_reg = 0.0;   c_31_reg = 0.0;   c_32_reg = 0.0;   c_33_reg = 0.0;

	/*
	TODO:
	process of calculation
	
	c00 = a00xb00
	c10 = a10xb00
	c01 = a00xb01
	c11 = a10xb01

	c00 = a00xb00 + a01xb10
	c10 = a10xb00 + a11xb10	
	c01 = a00xb01 + a01xb11
	c11 = a10xb01 + a11xb11

	c00 = a00xb00 + a01xb10 + a02xb20
	c10 = a10xb00 + a11xb10 + a12xb20	
	c01 = a00xb01 + a01xb11 + a02xb21
	c11 = a10xb01 + a11xb11 + a12xb21


	c00 = a00xb00 + a01xb10 + a02xb20 + a03xb30
	c10 = a10xb00 + a11xb10 + a12xb20 + a13xb30	
	c01 = a00xb01 + a01xb11 + a02xb21 + a03xb31
	c11 = a10xb01 + a11xb11 + a12xb21 + a13xb31

	...
	...
	...
	k = k -1

	c00 = a00xb00 + a01xb10 + a02xb20 + a03xb30 + ... + a0k * bk0
	c10 = a10xb00 + a11xb10 + a12xb20 + a13xb30 + ... + a1k * bk0
	c01 = a00xb01 + a01xb11 + a02xb21 + a03xb31 + ... + a0k * bk1
	c11 = a10xb01 + a11xb11 + a12xb21 + a13xb31 + ... + a1k * bk1
		
	
	*/
  for ( p=0; p<k; p++ ){
    a_0p_reg = A( 0, p );
    a_1p_reg = A( 1, p );
    a_2p_reg = A( 2, p );
    a_3p_reg = A( 3, p );

    b_p0_reg = *b_p0_pntr++;
    b_p1_reg = *b_p1_pntr++;
    b_p2_reg = *b_p2_pntr++;
    b_p3_reg = *b_p3_pntr++;

    /* First row and second rows */
    c_00_reg += a_0p_reg * b_p0_reg;
    c_10_reg += a_1p_reg * b_p0_reg;

    c_01_reg += a_0p_reg * b_p1_reg;
    c_11_reg += a_1p_reg * b_p1_reg;

    c_02_reg += a_0p_reg * b_p2_reg;
    c_12_reg += a_1p_reg * b_p2_reg;

    c_03_reg += a_0p_reg * b_p3_reg;
    c_13_reg += a_1p_reg * b_p3_reg;

    /* Third and fourth rows */
    c_20_reg += a_2p_reg * b_p0_reg;
    c_30_reg += a_3p_reg * b_p0_reg;

    c_21_reg += a_2p_reg * b_p1_reg;
    c_31_reg += a_3p_reg * b_p1_reg;

    c_22_reg += a_2p_reg * b_p2_reg;
    c_32_reg += a_3p_reg * b_p2_reg;

    c_23_reg += a_2p_reg * b_p3_reg;
    c_33_reg += a_3p_reg * b_p3_reg;
  }

  C( 0, 0 ) += c_00_reg;   C( 0, 1 ) += c_01_reg;   C( 0, 2 ) += c_02_reg;   C( 0, 3 ) += c_03_reg;
  C( 1, 0 ) += c_10_reg;   C( 1, 1 ) += c_11_reg;   C( 1, 2 ) += c_12_reg;   C( 1, 3 ) += c_13_reg;
  C( 2, 0 ) += c_20_reg;   C( 2, 1 ) += c_21_reg;   C( 2, 2 ) += c_22_reg;   C( 2, 3 ) += c_23_reg;
  C( 3, 0 ) += c_30_reg;   C( 3, 1 ) += c_31_reg;   C( 3, 2 ) += c_32_reg;   C( 3, 3 ) += c_33_reg;
}



注意其中的计算过程,迭代过程。


展开阅读全文

没有更多推荐了,返回首页