多元高斯分布的概率密度函数计算

[题目链接](https://www.nowcoder.com/practice/2cce779324004584bec7fe8d1af1f5b7)

思路

给定数据点 、均值向量 和协方差矩阵 ,计算多元高斯分布的概率密度函数(PDF)值。

多元高斯分布 PDF 公式

$$

其中 是维度, 是协方差矩阵的行列式, 是协方差矩阵的逆矩阵。

计算步骤

  1. 解析输入:从 JSON 字符串中提取
  2. 计算差向量
  3. 高斯消元求逆矩阵和行列式:构造增广矩阵 ,通过带部分主元选取的高斯-约旦消元法,同时求出行列式 和逆矩阵
  4. 计算指数项(二次型)。
  5. 组合结果:将归一化系数和指数项相乘得到 PDF 值,保留两位小数输出。

样例演示

输入 (单位矩阵):

  • 差向量
  • 二次型
  • 归一化系数
  • PDF
  • 保留两位小数输出

代码

#include <bits/stdc++.h>
using namespace std;

int main() {
    string line;
    getline(cin, line);

    auto extractArray = [](const string& s, int start) -> pair<vector<double>, int> {
        vector<double> arr;
        int i = start;
        while (i < (int)s.size() && s[i] != '[') i++;
        i++;
        string num;
        while (i < (int)s.size() && s[i] != ']') {
            if (s[i] == ',' || s[i] == ' ') {
                if (!num.empty()) {
                    arr.push_back(stod(num));
                    num.clear();
                }
            } else {
                num += s[i];
            }
            i++;
        }
        if (!num.empty()) arr.push_back(stod(num));
        return {arr, i + 1};
    };

    int pos = line.find("\"x\"");
    auto [x, p1] = extractArray(line, pos);
    pos = line.find("\"mu\"");
    auto [mu, p2] = extractArray(line, pos);
    int k = x.size();

    pos = line.find("\"sigma\"");
    while (pos < (int)line.size() && line[pos] != '[') pos++;
    pos++;
    vector<vector<double>> sigma(k, vector<double>(k));
    for (int r = 0; r < k; r++) {
        auto [row, pe] = extractArray(line, pos);
        sigma[r] = row;
        pos = pe;
    }

    vector<double> diff(k);
    for (int i = 0; i < k; i++) diff[i] = x[i] - mu[i];

    // 高斯-约旦消元求逆矩阵和行列式
    vector<vector<double>> aug(k, vector<double>(2 * k, 0));
    for (int i = 0; i < k; i++) {
        for (int j = 0; j < k; j++) aug[i][j] = sigma[i][j];
        aug[i][k + i] = 1.0;
    }
    double det = 1.0;
    for (int col = 0; col < k; col++) {
        int pivot = col;
        for (int row = col + 1; row < k; row++)
            if (fabs(aug[row][col]) > fabs(aug[pivot][col])) pivot = row;
        if (pivot != col) { swap(aug[col], aug[pivot]); det = -det; }
        det *= aug[col][col];
        double d = aug[col][col];
        for (int j = 0; j < 2 * k; j++) aug[col][j] /= d;
        for (int row = 0; row < k; row++) {
            if (row == col) continue;
            double f = aug[row][col];
            for (int j = 0; j < 2 * k; j++) aug[row][j] -= f * aug[col][j];
        }
    }

    double exponent = 0;
    for (int i = 0; i < k; i++)
        for (int j = 0; j < k; j++)
            exponent += diff[i] * aug[i][k + j] * diff[j];

    double pdf = exp(-0.5 * exponent) / sqrt(pow(2 * M_PI, k) * fabs(det));
    printf("%.2f\n", pdf);
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        StringBuilder sb = new StringBuilder();
        while (sc.hasNextLine()) sb.append(sc.nextLine());
        String line = sb.toString();

        double[] x = parseArray(line, "\"x\"");
        double[] mu = parseArray(line, "\"mu\"");
        int k = x.length;
        double[][] sigma = parseMatrix(line, "\"sigma\"", k);

        double[] diff = new double[k];
        for (int i = 0; i < k; i++) diff[i] = x[i] - mu[i];

        // 高斯-约旦消元求逆矩阵和行列式
        double[][] aug = new double[k][2 * k];
        for (int i = 0; i < k; i++) {
            for (int j = 0; j < k; j++) aug[i][j] = sigma[i][j];
            aug[i][k + i] = 1.0;
        }
        double det = 1.0;
        for (int col = 0; col < k; col++) {
            int pivot = col;
            for (int row = col + 1; row < k; row++)
                if (Math.abs(aug[row][col]) > Math.abs(aug[pivot][col])) pivot = row;
            if (pivot != col) {
                double[] tmp = aug[col]; aug[col] = aug[pivot]; aug[pivot] = tmp;
                det = -det;
            }
            det *= aug[col][col];
            double d = aug[col][col];
            for (int j = 0; j < 2 * k; j++) aug[col][j] /= d;
            for (int row = 0; row < k; row++) {
                if (row == col) continue;
                double f = aug[row][col];
                for (int j = 0; j < 2 * k; j++) aug[row][j] -= f * aug[col][j];
            }
        }

        double exponent = 0;
        for (int i = 0; i < k; i++)
            for (int j = 0; j < k; j++)
                exponent += diff[i] * aug[i][k + j] * diff[j];

        double pdf = Math.exp(-0.5 * exponent) / Math.sqrt(Math.pow(2 * Math.PI, k) * Math.abs(det));
        System.out.printf("%.2f%n", pdf);
    }

    static double[] parseArray(String s, String key) {
        int idx = s.indexOf(key);
        idx = s.indexOf('[', idx);
        int end = s.indexOf(']', idx);
        String inner = s.substring(idx + 1, end).trim();
        if (inner.isEmpty()) return new double[0];
        String[] parts = inner.split(",");
        double[] arr = new double[parts.length];
        for (int i = 0; i < parts.length; i++) arr[i] = Double.parseDouble(parts[i].trim());
        return arr;
    }

    static double[][] parseMatrix(String s, String key, int k) {
        int idx = s.indexOf(key);
        idx = s.indexOf('[', idx);
        idx++;
        double[][] mat = new double[k][k];
        for (int r = 0; r < k; r++) {
            idx = s.indexOf('[', idx);
            int end = s.indexOf(']', idx);
            String inner = s.substring(idx + 1, end).trim();
            String[] parts = inner.split(",");
            for (int c = 0; c < k; c++) mat[r][c] = Double.parseDouble(parts[c].trim());
            idx = end + 1;
        }
        return mat;
    }
}
import json
import math
import sys

