基于空间连续块的稀疏注意力机制
题意
给定长度为 、维度为
的向量序列
,以及块大小
和两个
维向量
。按以下流程计算:
- 分块:将序列分成
个连续块,第
块包含
(最后一块可能不满)。
- 块均值:对每个块
,计算维度均值
。
- 评分:
- (点积加 2)
-
- (标量乘法,每个分量再加 1)
-
- 二段划分:将序列
分成恰好两段非空连续子段,两段和为
,最大化
。
输出 ,其中
是最优的
。
思路
这道题本质上就是"按部就班地模拟",不需要什么高深算法,但有几个小地方值得理一理。
评分公式怎么化简?
先看 的表达式。
是一个
维向量,第
个分量是
。把所有分量加起来:
$$
所以 ,其中
只需要预先算一次。这样每个块只需要算一次点积,不用真的构造
向量。
二段划分怎么做?
把 从某个位置切一刀,左边求和得
,右边求和得
。我们要最大化
。
枚举所有 个切割点,维护前缀和就行了。当
从小到大增长时,
从大到小,
先增后减,所以取最大值即可。
复杂度
- 时间:
,计算各块均值和点积
- 空间:
,存储输入向量
代码
import sys
import math
def main():
data = sys.stdin.read().split()
idx = 0
n = int(data[idx]); idx += 1
d = int(data[idx]); idx += 1
b = int(data[idx]); idx += 1
X = []
for i in range(n):
vec = [float(data[idx + j]) for j in range(d)]
idx += d
X.append(vec)
W1 = [float(data[idx + j]) for j in range(d)]; idx += d
W2 = [float(data[idx + j]) for j in range(d)]; idx += d
sumW2 = sum(W2)
m = (n + b - 1) // b
sqrt_d = math.sqrt(d)
A = []
for k in range(m):
start = k * b
end = min(start + b, n)
bs = end - start
dot = sum(W1[j] * sum(X[i][j] for i in range(start, end)) / bs for j in range(d))
s = dot + 2.0
z = max(0.0, s)
A.append((z * sumW2 + d) / sqrt_d)
total = sum(A)
best = float('-inf')
prefix = 0.0
for i in range(m - 1):
prefix += A[i]
best = max(best, min(prefix, total - prefix))
print(round(100 * best))
main()
#include <bits/stdc++.h>
using namespace std;
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, d, b;
cin >> n >> d >> b;
vector<vector<double>> X(n, vector<double>(d));
for(int i = 0; i < n; i++)
for(int j = 0; j < d; j++)
cin >> X[i][j];
vector<double> W1(d), W2(d);
for(int j = 0; j < d; j++) cin >> W1[j];
for(int j = 0; j < d; j++) cin >> W2[j];
double sumW2 = 0;
for(int j = 0; j < d; j++) sumW2 += W2[j];
int m = (n + b - 1) / b;
double sqrt_d = sqrt((double)d);
vector<double> A(m);
for(int k = 0; k < m; k++){
int start = k * b;
int end = min(start + b, n);
int bs = end - start;
double dot = 0;
for(int j = 0; j < d; j++){
double col_sum = 0;
for(int i = start; i < end; i++)
col_sum += X[i][j];
dot += W1[j] * col_sum / bs;
}
double s = dot + 2.0;
double z = max(0.0, s);
A[k] = (z * sumW2 + d) / sqrt_d;
}
double total = 0;
for(int k = 0; k < m; k++) total += A[k];
double best = -1e18;
double prefix = 0;
for(int k = 0; k < m - 1; k++){
prefix += A[k];
best = max(best, min(prefix, total - prefix));
}
cout << (long long)round(100.0 * best) << endl;
return 0;
}

京公网安备 11010502036488号