简化Attention输出的元素总和
题意
给定三个正整数 (均小于 100),按照简化 Attention 的流程构造矩阵并求输出矩阵
的元素总和(四舍五入取整)。
构造规则:
- 输入特征矩阵
:
,全 1
- 权重矩阵
:
的上三角矩阵(主对角线及以上为 1,其余为 0)
,
,
- 对
逐行做 softmax
思路
这题看上去是矩阵乘法模拟题,但如果直接暴力模拟, 的矩阵乘法虽然跑得过,却不够优雅。不妨先手推一下每个矩阵长什么样。
第一步:Q、K、V 长什么样?
全是 1,所以
的第
行第
列就是
第
列的元素之和。
是
的上三角矩阵,第
列(0-indexed)在行
处为 1,共
个 1。
所以 的每一行都一样:
。同理
和
也是如此——三者完全相同。
第二步:softmax(S) 是什么?
。因为
的所有行都相同,
的所有行也都相同,
就是一个
的矩阵,每个元素都相等。
对一行全相等的向量做 softmax 会怎样?每个值 。所以
是一个所有元素都等于
的矩阵。
第三步:求 Y 的元素总和
$$
是
的
矩阵,乘以
(
),
的每一行就是
所有行的平均值。但
的每一行都一样,所以平均值还是那一行本身。
$$
最终答案:
$$
这个求和也可以分段计算:当 时,求和
;当
时,求和
。
用样例验证:,因为
,求和
,答案
。
时间复杂度 ,空间复杂度
。
代码
import math
n, m, h = map(int, input().split())
total = 0
for j in range(h):
total += min(j + 1, m)
print(round(n * total))
#include <iostream>
#include <cmath>
#include <algorithm>
using namespace std;
int main() {
int n, m, h;
cin >> n >> m >> h;
long long total = 0;
for (int j = 0; j < h; j++) {
total += min(j + 1, m);
}
cout << (long long)round((double)n * total) << endl;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt(), m = sc.nextInt(), h = sc.nextInt();
long total = 0;
for (int j = 0; j < h; j++) {
total += Math.min(j + 1, m);
}
System.out.println(Math.round((double) n * total));
}
}

京公网安备 11010502036488号