多元高斯分布的概率密度函数计算
[题目链接](https://www.nowcoder.com/practice/2cce779324004584bec7fe8d1af1f5b7)
思路
给定数据点 、均值向量
和协方差矩阵
,计算多元高斯分布的概率密度函数(PDF)值。
多元高斯分布 PDF 公式
$$
其中 是维度,
是协方差矩阵的行列式,
是协方差矩阵的逆矩阵。
计算步骤
- 解析输入:从 JSON 字符串中提取
、
、
。
- 计算差向量
。
- 高斯消元求逆矩阵和行列式:构造增广矩阵
,通过带部分主元选取的高斯-约旦消元法,同时求出行列式
和逆矩阵
。
- 计算指数项:
(二次型)。
- 组合结果:将归一化系数和指数项相乘得到 PDF 值,保留两位小数输出。
样例演示
输入 ,
,
(单位矩阵):
- 差向量
,
- 二次型
- 归一化系数
- PDF
- 保留两位小数输出
代码
#include <bits/stdc++.h>
using namespace std;
int main() {
string line;
getline(cin, line);
auto extractArray = [](const string& s, int start) -> pair<vector<double>, int> {
vector<double> arr;
int i = start;
while (i < (int)s.size() && s[i] != '[') i++;
i++;
string num;
while (i < (int)s.size() && s[i] != ']') {
if (s[i] == ',' || s[i] == ' ') {
if (!num.empty()) {
arr.push_back(stod(num));
num.clear();
}
} else {
num += s[i];
}
i++;
}
if (!num.empty()) arr.push_back(stod(num));
return {arr, i + 1};
};
int pos = line.find("\"x\"");
auto [x, p1] = extractArray(line, pos);
pos = line.find("\"mu\"");
auto [mu, p2] = extractArray(line, pos);
int k = x.size();
pos = line.find("\"sigma\"");
while (pos < (int)line.size() && line[pos] != '[') pos++;
pos++;
vector<vector<double>> sigma(k, vector<double>(k));
for (int r = 0; r < k; r++) {
auto [row, pe] = extractArray(line, pos);
sigma[r] = row;
pos = pe;
}
vector<double> diff(k);
for (int i = 0; i < k; i++) diff[i] = x[i] - mu[i];
// 高斯-约旦消元求逆矩阵和行列式
vector<vector<double>> aug(k, vector<double>(2 * k, 0));
for (int i = 0; i < k; i++) {
for (int j = 0; j < k; j++) aug[i][j] = sigma[i][j];
aug[i][k + i] = 1.0;
}
double det = 1.0;
for (int col = 0; col < k; col++) {
int pivot = col;
for (int row = col + 1; row < k; row++)
if (fabs(aug[row][col]) > fabs(aug[pivot][col])) pivot = row;
if (pivot != col) { swap(aug[col], aug[pivot]); det = -det; }
det *= aug[col][col];
double d = aug[col][col];
for (int j = 0; j < 2 * k; j++) aug[col][j] /= d;
for (int row = 0; row < k; row++) {
if (row == col) continue;
double f = aug[row][col];
for (int j = 0; j < 2 * k; j++) aug[row][j] -= f * aug[col][j];
}
}
double exponent = 0;
for (int i = 0; i < k; i++)
for (int j = 0; j < k; j++)
exponent += diff[i] * aug[i][k + j] * diff[j];
double pdf = exp(-0.5 * exponent) / sqrt(pow(2 * M_PI, k) * fabs(det));
printf("%.2f\n", pdf);
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
StringBuilder sb = new StringBuilder();
while (sc.hasNextLine()) sb.append(sc.nextLine());
String line = sb.toString();
double[] x = parseArray(line, "\"x\"");
double[] mu = parseArray(line, "\"mu\"");
int k = x.length;
double[][] sigma = parseMatrix(line, "\"sigma\"", k);
double[] diff = new double[k];
for (int i = 0; i < k; i++) diff[i] = x[i] - mu[i];
// 高斯-约旦消元求逆矩阵和行列式
double[][] aug = new double[k][2 * k];
for (int i = 0; i < k; i++) {
for (int j = 0; j < k; j++) aug[i][j] = sigma[i][j];
aug[i][k + i] = 1.0;
}
double det = 1.0;
for (int col = 0; col < k; col++) {
int pivot = col;
for (int row = col + 1; row < k; row++)
if (Math.abs(aug[row][col]) > Math.abs(aug[pivot][col])) pivot = row;
if (pivot != col) {
double[] tmp = aug[col]; aug[col] = aug[pivot]; aug[pivot] = tmp;
det = -det;
}
det *= aug[col][col];
double d = aug[col][col];
for (int j = 0; j < 2 * k; j++) aug[col][j] /= d;
for (int row = 0; row < k; row++) {
if (row == col) continue;
double f = aug[row][col];
for (int j = 0; j < 2 * k; j++) aug[row][j] -= f * aug[col][j];
}
}
double exponent = 0;
for (int i = 0; i < k; i++)
for (int j = 0; j < k; j++)
exponent += diff[i] * aug[i][k + j] * diff[j];
double pdf = Math.exp(-0.5 * exponent) / Math.sqrt(Math.pow(2 * Math.PI, k) * Math.abs(det));
System.out.printf("%.2f%n", pdf);
}
static double[] parseArray(String s, String key) {
int idx = s.indexOf(key);
idx = s.indexOf('[', idx);
int end = s.indexOf(']', idx);
String inner = s.substring(idx + 1, end).trim();
if (inner.isEmpty()) return new double[0];
String[] parts = inner.split(",");
double[] arr = new double[parts.length];
for (int i = 0; i < parts.length; i++) arr[i] = Double.parseDouble(parts[i].trim());
return arr;
}
static double[][] parseMatrix(String s, String key, int k) {
int idx = s.indexOf(key);
idx = s.indexOf('[', idx);
idx++;
double[][] mat = new double[k][k];
for (int r = 0; r < k; r++) {
idx = s.indexOf('[', idx);
int end = s.indexOf(']', idx);
String inner = s.substring(idx + 1, end).trim();
String[] parts = inner.split(",");
for (int c = 0; c < k; c++) mat[r][c] = Double.parseDouble(parts[c].trim());
idx = end + 1;
}
return mat;
}
}
import json
import math
import sys
def solve():
line = sys.stdin.read().strip()
data = json.loads(line)
x = data["x"]
mu = data["mu"]
sigma = data["sigma"]
k = len(x)
diff = [x[i] - mu[i] for i in range(k)]
# 高斯-约旦消元求逆矩阵和行列式
aug = [[0.0] * (2 * k) for _ in range(k)]
for i in range(k):
for j in range(k):
aug[i][j] = float(sigma[i][j])
aug[i][k + i] = 1.0
det = 1.0
for col in range(k):
pivot = col
for row in range(col + 1, k):
if abs(aug[row][col]) > abs(aug[pivot][col]):
pivot = row
if pivot != col:
aug[col], aug[pivot] = aug[pivot], aug[col]
det = -det
det *= aug[col][col]
d = aug[col][col]
for j in range(2 * k):
aug[col][j] /= d
for row in range(k):
if row == col:
continue
f = aug[row][col]
for j in range(2 * k):
aug[row][j] -= f * aug[col][j]
exponent = 0.0
for i in range(k):
for j in range(k):
exponent += diff[i] * aug[i][k + j] * diff[j]
pdf = math.exp(-0.5 * exponent) / math.sqrt((2 * math.pi) ** k * abs(det))
print(f"{pdf:.2f}")
solve()
复杂度分析
- 时间复杂度:
,高斯消元求逆矩阵的复杂度,其中
是向量维度。
- 空间复杂度:
,存储增广矩阵。

京公网安备 11010502036488号