题目链接
题目描述
本题要求对一个输入序列执行一个固定参数的 LSTM (长短期记忆网络) 的前向计算。任务是输出每个时间步 t 的隐藏向量 h_t 的第一个元素 h_t[0]。
模型关键设定:
- 输入: 一个长度为
seq_len的序列。 - 记忆单元 (Hidden State) 维度:
。
- 初始状态:
- 初始细胞状态
是一个全为 1 的向量。
- 初始隐藏状态
是一个全为 0 的向量。
- 初始细胞状态
- 固定参数: LSTM 四个门(输入门、遗忘门、输出门、细胞门)的所有权重和偏置全部为 0。
输出要求:
- 依次输出从
到
seq_len的h_t[0]。 - 数值四舍五入到小数点后三位,并去掉多余的尾部零。
- 如果数值为 0,统一输出
0.0。
解题思路
虽然题目涉及 LSTM,但由于其所有权重和偏置都被设为 0,问题可以被简化为一个纯粹的数学计算任务,其输出实际上与输入序列的具体数值无关,仅与序列长度 seq_len 有关。
-
LSTM 门计算简化
- LSTM 的四个门(输入门
, 遗忘门
, 输出门
, 细胞候选门
)的计算都依赖于一个线性变换,形式为
。
- 因为所有权重
和偏置
均为 0,所以这个线性变换的结果永远是一个零向量。
- 因此,四个门的计算结果变为常数:
- LSTM 的四个门(输入门
-
状态更新公式简化
- LSTM 的状态更新公式为:
- 细胞状态:
- 隐藏状态:
- 细胞状态:
- 将上述常数值代入,得到:
- LSTM 的状态更新公式为:
-
推导最终计算公式
- 我们得到了细胞状态
的递推关系:
。
- 已知初始细胞状态
是一个全为 1 的向量,所以
在时间步
的值是一个所有元素均为
的向量。
- 将
代入隐藏状态的公式中,由于
是逐元素操作的,所以
也是一个所有元素都相等的向量,其每个元素的值为
。
- 因此,我们要求解的
的值就是
。
- 我们得到了细胞状态
-
算法实现
- 从输入中读取序列长度
seq_len。x_dim和后续的序列数据都可以忽略。 - 循环
从 1 到
seq_len。 - 在循环中,计算
value = 0.5 * tanh(pow(0.5, t))。 - 对
value进行格式化处理,满足输出要求。
- 从输入中读取序列长度
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>
#include <string>
#include <sstream>
using namespace std;
// 格式化输出函数
string format_double(double val) {
if (abs(val) < 1e-9) {
return "0.0";
}
stringstream ss;
ss << fixed << setprecision(3) << val;
string s = ss.str();
s.erase(s.find_last_not_of('0') + 1, string::npos);
if (s.back() == '.') {
s.pop_back();
}
return s;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int seq_len, x_dim;
cin >> seq_len >> x_dim;
// 后续的输入可以忽略
double power_of_half = 0.5;
for (int t = 1; t <= seq_len; ++t) {
double h_t_0 = 0.5 * tanh(power_of_half);
cout << format_double(h_t_0) << (t == seq_len ? "" : " ");
power_of_half *= 0.5; // 为下一次迭代准备
}
cout << endl;
return 0;
}
import java.util.Scanner;
import java.math.BigDecimal;
import java.math.RoundingMode;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int seqLen = sc.nextInt();
int xDim = sc.nextInt();
// 后续的输入可以忽略
StringBuilder sb = new StringBuilder();
double powerOfHalf = 0.5;
for (int t = 1; t <= seqLen; t++) {
double ht0 = 0.5 * Math.tanh(powerOfHalf);
String formatted;
if (Math.abs(ht0) < 1e-9) {
formatted = "0.0";
} else {
BigDecimal bd = new BigDecimal(String.valueOf(ht0));
bd = bd.setScale(3, RoundingMode.HALF_UP);
formatted = bd.stripTrailingZeros().toPlainString();
}
sb.append(formatted);
if (t < seqLen) {
sb.append(" ");
}
powerOfHalf *= 0.5; // 为下一次迭代准备
}
System.out.println(sb.toString());
}
}
import math
from decimal import Decimal, ROUND_HALF_UP
def main():
# 读取输入,但只使用 seq_len
line = input().split()
seq_len = int(line[0])
results = []
power_of_half = 0.5
for _ in range(seq_len):
h_t_0 = 0.5 * math.tanh(power_of_half)
# 格式化输出
if abs(h_t_0) < 1e-9:
results.append("0.0")
else:
# 使用 Decimal 进行精确的四舍五入
formatted_val = Decimal(str(h_t_0)).quantize(
Decimal('0.001'), rounding=ROUND_HALF_UP
).normalize()
results.append(formatted_val.to_eng_string())
power_of_half *= 0.5 # 为下一次迭代准备
print(" ".join(results))
if __name__ == "__main__":
main()
算法及复杂度
- 算法:数学公式计算
- 时间复杂度:
。我们只需要循环
seq_len次,每次循环内部都是常数时间的操作。通过迭代计算避免了
pow函数带来的对数复杂度。 - 空间复杂度:
,用于存储
seq_len个结果字符串以便最后一次性输出。如果边计算边输出,则空间复杂度为。

京公网安备 11010502036488号