题目链接

实现多通道二维卷积

题目描述

实现一个带有 stride, padding, dilationbias 参数的多通道二维卷积。

  • 输入

    1. 输入张量 (形状 )
    2. 卷积核 (形状 )
    3. 卷积参数 (bias, stride, padding, dilation)
    4. 偏置向量 (如果 bias=1)
  • 计算公式: 对于每个输出位置 ,其值为: 其中,输入坐标 通过以下公式从输出坐标和卷积核坐标映射得到: 如果 越界,则该项的贡献视为 0。

  • 输出: 将计算出的输出张量(形状为 )按“通道优先、行优先、列优先”的顺序展开为一行,所有数值保留4位小数。

解题思路

这是一个纯粹的模拟题,核心是精确地将题目中给出的数据格式和卷积公式翻译成代码。

  1. 数据解析与存储

    • 输入张量:形状为 c x y。可以创建一个三维数组 input[c][x][y]。由于输入数据是“平铺”成一维的,我们需要手动根据索引 (ic, ix, iy) 计算其在一维数组中的位置 ic*x*y + ix*y + iy 来填充这个三维数组。
    • 卷积核权重:形状为 out x in x k x k。同理,创建一个四维数组 weight[out][in][k][k],并根据索引 (oc, ic, ki, kj) 计算一维位置 oc*in*k*k + ic*k*k + ki*k + kj 来填充。
    • 参数和偏置:读取 bias, stride, padding, dilation。如果 bias 标志为1, 则额外读取 out 个偏置值存入一个一维数组。否则,可以认为偏置数组全为0。
  2. 计算输出尺寸

    • 在进行卷积计算前,必须先根据官方公式计算出输出张量的空间尺寸
      • x_out = floor((x + 2*padding - dilation*(k-1) - 1) / stride + 1)
      • y_out 的计算方法相同,只需将 x 换成 y
    • 基于此,创建一个三维数组 output[out][x_out][y_out] 并初始化为0。
  3. 执行卷积运算

    • 这是算法的核心,严格遵循题目的计算公式,通过多层嵌套循环实现。
    • 外层循环 (遍历输出)
      • for oc in 0..out-1 (遍历每个输出通道)
      • for oh in 0..x_out-1 (遍历每个输出像素的行)
      • for ow in 0..y_out-1 (遍历每个输出像素的列)
    • 初始化
      • 在最内层,首先将当前输出像素值 val 初始化为对应的偏置 bias[oc]
    • 内层循环 (执行求和)
      • for ic in 0..in-1 (遍历每个输入通道)
      • for ki in 0..k-1 (遍历卷积核的行)
      • for kj in 0..k-1 (遍历卷积核的列)
    • 坐标映射与累加
      • 计算输入坐标 ihiw
      • 边界检查:判断 ihiw 是否在输入张量的有效范围内 (0 <= ih < x0 <= iw < y)。
      • 如果坐标有效,则将 input[ic][ih][iw] * weight[oc][ic][ki][kj] 累加到 val
    • 赋值:完成所有求和循环后,将 val 赋给 output[oc][oh][ow]
  4. 格式化输出

    • 遍历计算完成的 output 数组,按照“通道优先、行优先、列优先”的顺序,将每个浮点数格式化为保留四位小数的字符串,并用空格连接,最后统一输出。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>

