题目链接

基于空间连续块的稀疏注意力机制

题目描述

给定一个 的向量序列 ,块大小为 ,以及两个 维向量 ,执行以下流程:

  1. 分块与压缩

    • 将序列 按顺序切分为 个块。
    • 对每个块 ,计算其“块内均值”向量
  2. 打分

    • 对每个 ,根据以下公式计算标量打分
      • (点积后加偏置)
      • (ReLU激活)
      • (标量乘向量后加偏置)
      • (向量各分量求和后缩放)
    • 得到打分序列
  3. 二段划分

    • 将序列 划分为两个非空的连续子段,其和分别为
    • 找到一种划分方式,使得 的值最大。设这个最大值为
  4. 输出

    • 输出 round(100 * S) 的整数结果。

解题思路

这是一个流程清晰的模拟计算题,可以分解为三个独立的步骤。

第一步:分块与计算均值向量

  1. 分块:总共有 个块。我们可以遍历块的索引 从 0 到
  2. 提取块:对于每个块 ,其对应的原始向量在序列 中的起始索引是 k * b,结束索引是 min((k + 1) * b, n)
  3. 计算均值
    • 创建一个 维的零向量 sum_vec
    • 遍历当前块内的所有向量,将它们逐分量地累加到 sum_vec
    • 获取当前块的实际大小 block_size
    • sum_vec 的每一维都除以 block_size,得到均值向量

第二步:计算打分序列

这一步是严格的数学公式翻译。对每个均值向量

  1. 计算 :计算 的点积,然后加上偏置
  2. 计算 :如果 ,则 ;否则
  3. 计算 :这是一个向量。将向量 的每个分量都乘以标量 ,然后再给每个分量都加上偏置
  4. 计算 :计算向量 的所有分量之和,然后除以
  5. 将所有计算出的 存入一个列表,形成打分序列

第三步:寻找最优二段划分

  • 目标:给定序列 ,找到一个切分点,最大化 min(S1, S2)

  • 问题分析:随着切分点向右移动, 单调递增, 单调递减。它们的最小值 min(S1, S2) 会先增后减。最大值通常出现在 最接近的地方。

  • 前缀和优化:这是一个经典问题,可以通过前缀和高效解决。

    1. 计算序列 的总和 total_sum
    2. 初始化一个 left_sum = 0 和一个 max_min_sum = -infinity
    3. 遍历所有可能的切分点(即遍历 ):
      • left_sum += a_i (当前左段的和 )
      • right_sum = total_sum - left_sum (当前右段的和 )
      • current_min = min(left_sum, right_sum)
      • max_min_sum = max(max_min_sum, current_min)
    4. 遍历结束后,max_min_sum 就是我们要求的最优值
  • 输出:根据题目要求,对 进行计算和四舍五入:round(100 * S)

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iomanip>

using namespace std;

