CSP 矩阵运算

题目链接

http://118.190.20.162/view.page?gpid=T169

运行结果:

在这里插入图片描述

思想

  1. 根据数据范围推断出,最终结果矩阵中的元素数字可能超过int,需要用long long.当然,用python的话不用考虑这个问题。
  2. 考虑改变矩阵运算的次序,从而减低时间复杂度。
    如果不用long long,会导致结果错误。最好全部矩阵的存储都用long long.
    如果不调整矩阵运算次序,会超时。在这里插入图片描述
    这里简单说明一下为啥交换次序能优化:
    在这里插入图片描述Q:n*d
    K:d*n
    V:n*d
    从前向后算的话,复杂度是:在这里插入图片描述
    如果先算Temp = K*V,再算Q*Temp, 复杂度是在这里插入图片描述

满分代码

#include <bits/stdc++.h>
using namespace std;  
typedef long long ll;
// 通过矩阵运算交换律就可以做很大的优化 
void cross_multiply(vector<vector<ll>>& left, vector<vector<ll>>& right, 
            vector<vector<ll>>& ans, int a, int b, int c){
    // left: a*b right : b*c  ans: a*c 
    ll temp = 0;
    for(int i=0;i<a;++i){
        for(int j=0;j<c;++j){
            // get ans[i][j] left的第i行 right的第j列 对应相乘再相加
            temp = 0;
            for(int k=0;k<b;++k){
                temp += left[i][k]*right[k][j];
            }
            ans[i][j] = temp;
        }
    }
}

void dot_multiply(vector<ll>&w, vector<vector<ll>>& ans, int n, int d){
    // w:n*1  ans:n*d 
    for(int i=0;i<n;++i){
        for(int j=0;j<d;++j){
            ans[i][j]*=w[i];
        }
    }
}

int main(){
    // read data 
    // cout<<sizeof(int)<<" "<<sizeof(long long)<<endl;
    int n, d;
    cin >> n >> d;
    // q, k, v : n*d   w : n * 1
    vector<vector<ll>> q (n, vector<ll>(d, 0));
    vector<vector<ll>> kt (d, vector<ll>(n, 0));//kt for transpose of k
    vector<vector<ll>> v (n, vector<ll>(d, 0));
    vector<ll> w(n, 0);
    
    for(int i=0;i<n;++i){
        for(int j=0;j<d;++j)
            cin >> q[i][j];
    }
    // kt : d * n  
    for(int i=0;i<n;++i){
        for(int j=0;j<d;++j)
            cin >> kt[j][i];
    }
    for(int i=0;i<n;++i){
        for(int j=0;j<d;++j)
            cin >> v[i][j];
    }
    // read w, n*1
    for(int i=0;i<n;++i){
        cin >> w[i];
    }
    
    // 不调整顺序,只过70
    // vector<vector<ll>> qkt(n, vector<ll>(n, 0));
    // cross_multiply(q, kt, qkt, n, d, n);
    // dot_multiply(w, qkt, n, n);
    // vector<vector<ll>> ans(n, vector<ll>(d, 0));
    // cross_multiply(qkt, v, ans, n, n, d);
    
 
    
    // 调整计算顺序后的,过100
    vector<vector<ll>> ktv(d, vector<ll>(d,0));
    cross_multiply(kt, v, ktv, d, n, d);
    vector<vector<ll>> ans(n, vector<ll>(d, 0));
    cross_multiply(q, ktv, ans, n, d, d);
    dot_multiply(w, ans, n, d);
    for(int i=0;i<n;++i){
        for(int j=0;j<d;++j) cout<<ans[i][j]<<" ";
        cout<<endl;
    }
    
    return 0;
}

python:

n, d = map(int, input().split())
q, k, v, w = [], [], [], []

for i in range(n):
    row = list(map(int, input().split()))
    q.append(row)    

for i in range(n):
    row = list(map(int, input().split()))
    k.append(row)
   
for i in range(n):
    row = list(map(int, input().split()))
    v.append(row)
    
w = list(map(int, input().split()))

def cross_multiply(left, right, a, b, c):
    # left:a*b right:b*c return ans
    ans = [[0]*c for i in range(a)]
    for i in range(a):
        for j in range(c):
            # get ans[i][j]
            temp = 0
            for k in range(b):
                temp += left[i][k]*right[k][j]
            ans[i][j]=temp
    return ans 

# transpose k
kt = []
for i in range(d):
    row = []
    for j in range(n):
        row.append(k[j][i])
    kt.append(row)
    
ktv = cross_multiply(kt, v, d, n, d)
qktv = cross_multiply(q, ktv, n, d, d)

# 对 qktv 和 w 进行点运算
for i in range(n):
    for j in range(d):
        qktv[i][j] *= w[i]
        
for i in range(n):
    for j in range(d):
        print(qktv[i][j], end=' ')
    print()

总结

在这个题目的特殊条件(n<=10**4且d<=20)下,由于n的范围远大于d,所以可以用交换计算顺序的技巧来优化。当然,如果n<=10**4 and d <= 10**4, 那么这个优化就没有意义了。
在这里朴实无华的想法却是最强大的。

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值