using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int c, x, y;
    cin >> c >> x >> y;
    vector<double> input_flat(c * x * y);
    for (int i = 0; i < c * x * y; ++i) cin >> input_flat[i];

    int out_c, in_c, k, k_dummy;
    cin >> out_c >> in_c >> k >> k_dummy;
    vector<double> weight_flat(out_c * in_c * k * k);
    for (int i = 0; i < out_c * in_c * k * k; ++i) cin >> weight_flat[i];

    int bias_flag, stride, padding, dilation;
    cin >> bias_flag >> stride >> padding >> dilation;
    vector<double> bias(out_c, 0.0);
    if (bias_flag == 1) {
        for (int i = 0; i < out_c; ++i) cin >> bias[i];
    }

    auto get_input = [&](int ic, int ix, int iy) {
        return input_flat[ic * x * y + ix * y + iy];
    };
    auto get_weight = [&](int oc, int ic, int ki, int kj) {
        return weight_flat[oc * in_c * k * k + ic * k * k + ki * k + kj];
    };

    int x_out = floor((double)(x + 2 * padding - dilation * (k - 1) - 1) / stride + 1);
    int y_out = floor((double)(y + 2 * padding - dilation * (k - 1) - 1) / stride + 1);

    vector<double> output_flat;

    for (int oc = 0; oc < out_c; ++oc) {
        for (int oh = 0; oh < x_out; ++oh) {
            for (int ow = 0; ow < y_out; ++ow) {
                double val = bias[oc];
                for (int ic = 0; ic < in_c; ++ic) {
                    for (int ki = 0; ki < k; ++ki) {
                        for (int kj = 0; kj < k; ++kj) {
                            int ih = oh * stride + ki * dilation - padding;
                            int iw = ow * stride + kj * dilation - padding;
                            if (ih >= 0 && ih < x && iw >= 0 && iw < y) {
                                val += get_input(ic, ih, iw) * get_weight(oc, ic, ki, kj);
                            }
                        }
                    }
                }
                output_flat.push_back(val);
            }
        }
    }

    cout << fixed << setprecision(4);
    for (size_t i = 0; i < output_flat.size(); ++i) {
        cout << output_flat[i] << (i == output_flat.size() - 1 ? "" : " ");
    }
    cout << "\n";

    return 0;
}
import java.util.Scanner;
import java.util.Locale;
import java.util.ArrayList;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in).useLocale(Locale.US);

        int c = sc.nextInt();
        int x = sc.nextInt();
        int y = sc.nextInt();
        double[][][] input = new double[c][x][y];
        for (int ic = 0; ic < c; ic++) {
            for (int ix = 0; ix < x; ix++) {
                for (int iy = 0; iy < y; iy++) {
                    input[ic][ix][iy] = sc.nextDouble();
                }
            }
        }

        int outC = sc.nextInt();
        int inC = sc.nextInt();
        int k = sc.nextInt();
        sc.nextInt(); // k_dummy
        double[][][][] weight = new double[outC][inC][k][k];
        for (int oc = 0; oc < outC; oc++) {
            for (int ic = 0; ic < inC; ic++) {
                for (int ki = 0; ki < k; ki++) {
                    for (int kj = 0; kj < k; kj++) {
                        weight[oc][ic][ki][kj] = sc.nextDouble();
                    }
                }
            }
        }

        int biasFlag = sc.nextInt();
        int stride = sc.nextInt();
        int padding = sc.nextInt();
        int dilation = sc.nextInt();
        double[] bias = new double[outC];
        if (biasFlag == 1) {
            for (int i = 0; i < outC; i++) {
                bias[i] = sc.nextDouble();
            }
        }

        int xOut = (x + 2 * padding - dilation * (k - 1) - 1) / stride + 1;
        int yOut = (y + 2 * padding - dilation * (k - 1) - 1) / stride + 1;

        double[][][] output = new double[outC][xOut][yOut];
        ArrayList<String> outputFlat = new ArrayList<>();

        for (int oc = 0; oc < outC; oc++) {
            for (int oh = 0; oh < xOut; oh++) {
                for (int ow = 0; ow < yOut; ow++) {
                    double val = bias[oc];
                    for (int ic = 0; ic < inC; ic++) {
                        for (int ki = 0; ki < k; ki++) {
                            for (int kj = 0; kj < k; kj++) {
                                int ih = oh * stride + ki * dilation - padding;
                                int iw = ow * stride + kj * dilation - padding;
                                if (ih >= 0 && ih < x && iw >= 0 && iw < y) {
                                    val += input[ic][ih][iw] * weight[oc][ic][ki][kj];
                                }
                            }
                        }
                    }
                    output[oc][oh][ow] = val;
                }
            }
        }

        for (int oc = 0; oc < outC; oc++) {
            for (int oh = 0; oh < xOut; oh++) {
                for (int ow = 0; ow < yOut; ow++) {
                    outputFlat.add(String.format(Locale.US, "%.4f", output[oc][oh][ow]));
                }
            }
        }

        System.out.println(String.join(" ", outputFlat));
    }
}
import math

def main():
    c, x, y = map(int, input().split())
    input_flat = list(map(float, input().split()))
    input_tensor = [[[0.0] * y for _ in range(x)] for _ in range(c)]
    idx = 0
    for ic in range(c):
        for ix in range(x):
            for iy in range(y):
                input_tensor[ic][ix][iy] = input_flat[idx]
                idx += 1

    out_c, in_c, k, _ = map(int, input().split())
    weight_flat = list(map(float, input().split()))
    weight_tensor = [[[[0.0] * k for _ in range(k)] for _ in range(in_c)] for _ in range(out_c)]
    idx = 0
    for oc in range(out_c):
        for ic in range(in_c):
            for ki in range(k):
                for kj in range(k):
                    weight_tensor[oc][ic][ki][kj] = weight_flat[idx]
                    idx += 1
    
    params = list(map(int, input().split()))
    bias_flag, stride, padding, dilation = params
    
    bias = [0.0] * out_c
    if bias_flag == 1:
        bias = list(map(float, input().split()))

    x_out = (x + 2 * padding - dilation * (k - 1) - 1) // stride + 1
    y_out = (y + 2 * padding - dilation * (k - 1) - 1) // stride + 1

    output_flat = []

    for oc in range(out_c):
        for oh in range(x_out):
            for ow in range(y_out):
                val = bias[oc]
                for ic in range(in_c):
                    for ki in range(k):
                        for kj in range(k):
                            ih = oh * stride + ki * dilation - padding
                            iw = ow * stride + kj * dilation - padding
                            if 0 <= ih < x and 0 <= iw < y:
                                val += input_tensor[ic][ih][iw] * weight_tensor[oc][ic][ki][kj]
                output_flat.append(val)

    print(" ".join([f"{v:.4f}" for v in output_flat]))

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法: 模拟
  • 时间复杂度:
    • 算法的核心是七层嵌套循环。外三层遍历输出张量的每个点,内四层遍历输入通道和卷积核的每个点。
  • 空间复杂度:
    • 主要空间开销用于存储输入张量和卷积核权重。输出张量也可以计入,其大小为