实现简化版的 LSTM

题意

给定一个长度为 的序列,每一步输入维度为 ,用一个固定参数的 LSTM 做前向计算,输出每个时间步隐藏向量的第一个元素

LSTM 的参数设定:

  • 记忆单元数
  • 初始状态 全 1, 全 0
  • 所有门的权重和偏置均为 0

思路

这题看着吓人——LSTM 有遗忘门、输入门、输出门、候选记忆……一大堆公式。但题目给了一个极端的初始化条件:所有权重和偏置都是 0。这意味着什么?

各门的值是多少? 遗忘门 ,权重和偏置全 0,所以 。同理输入门 ,输出门

候选记忆呢? ,全 0 输入所以

那记忆单元的递推就变成了:

$$

初始 全 1,所以 ,即 的每个分量都是

隐藏状态呢?

$$

所以

整个题目化简下来就一行公式。输入数据完全不影响结果——只需要读入 ,然后对 逐个算就完了。

输出格式要注意: 保留三位小数,但要去掉末尾多余的零,零值输出 "0.0"。比如 要写成 要写成

复杂度

  • 时间:
  • 空间:(存输出结果)

代码

import math

data = list(map(float, input().split()))
seq_len = int(data[0])
x_dim = int(data[1])

results = []
for t in range(1, seq_len + 1):
    val = 0.5 * math.tanh(0.5 ** t)
    rounded = round(val, 3)
    if rounded == 0.0:
        results.append("0.0")
    else:
        s = f"{rounded:.3f}"
        if '.' in s:
            s = s.rstrip('0')
            if s.endswith('.'):
                s += '0'
        results.append(s)

print(' '.join(results))
#include <cstdio>
#include <cmath>
#include <cstring>

int main() {
    int seq_len, x_dim;
    scanf("%d %d", &seq_len, &x_dim);
    double tmp;
    for (int i = 0; i < seq_len * x_dim; i++) scanf("%lf", &tmp);

    for (int t = 1; t <= seq_len; t++) {
        double val = 0.5 * tanh(pow(0.5, t));
        double rounded = round(val * 1000.0) / 1000.0;
        if (rounded == 0.0) {
            printf("0.0");
        } else {
            char buf[32];
            snprintf(buf, sizeof(buf), "%.3f", rounded);
            int len = strlen(buf);
            while (len > 1 && buf[len-1] == '0' && buf[len-2] != '.') {
                buf[--len] = '\0';
            }
            printf("%s", buf);
        }
        if (t < seq_len) printf(" ");
    }
    printf("\n");
    return 0;
}