试题编号: | 202305-2 |
试题名称: | 矩阵运算 |
时间限制: | 5.0s |
内存限制: | 512.0MB |
问题描述: | 题目背景Softmax(𝑄×𝐾𝑇𝑑)×𝑉 是 Transformer 中注意力模块的核心算式,其中 𝑄、𝐾 和 𝑉 均是 𝑛 行 𝑑 列的矩阵,𝐾𝑇 表示矩阵 𝐾 的转置,× 表示矩阵乘法。 问题描述为了方便计算,顿顿同学将 Softmax 简化为了点乘一个大小为 𝑛 的一维向量 𝑊: 现给出矩阵 𝑄、𝐾 和 𝑉 和向量 𝑊,试计算顿顿按简化的算式计算的结果。 输入格式从标准输入读入数据。 输入的第一行包含空格分隔的两个正整数 𝑛 和 𝑑,表示矩阵的大小。 接下来依次输入矩阵 𝑄、𝐾 和 𝑉。每个矩阵输入 𝑛 行,每行包含空格分隔的 𝑑 个整数,其中第 𝑖 行的第 𝑗 个数对应矩阵的第 𝑖 行、第 𝑗 列。 最后一行输入 𝑛 个整数,表示向量 𝑊。 输出格式输出到标准输出中。 输出共 𝑛 行,每行包含空格分隔的 𝑑 个整数,表示计算的结果。 样例输入 Data 样例输出 Data 子任务70 的测试数据满足:𝑛≤100 且 𝑑≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30。 全部的测试数据满足:𝑛≤104 且 𝑑≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000。 提示请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。 |
import java.io.*;
public class ty {
static PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
static QuickInput in = new QuickInput();
public static void main(String[] args) throws IOException {
int n = in.nextInt();
int d = in.nextInt();
int[][] Q = new int[n][d];
int[][] K = new int[n][d];
int[][] V = new int[n][d];
for (int i = 0; i < 3; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < d; k++) {
if(i == 0){
Q[j][k] = in.nextInt();
}else if(i == 1){
K[j][k] = in.nextInt();
}else {
V[j][k] = in.nextInt();
}
}
}
}
int[] arr = new int[n];
for (int i = 0; i < n; i++) {
arr[i] = in.nextInt();
}
long[][] re1 = new long[d][d];
for (int i = 0; i < d; i++) {
for (int j = 0; j < d; j++) {
for (int k = 0; k < n; k++) {
re1[i][j] += (long) V[k][j] * K[k][i];
}
}
}
long[][] re2 = new long[n][d];
for (int i = 0; i < n; i++) {
for (int j = 0; j < d; j++) {
for (int k = 0; k < d; k++) {
re2[i][j] += (long)Q[i][k]*re1[k][j];
}
//这里的向量W其中的第i个值就要和矩阵中的第i行所有元素进行数乘运算。
re2[i][j]*=arr[i];
System.out.print(re2[i][j] + " ");
}
System.out.println();
}
out.flush();
}
static class QuickInput {
StreamTokenizer input = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
int nextInt() throws IOException {
input.nextToken();
return (int) input.nval;
}
}
}