ccfcsp 202305-2 矩阵运算

文章描述了一个编程问题,涉及Transformer模型中的矩阵运算简化过程,要求计算给定矩阵和向量按特定公式的结果。输入是三个矩阵和一个向量,输出是计算后的矩阵结果。
摘要由CSDN通过智能技术生成
试题编号:202305-2
试题名称:矩阵运算
时间限制:5.0s
内存限制:512.0MB
问题描述:

题目背景

Softmax(𝑄×𝐾𝑇𝑑)×𝑉 是 Transformer 中注意力模块的核心算式,其中 𝑄、𝐾 和 𝑉 均是 𝑛 行 𝑑 列的矩阵,𝐾𝑇 表示矩阵 𝐾 的转置,× 表示矩阵乘法。

问题描述

为了方便计算,顿顿同学将 Softmax 简化为了点乘一个大小为 𝑛 的一维向量 𝑊:
(𝑊⋅(𝑄×𝐾𝑇))×𝑉
点乘即对应位相乘,记 𝑊(𝑖) 为向量 𝑊 的第 𝑖 个元素,即将 (𝑄×𝐾𝑇) 第 𝑖 行中的每个元素都与 𝑊(𝑖) 相乘。

现给出矩阵 𝑄、𝐾 和 𝑉 和向量 𝑊,试计算顿顿按简化的算式计算的结果。

输入格式

从标准输入读入数据。

输入的第一行包含空格分隔的两个正整数 𝑛 和 𝑑,表示矩阵的大小。

接下来依次输入矩阵 𝑄、𝐾 和 𝑉。每个矩阵输入 𝑛 行,每行包含空格分隔的 𝑑 个整数,其中第 𝑖 行的第 𝑗 个数对应矩阵的第 𝑖 行、第 𝑗 列。

最后一行输入 𝑛 个整数,表示向量 𝑊。

输出格式

输出到标准输出中。

输出共 𝑛 行,每行包含空格分隔的 𝑑 个整数,表示计算的结果。

样例输入

3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5

Data

样例输出

480 240
0 0
-2200 -1100

Data

子任务

70 的测试数据满足:𝑛≤100 且 𝑑≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30。

全部的测试数据满足:𝑛≤104 且 𝑑≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000。

提示

请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。

import java.io.*;

public class ty {

    static PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
    static QuickInput in = new QuickInput();


    public static void main(String[] args) throws IOException {
        int n = in.nextInt();
        int d = in.nextInt();
        int[][] Q = new int[n][d];
        int[][] K = new int[n][d];
        int[][] V = new int[n][d];
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < n; j++) {
                for (int k = 0; k < d; k++) {
                    if(i == 0){
                        Q[j][k] = in.nextInt();
                    }else if(i == 1){
                        K[j][k] = in.nextInt();
                    }else {
                        V[j][k] = in.nextInt();
                    }
                }
            }
        }
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) {
            arr[i] = in.nextInt();
        }
        long[][] re1 = new long[d][d];
        for (int i = 0; i < d; i++) {
            for (int j = 0; j < d; j++) {
                for (int k = 0; k < n; k++) {
                    re1[i][j] += (long) V[k][j] * K[k][i];
                }
            }
        }

        long[][] re2 = new long[n][d];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < d; j++) {
                for (int k = 0; k < d; k++) {
                    re2[i][j] += (long)Q[i][k]*re1[k][j];
                }
                //这里的向量W其中的第i个值就要和矩阵中的第i行所有元素进行数乘运算。
                re2[i][j]*=arr[i];
                System.out.print(re2[i][j] + " ");
            }
            System.out.println();
        }
        out.flush();

    }

    static class QuickInput {
        StreamTokenizer input = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));

        int nextInt() throws IOException {
            input.nextToken();
            return (int) input.nval;
        }

    }
}

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值