实现多通道二维卷积

题目分析

本题要求实现一个标准的多通道二维卷积操作,支持 bias、stride、padding 和 dilation 四个参数。这是深度学习中最基础的操作之一,理解卷积的计算过程对于掌握 CNN 至关重要。

思路

核心公式

对于每个输出通道 oc 和输出位置 (oh, ow)

$$

其中:

$$

$$

越界时,对应的输入值视为 0(即 zero-padding 的效果)。

输出尺寸

$$

$$

实现步骤

  1. 读入数据:按照 channel-first、row-major 的顺序读入输入张量和卷积核权重。
  2. 计算输出尺寸:根据公式算出输出的高和宽。
  3. 六重循环计算卷积:遍历输出通道、输出位置、输入通道、卷积核位置,累加乘积。
  4. 边界检查:计算出的输入坐标如果越界则跳过(等价于 padding 为 0)。
  5. 格式化输出:保留 4 位小数,空格分隔。

代码

import sys

def main():
    data = sys.stdin.read().split()
    idx = 0

    c = int(data[idx]); idx += 1
    x = int(data[idx]); idx += 1
    y = int(data[idx]); idx += 1

    # 读入输入张量: c 个通道, 每个 x*y
    inp = []
    for ic in range(c):
        channel = []
        for i in range(x):
            row = []
            for j in range(y):
                row.append(float(data[idx])); idx += 1
            channel.append(row)
        inp.append(channel)

    out_c = int(data[idx]); idx += 1
    in_c = int(data[idx]); idx += 1
    kh = int(data[idx]); idx += 1
    kw = int(data[idx]); idx += 1

    # 读入卷积核权重: out_c * in_c * kh * kw
    weight = []
    for oc in range(out_c):
        oc_w = []
        for ic in range(in_c):
            kernel = []
            for ki in range(kh):
                row = []
                for kj in range(kw):
                    row.append(float(data[idx])); idx += 1
                kernel.append(row)
            oc_w.append(kernel)
        weight.append(oc_w)

    bias_flag = int(data[idx]); idx += 1
    stride = int(data[idx]); idx += 1
    padding = int(data[idx]); idx += 1
    dilation = int(data[idx]); idx += 1

    bias = [0.0] * out_c
    if bias_flag == 1:
        for oc in range(out_c):
            bias[oc] = float(data[idx]); idx += 1

    # 计算输出尺寸
    out_h = (x + 2 * padding - dilation * (kh - 1) - 1) // stride + 1
    out_w = (y + 2 * padding - dilation * (kw - 1) - 1) // stride + 1

    # 六重循环计算卷积
    results = []
    for oc in range(out_c):
        for oh in range(out_h):
            for ow in range(out_w):
                val = bias[oc]
                for ic in range(in_c):
                    for ki in range(kh):
                        for kj in range(kw):
                            ih = oh * stride + ki * dilation - padding
                            iw = ow * stride + kj * dilation - padding
                            if 0 <= ih < x and 0 <= iw < y:
                                val += inp[ic][ih][iw] * weight[oc][ic][ki][kj]
                results.append(f"{val:.4f}")

    print(" ".join(results))

main()

复杂度分析

  • 时间复杂度,即标准卷积的计算复杂度。
  • 空间复杂度,存储输入、权重和输出所需的空间。