题目链接

实现简化版的 LSTM

题目描述

本题要求对一个输入序列执行一个固定参数的 LSTM (长短期记忆网络) 的前向计算。任务是输出每个时间步 t 的隐藏向量 h_t 的第一个元素 h_t[0]

模型关键设定:

  • 输入: 一个长度为 seq_len 的序列。
  • 记忆单元 (Hidden State) 维度:
  • 初始状态:
    • 初始细胞状态 是一个全为 1 的向量。
    • 初始隐藏状态 是一个全为 0 的向量。
  • 固定参数: LSTM 四个门(输入门、遗忘门、输出门、细胞门)的所有权重和偏置全部为 0

输出要求:

  • 依次输出从 seq_lenh_t[0]
  • 数值四舍五入到小数点后三位,并去掉多余的尾部零。
  • 如果数值为 0,统一输出 0.0

解题思路

虽然题目涉及 LSTM,但由于其所有权重和偏置都被设为 0,问题可以被简化为一个纯粹的数学计算任务,其输出实际上与输入序列的具体数值无关,仅与序列长度 seq_len 有关。

  1. LSTM 门计算简化

    • LSTM 的四个门(输入门 , 遗忘门 , 输出门 , 细胞候选门 )的计算都依赖于一个线性变换,形式为
    • 因为所有权重 和偏置 均为 0,所以这个线性变换的结果永远是一个零向量。
    • 因此,四个门的计算结果变为常数:
  2. 状态更新公式简化

    • LSTM 的状态更新公式为:
      • 细胞状态:
      • 隐藏状态:
    • 将上述常数值代入,得到:
  3. 推导最终计算公式

    • 我们得到了细胞状态 的递推关系:
    • 已知初始细胞状态 是一个全为 1 的向量,所以 在时间步 的值是一个所有元素均为 的向量。
    • 代入隐藏状态的公式中,由于 是逐元素操作的,所以 也是一个所有元素都相等的向量,其每个元素的值为
    • 因此,我们要求解的 的值就是
  4. 算法实现

    • 从输入中读取序列长度 seq_lenx_dim 和后续的序列数据都可以忽略。
    • 循环 从 1 到 seq_len
    • 在循环中,计算 value = 0.5 * tanh(pow(0.5, t))
    • value 进行格式化处理,满足输出要求。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>
#include <string>
#include <sstream>

using namespace std;

// 格式化输出函数
string format_double(double val) {
    if (abs(val) < 1e-9) {
        return "0.0";
    }
    stringstream ss;
    ss << fixed << setprecision(3) << val;
    string s = ss.str();
    s.erase(s.find_last_not_of('0') + 1, string::npos);
    if (s.back() == '.') {
        s.pop_back();
    }
    return s;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int seq_len, x_dim;
    cin >> seq_len >> x_dim;

    // 后续的输入可以忽略
    
    double power_of_half = 0.5;
    for (int t = 1; t <= seq_len; ++t) {
        double h_t_0 = 0.5 * tanh(power_of_half);
        cout << format_double(h_t_0) << (t == seq_len ? "" : " ");
        power_of_half *= 0.5; // 为下一次迭代准备
    }
    cout << endl;

    return 0;
}
import java.util.Scanner;
import java.math.BigDecimal;
import java.math.RoundingMode;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int seqLen = sc.nextInt();
        int xDim = sc.nextInt();

        // 后续的输入可以忽略

        StringBuilder sb = new StringBuilder();
        double powerOfHalf = 0.5;
        for (int t = 1; t <= seqLen; t++) {
            double ht0 = 0.5 * Math.tanh(powerOfHalf);
            
            String formatted;
            if (Math.abs(ht0) < 1e-9) {
                formatted = "0.0";
            } else {
                BigDecimal bd = new BigDecimal(String.valueOf(ht0));
                bd = bd.setScale(3, RoundingMode.HALF_UP);
                formatted = bd.stripTrailingZeros().toPlainString();
            }
            sb.append(formatted);

            if (t < seqLen) {
                sb.append(" ");
            }
            powerOfHalf *= 0.5; // 为下一次迭代准备
        }
        System.out.println(sb.toString());
    }
}
import math
from decimal import Decimal, ROUND_HALF_UP

def main():
    # 读取输入,但只使用 seq_len
    line = input().split()
    seq_len = int(line[0])

    results = []
    power_of_half = 0.5
    for _ in range(seq_len):
        h_t_0 = 0.5 * math.tanh(power_of_half)
        
        # 格式化输出
        if abs(h_t_0) < 1e-9:
            results.append("0.0")
        else:
            # 使用 Decimal 进行精确的四舍五入
            formatted_val = Decimal(str(h_t_0)).quantize(
                Decimal('0.001'), rounding=ROUND_HALF_UP
            ).normalize()
            results.append(formatted_val.to_eng_string())

        power_of_half *= 0.5 # 为下一次迭代准备

    print(" ".join(results))

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:数学公式计算
  • 时间复杂度:。我们只需要循环 seq_len 次,每次循环内部都是常数时间的操作。通过迭代计算 避免了 pow 函数带来的对数复杂度。
  • 空间复杂度:,用于存储 seq_len 个结果字符串以便最后一次性输出。如果边计算边输出,则空间复杂度为