高斯混合模型(GMM)在聚类分析中的应用
[题目链接](https://www.nowcoder.com/practice/1cc4c3afa36a44eea934be45cbbe7315)
思路
给定 个二维数据点和参数
(聚类数)、
(迭代次数),要求使用 EM 算法实现高斯混合模型聚类,输出每个点的聚类标签。
高斯混合模型
GMM 假设数据由 个高斯分布的加权组合生成。每个高斯成分有三个参数:均值
、协方差矩阵
和混合系数
。
二维高斯的概率密度:
$$
初始化
- 设置
np.random.seed(0),然后用np.random.choice(N, K, replace=False)随机选取个数据点作为初始均值。
- 协方差矩阵初始化为单位矩阵
。
- 混合系数均等:
。
EM 算法
E 步骤:计算每个数据点 属于第
个高斯成分的后验概率(责任度):
$$
分母接近零时用 兜底,防止除零。
M 步骤:利用责任度更新参数:
$$
迭代 次后,每个点取
最大的
作为聚类标签。
实现要点
- 协方差矩阵退化处理:当协方差矩阵行列式接近零(
)时,对角线加
使其可逆。题目虽说"不需要正则化",但这是保证数值稳定的必要措施。
- 随机数一致性:C++ 和 Java 需要自行实现 Mersenne Twister (MT19937) 来精确复现
np.random.seed(0)的随机序列。choice(n, k, replace=False)的实现方式是:先对做 Fisher-Yates 洗牌,再取前
个元素。
- 二维特化:由于数据固定为二维,
矩阵的行列式和逆矩阵可以用解析公式直接计算,无需通用矩阵库。
复杂度分析
- 时间复杂度:
,每轮 E 步和 M 步各遍历所有数据点和所有成分。
- 空间复杂度:
,存储责任度矩阵。
代码
import numpy as np
from scipy.stats import multivariate_normal
np.random.seed(0)
N = int(input())
data = []
for _ in range(N):
x, y = map(float, input().split())
data.append([x, y])
X = np.array(data)
K, T = map(int, input().split())
D = 2
indices = np.random.choice(N, K, replace=False)
mu = X[indices].copy()
sigma = np.array([np.eye(D) for _ in range(K)])
pi = np.ones(K) / K
def compute_pdf(X, mu_k, sigma_k):
D = X.shape[1]
diff = X - mu_k
det = np.linalg.det(sigma_k)
if det < 1e-300:
sigma_k = sigma_k + 1e-6 * np.eye(D)
det = np.linalg.det(sigma_k)
inv_sigma = np.linalg.inv(sigma_k)
norm_const = 1.0 / ((2 * np.pi) ** (D / 2.0) * np.sqrt(det))
exponent = -0.5 * np.sum(diff @ inv_sigma * diff, axis=1)
return norm_const * np.exp(exponent)
for _ in range(T):
# E-step
gamma = np.zeros((N, K))
for k in range(K):
gamma[:, k] = pi[k] * compute_pdf(X, mu[k], sigma[k])
gamma_sum = gamma.sum(axis=1, keepdims=True)
gamma_sum = np.where(gamma_sum < 1e-6, 1e-6, gamma_sum)
gamma = gamma / gamma_sum
# M-step
for k in range(K):
Nk = gamma[:, k].sum()
if Nk < 1e-6:
continue
mu[k] = (gamma[:, k].reshape(-1, 1) * X).sum(axis=0) / Nk
diff = X - mu[k]
sigma[k] = (gamma[:, k].reshape(-1, 1, 1) * np.einsum('ij,ik->ijk', diff, diff)).sum(axis=0) / Nk
pi[k] = Nk / N
labels = gamma.argmax(axis=1)
for label in labels:
print(label)
#include <bits/stdc++.h>
using namespace std;
class MT19937 {
static const int N_MT = 624, M = 397;
unsigned long mt[624];
int mti;
public:
MT19937(unsigned long seed) {
mt[0] = seed & 0xffffffffUL;
for (mti = 1; mti < N_MT; mti++) {
mt[mti] = (1812433253UL * (mt[mti-1] ^ (mt[mti-1] >> 30)) + mti);
mt[mti] &= 0xffffffffUL;
}
}
unsigned long genrand() {
unsigned long y;
static unsigned long mag01[2] = {0x0UL, 0x9908b0dfUL};
if (mti >= N_MT) {
int kk;
for (kk = 0; kk < N_MT - M; kk++) {
y = (mt[kk] & 0x80000000UL) | (mt[kk+1] & 0x7fffffffUL);
mt[kk] = mt[kk+M] ^ (y >> 1) ^ mag01[y & 1UL];
}
for (; kk < N_MT - 1; kk++) {
y = (mt[kk] & 0x80000000UL) | (mt[kk+1] & 0x7fffffffUL);
mt[kk] = mt[kk+(M-N_MT)] ^ (y >> 1) ^ mag01[y & 1UL];
}
y = (mt[N_MT-1] & 0x80000000UL) | (mt[0] & 0x7fffffffUL);
mt[N_MT-1] = mt[M-1] ^ (y >> 1) ^ mag01[y & 1UL];
mti = 0;
}
y = mt[mti++];
y ^= (y >> 11);
y ^= (y << 7) & 0x9d2c5680UL;
y ^= (y << 15) & 0xefc60000UL;
y ^= (y >> 18);
return y;
}
int randBelow(int n) {
if (n <= 1) return 0;
unsigned long mask = n - 1;
mask |= mask >> 1; mask |= mask >> 2; mask |= mask >> 4;
mask |= mask >> 8; mask |= mask >> 16;
unsigned long r;
do { r = genrand() & mask; } while (r >= (unsigned long)n);
return (int)r;
}
vector<int> choice(int n, int k) {
vector<int> pool(n);
iota(pool.begin(), pool.end(), 0);
for (int i = n - 1; i > 0; i--)
swap(pool[i], pool[randBelow(i + 1)]);
return vector<int>(pool.begin(), pool.begin() + k);
}
};
const double PI2 = 2.0 * acos(-1.0);
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
vector<array<double,2>> X(N);
for (int i = 0; i < N; i++) cin >> X[i][0] >> X[i][1];
int K, T;
cin >> K >> T;
MT19937 rng(0);
auto indices = rng.choice(N, K);
vector<array<double,2>> mu(K);
vector<array<array<double,2>,2>> sigma(K);
vector<double> pi_k(K, 1.0 / K);
for (int k = 0; k < K; k++) {
mu[k] = X[indices[k]];
sigma[k] = {{{1,0},{0,1}}};
}
vector<vector<double>> gamma(N, vector<double>(K));
for (int t = 0; t < T; t++) {
for (int k = 0; k < K; k++) {
auto& s = sigma[k];
double s00 = s[0][0], s01 = s[0][1], s10 = s[1][0], s11 = s[1][1];
double det = s00 * s11 - s01 * s10;
if (fabs(det) < 1e-300) {
s00 += 1e-6; s11 += 1e-6;
det = s00 * s11 - s01 * s10;
}
double inv00 = s11/det, inv01 = -s01/det, inv10 = -s10/det, inv11 = s00/det;
double nc = 1.0 / (PI2 * sqrt(fabs(det)));
for (int i = 0; i < N; i++) {
double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
double e = -0.5 * (d0*(inv00*d0+inv01*d1) + d1*(inv10*d0+inv11*d1));
gamma[i][k] = pi_k[k] * nc * exp(e);
}
}
for (int i = 0; i < N; i++) {
double sum = 0;
for (int k = 0; k < K; k++) sum += gamma[i][k];
if (sum < 1e-6) sum = 1e-6;
for (int k = 0; k < K; k++) gamma[i][k] /= sum;
}
for (int k = 0; k < K; k++) {
double Nk = 0;
for (int i = 0; i < N; i++) Nk += gamma[i][k];
if (Nk < 1e-6) continue;
mu[k] = {0, 0};
for (int i = 0; i < N; i++) {
mu[k][0] += gamma[i][k] * X[i][0];
mu[k][1] += gamma[i][k] * X[i][1];
}
mu[k][0] /= Nk; mu[k][1] /= Nk;
sigma[k] = {{{0,0},{0,0}}};
for (int i = 0; i < N; i++) {
double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
sigma[k][0][0] += gamma[i][k] * d0 * d0;
sigma[k][0][1] += gamma[i][k] * d0 * d1;
sigma[k][1][0] += gamma[i][k] * d1 * d0;
sigma[k][1][1] += gamma[i][k] * d1 * d1;
}
sigma[k][0][0] /= Nk; sigma[k][0][1] /= Nk;
sigma[k][1][0] /= Nk; sigma[k][1][1] /= Nk;
pi_k[k] = Nk / N;
}
}
for (int i = 0; i < N; i++) {
int label = 0;
for (int k = 1; k < K; k++)
if (gamma[i][k] > gamma[i][label]) label = k;
cout << label << '\n';
}
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
double[][] X = new double[N][2];
for (int i = 0; i < N; i++) {
X[i][0] = sc.nextDouble();
X[i][1] = sc.nextDouble();
}
int K = sc.nextInt(), T = sc.nextInt();
MT19937 rng = new MT19937(0);
int[] indices = rng.choice(N, K);
double[][] mu = new double[K][2];
double[][][] sigma = new double[K][2][2];
double[] pi_k = new double[K];
double[][] gamma = new double[N][K];
double PI2 = 2.0 * Math.PI;
for (int k = 0; k < K; k++) {
mu[k][0] = X[indices[k]][0];
mu[k][1] = X[indices[k]][1];
sigma[k][0][0] = 1;
sigma[k][1][1] = 1;
pi_k[k] = 1.0 / K;
}
for (int t = 0; t < T; t++) {
for (int k = 0; k < K; k++) {
double s00 = sigma[k][0][0], s01 = sigma[k][0][1];
double s10 = sigma[k][1][0], s11 = sigma[k][1][1];
double det = s00 * s11 - s01 * s10;
if (Math.abs(det) < 1e-300) {
s00 += 1e-6; s11 += 1e-6;
det = s00 * s11 - s01 * s10;
}
double inv00 = s11 / det, inv01 = -s01 / det;
double inv10 = -s10 / det, inv11 = s00 / det;
double nc = 1.0 / (PI2 * Math.sqrt(Math.abs(det)));
for (int i = 0; i < N; i++) {
double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
double e = -0.5 * (d0 * (inv00 * d0 + inv01 * d1) + d1 * (inv10 * d0 + inv11 * d1));
gamma[i][k] = pi_k[k] * nc * Math.exp(e);
}
}
for (int i = 0; i < N; i++) {
double sum = 0;
for (int k = 0; k < K; k++) sum += gamma[i][k];
if (sum < 1e-6) sum = 1e-6;
for (int k = 0; k < K; k++) gamma[i][k] /= sum;
}
for (int k = 0; k < K; k++) {
double Nk = 0;
for (int i = 0; i < N; i++) Nk += gamma[i][k];
if (Nk < 1e-6) continue;
mu[k][0] = 0; mu[k][1] = 0;
for (int i = 0; i < N; i++) {
mu[k][0] += gamma[i][k] * X[i][0];
mu[k][1] += gamma[i][k] * X[i][1];
}
mu[k][0] /= Nk; mu[k][1] /= Nk;
sigma[k] = new double[2][2];
for (int i = 0; i < N; i++) {
double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
sigma[k][0][0] += gamma[i][k] * d0 * d0;
sigma[k][0][1] += gamma[i][k] * d0 * d1;
sigma[k][1][0] += gamma[i][k] * d1 * d0;
sigma[k][1][1] += gamma[i][k] * d1 * d1;
}
sigma[k][0][0] /= Nk; sigma[k][0][1] /= Nk;
sigma[k][1][0] /= Nk; sigma[k][1][1] /= Nk;
pi_k[k] = Nk / N;
}
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < N; i++) {
int label = 0;
double maxG = gamma[i][0];
for (int k = 1; k < K; k++) {
if (gamma[i][k] > maxG) { maxG = gamma[i][k]; label = k; }
}
sb.append(label).append('\n');
}
System.out.print(sb);
}
}
class MT19937 {
private static final int N_MT = 624, M = 397;
private long[] mt = new long[N_MT];
private int mti;
MT19937(long seed) {
mt[0] = seed & 0xffffffffL;
for (mti = 1; mti < N_MT; mti++)
mt[mti] = (1812433253L * (mt[mti - 1] ^ (mt[mti - 1] >> 30)) + mti) & 0xffffffffL;
}
long genrand() {
long y;
long[] mag01 = {0x0L, 0x9908b0dfL};
if (mti >= N_MT) {
int kk;
for (kk = 0; kk < N_MT - M; kk++) {
y = (mt[kk] & 0x80000000L) | (mt[kk + 1] & 0x7fffffffL);
mt[kk] = mt[kk + M] ^ (y >> 1) ^ mag01[(int) (y & 1L)];
}
for (; kk < N_MT - 1; kk++) {
y = (mt[kk] & 0x80000000L) | (mt[kk + 1] & 0x7fffffffL);
mt[kk] = mt[kk + (M - N_MT)] ^ (y >> 1) ^ mag01[(int) (y & 1L)];
}
y = (mt[N_MT - 1] & 0x80000000L) | (mt[0] & 0x7fffffffL);
mt[N_MT - 1] = mt[M - 1] ^ (y >> 1) ^ mag01[(int) (y & 1L)];
mti = 0;
}
y = mt[mti++];
y ^= (y >> 11);
y ^= (y << 7) & 0x9d2c5680L;
y ^= (y << 15) & 0xefc60000L;
y ^= (y >> 18);
return y;
}
int randBelow(int n) {
if (n <= 1) return 0;
long mask = n - 1;
mask |= mask >> 1; mask |= mask >> 2; mask |= mask >> 4;
mask |= mask >> 8; mask |= mask >> 16;
long r;
do { r = genrand() & mask; } while (r >= n);
return (int) r;
}
int[] choice(int n, int k) {
int[] pool = new int[n];
for (int i = 0; i < n; i++) pool[i] = i;
for (int i = n - 1; i > 0; i--) {
int j = randBelow(i + 1);
int tmp = pool[i]; pool[i] = pool[j]; pool[j] = tmp;
}
return java.util.Arrays.copyOf(pool, k);
}
}

京公网安备 11010502036488号