有时候需要对blas库函数进行重写
我写的一个函数为
#include <iostream>
using namespace std;
int main() {
const int M = 4;//A的行数,C的行数
const int N = 2;//B的列数,C的列数
const int K = 3;//A的列数,B的行数
const float alpha = 1;
const float beta = 0;
const float A[K*M] = { 1,2,3,4,5,6,7,8,9,8,7,6 };
const float B[K*N] = { 5,4,3,2,1,0 };
float C[M*N];
for (int i = 0; i<M; i++) {
for (int j = 0; j<N; j++) {
float sum = 0;
for (int k = 0; k<K; k++) {
sum += A[i*K + k] * B[k*N + j];
}
C[i*N + j] = alpha * sum + beta*C[i*N + j];
}
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
cout << C[i*N + j] << " ";
}
cout << endl;
}
}
但是有问题。
在caffe的tools文件夹下建立keshan.cpp
#include <cblas.h>
#include <iostream>
using namespace std;
int main(){
const int M = 4;//A的行数,C的行数
const int N = 2;//B的列数,C的列数
const int K = 3;//A的列数,B的行数
const double alpha = 1;
const double beta = 0;
const double A[K*M] = { 1,2,3,4,5,6,7,8,9,8,7,6 };
const double B[K*N] = { 5,4,3,2,1,0 };
double C[M*N];
int lda = K;
int ldb = N;
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B,ldb, beta, C, N);
for(int i=0;i<M;i++){
for(int j=0;j<N;j++){
cout<<C[i*N+j]<<"\t";
}
cout<<endl;
}
return 0;
}
输出结果为
14 8
41 26
68 44
67 46
和我写的一样,似乎没什么错误,但是这个只是对一种错误的描写。
在caffe中caffe_cpu_gemm定义在math_function.cpp中
重写caffe_cpu_gemm只需要增加
//重写caffe_cpu_gemm(float),假设没有transpose
template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const float alpha, const float* A, const float* B, const float beta,
float* C,int flag) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
if(TransB != CblasNoTrans){
//转置
float BT[N*K];
for(int i=0;i<N;i++){
for(int j=0;j<K;j++){
BT[j*N+i] = B[i*K+j];
}
}
//相乘
for(int i=0;i<M;i++){
for(int j=0;j<N;j++){
float sum = 0;
for(int k=0;k<K;k++){
sum += A[i*K+k]*BT[k*N+j];
}
C[i*N+j] = alpha * sum + beta*C[i*N+j];
}
}
}else{
//相乘
for(int i=0;i<M;i++){
for(int j=0;j<N;j++){
float sum = 0;
for(int k=0;k<K;k++){
sum += A[i*K+k]*B[k*N+j];
}
C[i*N+j] = alpha * sum + beta*C[i*N+j];
}
}
}
}
//重写caffe_cpu_gemm(double)
template<>
void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const double alpha, const double* A, const double* B, const double beta,
double* C,int flag) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
if(TransB != CblasNoTrans){
//转置
double BT[N*K];
for(int i=0;i<N;i++){
for(int j=0;j<K;j++){
BT[j*N+i] = B[i*K+j];
}
}
//乘法
for(int i=0;i<M;i++){
for(int j=0;j<N;j++){
double sum = 0;
for(int k=0;k<K;k++){
sum += A[i*K+k]*BT[k*N+j];
}
C[i*N+j] = alpha * sum + beta*C[i*N+j];
}
}
}else{
//乘法
for(int i=0;i<M;i++){
for(int j=0;j<N;j++){
double sum = 0;
for(int k=0;k<K;k++){
sum += A[i*K+k]*B[k*N+j];
}
C[i*N+j] = alpha * sum + beta*C[i*N+j];
}
}
}
}
这样就对caffe_cpu_gemm函数进行了重载,caffe_cpu_gemm中增加
template <typename Dtype>
void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
Dtype* C,int flag);
对base_conv_layer.cpp和inner_product.cpp中的caffe_cpu_gemm都改为带flag的即可。
下面写的矩阵小框架倒是不错,值得参考
//transpote转置矩阵
#include "Stdio.h"
#include "memory.h"
template<typename T>
void TypePrint(T v);
template<typename T, int M, int N>
class Matrix
{
public:
Matrix(void) {
data = new T[M*N];
};
~Matrix(void) {};
int getVIndex()
{
return M;
}
int getHIndex()
{
return N;
}
T getxy(int x, int y)
{
return data[x*N + y];
}
void setxy(int x, int y, T f)
{
data[x*N + y] = f;
}
void setdata(T*datap, int size)
{
memcpy(data, datap, size);
}
Matrix<T, N, M> transpote()
{
Matrix<T, N, M> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
m.setxy(j, i, getxy(i, j));
}
}
return m;
}
Matrix<T, M, N> operator+(Matrix<T, M, N> &adv)
{
Matrix<T, N, M> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
m.setxy(getxy(i, j) + adv.getxy(i, j));
}
}
return m;
}
Matrix<T, M, N> operator-(Matrix<T, M, N> &adv)
{
Matrix<T, N, M> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
m.setxy(getxy(i, j) - adv.getxy(i, j));
}
}
return m;
}
bool operator==(Matrix<T, M, N> &adv)
{
Matrix<T, N, M> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
if (getxy(i, j) != adv.getxy(i, j))return false;
}
}
return true;
}
bool operator!=(Matrix<T, M, N> &adv)
{
Matrix<T, N, M> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
if (getxy(i, j) != adv.getxy(i, j))return true;
}
}
return false;
}
void print()
{
printf("\n");
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
TypePrint(getxy(i, j));
printf(",\t");
}
printf("\n");
}
}
private:
T *data;
};
template<typename T, int M, int N, int P>
Matrix<T, M, P> operator*(Matrix<T, M, N> &x, Matrix<T, N, P> &y)
{
Matrix<T, M, P> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < P; j++)
{
T v = 0;
for (int k = 0; k < N; k++)
{
v += (x.getxy(i, k)*y.getxy(k, j));
}
m.setxy(i, j, v);
}
}
return m;
}
template<typename T, int M, int N>
Matrix<T, M, N> operator*(Matrix<T, M, N> &x, T y)
{
Matrix<T, M, N> m;
for (int i = 0; i < M; i++)
{
for (int j = 0; j < N; j++)
{
m.setxy(i, j, m.getxy(i, j)*y);
}
}
return m;
}
template<typename T, int M, int N>
Matrix<T, M, N> operator*(T y, Matrix<T, M, N> &x)
{
return x*y;
}
template<>
void TypePrint(float v)
{
printf("%f", v);
}
template<>
void TypePrint(int v)
{
printf("%d", v);
}
#define type float
int d1[] =
{
1,2,
2,3,
3,0
};
int d2[] =
{
2,-3,0,
0,1,-2,
-4,5,10
};
int main()
{
Matrix<int, 3, 2> s;
s.setdata(d1, sizeof(d1));
Matrix<int, 3, 3> s1;
s1.setdata(d2, sizeof(d2));
Matrix<int, 3, 2> s2 = s1*s;
Matrix<int, 2, 3> s3 = s2.transpote();//转置了
s.print();
s1.print();
s2.print();
s3.print();
return 0;
}
可以看看cblas的源代码[2],perl写的...
[1] 对cblas_sgemm的说明
[2] linux安装cblas