简化Attention输出的元素总和

题意

给定三个正整数 (均小于 100),按照简化 Attention 的流程构造矩阵并求输出矩阵 的元素总和(四舍五入取整)。

构造规则:

  • 输入特征矩阵 ,全 1
  • 权重矩阵 的上三角矩阵(主对角线及以上为 1,其余为 0)
  • 逐行做 softmax

思路

这题看上去是矩阵乘法模拟题,但如果直接暴力模拟, 的矩阵乘法虽然跑得过,却不够优雅。不妨先手推一下每个矩阵长什么样。

第一步:Q、K、V 长什么样?

全是 1,所以 的第 行第 列就是 列的元素之和。 的上三角矩阵,第 列(0-indexed)在行 处为 1,共 个 1。

所以 的每一行都一样:。同理 也是如此——三者完全相同。

第二步:softmax(S) 是什么?

。因为 的所有行都相同, 的所有行也都相同, 就是一个 的矩阵,每个元素都相等

对一行全相等的向量做 softmax 会怎样?每个值 。所以 是一个所有元素都等于 的矩阵。

第三步:求 Y 的元素总和

$$

矩阵,乘以 ), 的每一行就是 所有行的平均值。但 的每一行都一样,所以平均值还是那一行本身。

$$

最终答案:

$$

这个求和也可以分段计算:当 时,求和 ;当 时,求和

用样例验证:,因为 ,求和 ,答案

时间复杂度 ,空间复杂度

代码

import math

n, m, h = map(int, input().split())

total = 0
for j in range(h):
    total += min(j + 1, m)

print(round(n * total))
#include <iostream>
#include <cmath>
#include <algorithm>
using namespace std;

int main() {
    int n, m, h;
    cin >> n >> m >> h;
    long long total = 0;
    for (int j = 0; j < h; j++) {
        total += min(j + 1, m);
    }
    cout << (long long)round((double)n * total) << endl;
    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt(), h = sc.nextInt();
        long total = 0;
        for (int j = 0; j < h; j++) {
            total += Math.min(j + 1, m);
        }
        System.out.println(Math.round((double) n * total));
    }
}