def solve():
    line = sys.stdin.read().strip()
    data = json.loads(line)
    x = data["x"]
    mu = data["mu"]
    sigma = data["sigma"]
    k = len(x)

    diff = [x[i] - mu[i] for i in range(k)]

    # 高斯-约旦消元求逆矩阵和行列式
    aug = [[0.0] * (2 * k) for _ in range(k)]
    for i in range(k):
        for j in range(k):
            aug[i][j] = float(sigma[i][j])
        aug[i][k + i] = 1.0

    det = 1.0
    for col in range(k):
        pivot = col
        for row in range(col + 1, k):
            if abs(aug[row][col]) > abs(aug[pivot][col]):
                pivot = row
        if pivot != col:
            aug[col], aug[pivot] = aug[pivot], aug[col]
            det = -det
        det *= aug[col][col]
        d = aug[col][col]
        for j in range(2 * k):
            aug[col][j] /= d
        for row in range(k):
            if row == col:
                continue
            f = aug[row][col]
            for j in range(2 * k):
                aug[row][j] -= f * aug[col][j]

    exponent = 0.0
    for i in range(k):
        for j in range(k):
            exponent += diff[i] * aug[i][k + j] * diff[j]

    pdf = math.exp(-0.5 * exponent) / math.sqrt((2 * math.pi) ** k * abs(det))
    print(f"{pdf:.2f}")

solve()

复杂度分析

  • 时间复杂度,高斯消元求逆矩阵的复杂度,其中 是向量维度。
  • 空间复杂度,存储增广矩阵。