出的蛮有意思的题目
7-1 矩阵乘法及其梯度 (20 分)
矩阵乘法大家都不陌生,给定维度分别为m×n和n×p的两个矩阵A和B,其乘积可以表示C=A×B,C的维度为m×p,其中的元素可以表示为Ci,j=∑k=0n−1Ai,kBk,j。
如果我们将矩阵A看作自变量,矩阵乘法可以看做是一个Rm×n↦Rm×p的函数,对于其中每一项Ai,j,它仅参与了形成矩阵C中第i行的运算,我们可以将其详细地写出来,即: Ci,0=∑k=0n−1Ai,kBk,0,Ci,1=∑k=0n−1Ai,kBk,1……
设函数f对矩阵C有梯度G=∇Cf,即它是一个与C维度相同的矩阵,其每一项元素为Gi,j=∂Ci,j∂f。根据链式求导法则,∂Ai,j∂f=∑k=0n−1∂Ci,k∂f∂Ai,j∂Ci,k=∑k=0n−1Gi,jBk,j,即∇Af=GBT。
同理,对于矩阵B你也可以求得对它的梯度∇Bf。
现在给出矩阵A、B以及函数f对C的梯度G,求要求依次输出乘积AB、函数f对A的梯度、函数f对B的梯度。
输入格式:
第一行给出三个整数m,n和p,均不大于102。然后跟随着一个空行。
随后m行以行主序给出矩阵A中的元素,每行n个元素,元素间用空格分隔。然后跟随着一个空行。
随后n行以行主序给出矩阵B中的元素,每行p个元素,元素间用空格分隔。然后跟随着一个空行。
最后m行以行主序给出f对矩阵C的梯度,即矩阵G中的元素,每行p个元素,元素间用空格分割。
输出格式:
首先输出乘积矩阵C,共m行,每行p个元素, 每个元素间后都紧随着一个空格,每行以换行结尾。最后额外输出一个空行。
然后输出矩阵∇Af,共m行,每行n个元素, 每个元素间后都紧随着一个空格,每行以换行结尾。最后额外输出一个空行。
最后输出矩阵∇Bf,共n行,每行p个元素, 每个元素间后都紧随着一个空格,每行以换行结尾。
输出元素均保留小数点后2位。
输入样例:
2 2 3
1 2
3 4
1 2 3
4 5 6
1 1 1
1 1 1
输出样例:
9.00 12.00 15.00
19.00 26.00 33.00
6.00 15.00
6.00 15.00
4.00 4.00 4.00
6.00 6.00 6.00
#include<stdio.h>
int a[100][100],b[100][100],c[100][100],d[100][100],aT[100][100],bT[100][100],af[100][100],bf[100][100];
int main()
{
int m,n,p;
scanf("%d%d%d",&m,&n,&p);
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
scanf("%d",&a[i][j]);
//求A的转置
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
aT[j][i]=a[i][j];
for(int i=1;i<=n;i++)
for(int j=1;j<=p;j++)
scanf("%d",&b[i][j]);
//求B的转置
for(int i=1;i<=n;i++)
for(int j=1;j<=p;j++)
bT[j][i]=b[i][j];
//G矩阵
for(int i=1;i<=m;i++)
for(int j=1;j<=p;j++)
scanf("%d",&d[i][j]);
for(int i=1;i<=m;i++)
for(int j=1;j<=p;j++)
for(int t=1;t<=n;t++)
c[i][j]+=a[i][t]*b[t][j];
for(int i=1;i<=m;i++)
{
for(int j=1;j<=p;j++)
printf("%.2lf ",(double)c[i][j]);
printf("\n");
}
printf("\n");
//A的梯度等于G乘B的转置
for(int i=1;i<=m;i++)
for(int j=1;j<=n;j++)
for(int t=1;t<=p;t++)
af[i][j]+=d[i][t]*bT[t][j];
for(int i=1;i<=m;i++)
{
for(int j=1;j<=n;j++)
printf("%.2lf ",(double)af[i][j]);
printf("\n");
}
printf("\n");
//B的梯度等于G乘A的转置
for(int i=1;i<=n;i++)
for(int j=1;j<=p;j++)
for(int t=1;t<=m;t++)
bf[i][j]+=aT[i][t]*d[t][j];
for(int i=1;i<=n;i++)
{
for(int j=1;j<=p;j++)
printf("%.2lf ",(double)bf[i][j]);
printf("\n");
}
return 0;
}