// 向量点积
double dot_product(const vector<double>& v1, const vector<double>& v2) {
    double result = 0.0;
    for (size_t i = 0; i < v1.size(); ++i) {
        result += v1[i] * v2[i];
    }
    return result;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, d, b;
    cin >> n >> d >> b;

    vector<vector<double>> x(n, vector<double>(d));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            cin >> x[i][j];
        }
    }

    vector<double> w1(d), w2(d);
    for (int i = 0; i < d; ++i) cin >> w1[i];
    for (int i = 0; i < d; ++i) cin >> w2[i];

    vector<double> a;
    int m = (n + b - 1) / b;

    // 1 & 2. 分块、计算均值、打分
    for (int k = 0; k < m; ++k) {
        int start_idx = k * b;
        int end_idx = min((k + 1) * b, n);
        int block_size = end_idx - start_idx;

        vector<double> h_k(d, 0.0);
        for (int i = start_idx; i < end_idx; ++i) {
            for (int j = 0; j < d; ++j) {
                h_k[j] += x[i][j];
            }
        }
        for (int j = 0; j < d; ++j) {
            h_k[j] /= block_size;
        }

        double s_k = dot_product(w1, h_k) + 2.0;
        double z_k = max(0.0, s_k);
        
        vector<double> c_k(d);
        double c_k_sum = 0;
        for(int i = 0; i < d; ++i) {
            c_k[i] = w2[i] * z_k + 1.0;
            c_k_sum += c_k[i];
        }
        
        double a_k = c_k_sum / sqrt(d);
        a.push_back(a_k);
    }

    // 3. 最优二段划分
    double total_sum = accumulate(a.begin(), a.end(), 0.0);
    double left_sum = 0.0;
    double max_min_sum = -1.0;

    for (size_t i = 0; i < a.size() - 1; ++i) {
        left_sum += a[i];
        double right_sum = total_sum - left_sum;
        if (max_min_sum < 0 || min(left_sum, right_sum) > max_min_sum) {
            max_min_sum = min(left_sum, right_sum);
        }
    }
    
    // 4. 输出
    long long result = round(100.0 * max_min_sum);
    cout << result << "\n";

    return 0;
}
import java.util.Scanner;
import java.util.Locale;
import java.util.ArrayList;
import java.util.List;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in).useLocale(Locale.US);

        int n = sc.nextInt();
        int d = sc.nextInt();
        int b = sc.nextInt();

        double[][] x = new double[n][d];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < d; j++) {
                x[i][j] = sc.nextDouble();
            }
        }

        double[] w1 = new double[d];
        double[] w2 = new double[d];
        for (int i = 0; i < d; i++) w1[i] = sc.nextDouble();
        for (int i = 0; i < d; i++) w2[i] = sc.nextDouble();

        List<Double> a = new ArrayList<>();
        int m = (n + b - 1) / b;

        // 1 & 2. 分块、计算均值、打分
        for (int k = 0; k < m; k++) {
            int startIdx = k * b;
            int endIdx = Math.min((k + 1) * b, n);
            int blockSize = endIdx - startIdx;

            double[] h_k = new double[d];
            for (int i = startIdx; i < endIdx; i++) {
                for (int j = 0; j < d; j++) {
                    h_k[j] += x[i][j];
                }
            }
            for (int j = 0; j < d; j++) {
                h_k[j] /= blockSize;
            }

            double s_k = 0.0;
            for(int i = 0; i < d; i++) s_k += w1[i] * h_k[i];
            s_k += 2.0;

            double z_k = Math.max(0.0, s_k);

            double c_k_sum = 0;
            for (int i = 0; i < d; i++) {
                c_k_sum += (w2[i] * z_k + 1.0);
            }
            
            double a_k = c_k_sum / Math.sqrt(d);
            a.add(a_k);
        }
        
        // 3. 最优二段划分
        double totalSum = 0;
        for (double val : a) totalSum += val;
        
        double leftSum = 0;
        double maxMinSum = -1.0;

        for (int i = 0; i < a.size() - 1; i++) {
            leftSum += a.get(i);
            double rightSum = totalSum - leftSum;
            if(maxMinSum < 0 || Math.min(leftSum, rightSum) > maxMinSum) {
               maxMinSum = Math.min(leftSum, rightSum);
            }
        }

        // 4. 输出
        long result = Math.round(100.0 * maxMinSum);
        System.out.println(result);
    }
}
import math

def main():
    n, d, b = map(int, input().split())

    x = [list(map(float, input().split())) for _ in range(n)]
    w1 = list(map(float, input().split()))
    w2 = list(map(float, input().split()))

    a = []
    m = (n + b - 1) // b

    # 1 & 2. 分块、计算均值、打分
    for k in range(m):
        start_idx = k * b
        end_idx = min((k + 1) * b, n)
        block = x[start_idx:end_idx]
        block_size = len(block)

        h_k = [0.0] * d
        for vec in block:
            for i in range(d):
                h_k[i] += vec[i]
        h_k = [val / block_size for val in h_k]
        
        s_k = sum(w1[i] * h_k[i] for i in range(d)) + 2.0
        z_k = max(0.0, s_k)
        
        c_k = [(val * z_k) + 1.0 for val in w2]
        c_k_sum = sum(c_k)
        
        a_k = c_k_sum / math.sqrt(d)
        a.append(a_k)

    # 3. 最优二段划分
    total_sum = sum(a)
    left_sum = 0.0
    max_min_sum = -1.0

    for i in range(len(a) - 1):
        left_sum += a[i]
        right_sum = total_sum - left_sum
        if max_min_sum < 0 or min(left_sum, right_sum) > max_min_sum:
            max_min_sum = min(left_sum, right_sum)

    # 4. 输出
    result = round(100.0 * max_min_sum)
    print(result)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法: 模拟, 前缀和
  • 时间复杂度:
    • 分块和计算所有均值向量 的总时间复杂度为 ,因为每个输入向量只被访问一次。
    • 计算所有打分 的总时间复杂度为 ,这通常小于
    • 最优二段划分的时间复杂度为
    • 因此,总的时间复杂度由第一步主导。
  • 空间复杂度: ,主要用于存储输入的向量序列