高斯混合模型(GMM)在聚类分析中的应用

[题目链接](https://www.nowcoder.com/practice/1cc4c3afa36a44eea934be45cbbe7315)

思路

给定 个二维数据点和参数 (聚类数)、(迭代次数),要求使用 EM 算法实现高斯混合模型聚类,输出每个点的聚类标签。

高斯混合模型

GMM 假设数据由 个高斯分布的加权组合生成。每个高斯成分有三个参数:均值 、协方差矩阵 和混合系数

二维高斯的概率密度:

$$

初始化

  • 设置 np.random.seed(0),然后用 np.random.choice(N, K, replace=False) 随机选取 个数据点作为初始均值。
  • 协方差矩阵初始化为单位矩阵
  • 混合系数均等:

EM 算法

E 步骤:计算每个数据点 属于第 个高斯成分的后验概率(责任度):

$$

分母接近零时用 兜底,防止除零。

M 步骤:利用责任度更新参数:

$$

迭代 次后,每个点取 最大的 作为聚类标签。

实现要点

  1. 协方差矩阵退化处理:当协方差矩阵行列式接近零()时,对角线加 使其可逆。题目虽说"不需要正则化",但这是保证数值稳定的必要措施。
  2. 随机数一致性:C++ 和 Java 需要自行实现 Mersenne Twister (MT19937) 来精确复现 np.random.seed(0) 的随机序列。choice(n, k, replace=False) 的实现方式是:先对 做 Fisher-Yates 洗牌,再取前 个元素。
  3. 二维特化:由于数据固定为二维, 矩阵的行列式和逆矩阵可以用解析公式直接计算,无需通用矩阵库。

复杂度分析

  • 时间复杂度:,每轮 E 步和 M 步各遍历所有数据点和所有成分。
  • 空间复杂度:,存储责任度矩阵。

代码

import numpy as np
from scipy.stats import multivariate_normal

np.random.seed(0)

N = int(input())
data = []
for _ in range(N):
    x, y = map(float, input().split())
    data.append([x, y])
X = np.array(data)

K, T = map(int, input().split())

D = 2
indices = np.random.choice(N, K, replace=False)
mu = X[indices].copy()
sigma = np.array([np.eye(D) for _ in range(K)])
pi = np.ones(K) / K


def compute_pdf(X, mu_k, sigma_k):
    D = X.shape[1]
    diff = X - mu_k
    det = np.linalg.det(sigma_k)
    if det < 1e-300:
        sigma_k = sigma_k + 1e-6 * np.eye(D)
        det = np.linalg.det(sigma_k)
    inv_sigma = np.linalg.inv(sigma_k)
    norm_const = 1.0 / ((2 * np.pi) ** (D / 2.0) * np.sqrt(det))
    exponent = -0.5 * np.sum(diff @ inv_sigma * diff, axis=1)
    return norm_const * np.exp(exponent)


for _ in range(T):
    # E-step
    gamma = np.zeros((N, K))
    for k in range(K):
        gamma[:, k] = pi[k] * compute_pdf(X, mu[k], sigma[k])
    gamma_sum = gamma.sum(axis=1, keepdims=True)
    gamma_sum = np.where(gamma_sum < 1e-6, 1e-6, gamma_sum)
    gamma = gamma / gamma_sum

    # M-step
    for k in range(K):
        Nk = gamma[:, k].sum()
        if Nk < 1e-6:
            continue
        mu[k] = (gamma[:, k].reshape(-1, 1) * X).sum(axis=0) / Nk
        diff = X - mu[k]
        sigma[k] = (gamma[:, k].reshape(-1, 1, 1) * np.einsum('ij,ik->ijk', diff, diff)).sum(axis=0) / Nk
        pi[k] = Nk / N

labels = gamma.argmax(axis=1)
for label in labels:
    print(label)
#include <bits/stdc++.h>
using namespace std;

class MT19937 {
    static const int N_MT = 624, M = 397;
    unsigned long mt[624];
    int mti;
public:
    MT19937(unsigned long seed) {
        mt[0] = seed & 0xffffffffUL;
        for (mti = 1; mti < N_MT; mti++) {
            mt[mti] = (1812433253UL * (mt[mti-1] ^ (mt[mti-1] >> 30)) + mti);
            mt[mti] &= 0xffffffffUL;
        }
    }
    unsigned long genrand() {
        unsigned long y;
        static unsigned long mag01[2] = {0x0UL, 0x9908b0dfUL};
        if (mti >= N_MT) {
            int kk;
            for (kk = 0; kk < N_MT - M; kk++) {
                y = (mt[kk] & 0x80000000UL) | (mt[kk+1] & 0x7fffffffUL);
                mt[kk] = mt[kk+M] ^ (y >> 1) ^ mag01[y & 1UL];
            }
            for (; kk < N_MT - 1; kk++) {
                y = (mt[kk] & 0x80000000UL) | (mt[kk+1] & 0x7fffffffUL);
                mt[kk] = mt[kk+(M-N_MT)] ^ (y >> 1) ^ mag01[y & 1UL];
            }
            y = (mt[N_MT-1] & 0x80000000UL) | (mt[0] & 0x7fffffffUL);
            mt[N_MT-1] = mt[M-1] ^ (y >> 1) ^ mag01[y & 1UL];
            mti = 0;
        }
        y = mt[mti++];
        y ^= (y >> 11);
        y ^= (y << 7) & 0x9d2c5680UL;
        y ^= (y << 15) & 0xefc60000UL;
        y ^= (y >> 18);
        return y;
    }
    int randBelow(int n) {
        if (n <= 1) return 0;
        unsigned long mask = n - 1;
        mask |= mask >> 1; mask |= mask >> 2; mask |= mask >> 4;
        mask |= mask >> 8; mask |= mask >> 16;
        unsigned long r;
        do { r = genrand() & mask; } while (r >= (unsigned long)n);
        return (int)r;
    }
    vector<int> choice(int n, int k) {
        vector<int> pool(n);
        iota(pool.begin(), pool.end(), 0);
        for (int i = n - 1; i > 0; i--)
            swap(pool[i], pool[randBelow(i + 1)]);
        return vector<int>(pool.begin(), pool.begin() + k);
    }
};

const double PI2 = 2.0 * acos(-1.0);

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    cin >> N;
    vector<array<double,2>> X(N);
    for (int i = 0; i < N; i++) cin >> X[i][0] >> X[i][1];
    int K, T;
    cin >> K >> T;

    MT19937 rng(0);
    auto indices = rng.choice(N, K);

    vector<array<double,2>> mu(K);
    vector<array<array<double,2>,2>> sigma(K);
    vector<double> pi_k(K, 1.0 / K);

    for (int k = 0; k < K; k++) {
        mu[k] = X[indices[k]];
        sigma[k] = {{{1,0},{0,1}}};
    }

    vector<vector<double>> gamma(N, vector<double>(K));

    for (int t = 0; t < T; t++) {
        for (int k = 0; k < K; k++) {
            auto& s = sigma[k];
            double s00 = s[0][0], s01 = s[0][1], s10 = s[1][0], s11 = s[1][1];
            double det = s00 * s11 - s01 * s10;
            if (fabs(det) < 1e-300) {
                s00 += 1e-6; s11 += 1e-6;
                det = s00 * s11 - s01 * s10;
            }
            double inv00 = s11/det, inv01 = -s01/det, inv10 = -s10/det, inv11 = s00/det;
            double nc = 1.0 / (PI2 * sqrt(fabs(det)));
            for (int i = 0; i < N; i++) {
                double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
                double e = -0.5 * (d0*(inv00*d0+inv01*d1) + d1*(inv10*d0+inv11*d1));
                gamma[i][k] = pi_k[k] * nc * exp(e);
            }
        }
        for (int i = 0; i < N; i++) {
            double sum = 0;
            for (int k = 0; k < K; k++) sum += gamma[i][k];
            if (sum < 1e-6) sum = 1e-6;
            for (int k = 0; k < K; k++) gamma[i][k] /= sum;
        }

        for (int k = 0; k < K; k++) {
            double Nk = 0;
            for (int i = 0; i < N; i++) Nk += gamma[i][k];
            if (Nk < 1e-6) continue;
            mu[k] = {0, 0};
            for (int i = 0; i < N; i++) {
                mu[k][0] += gamma[i][k] * X[i][0];
                mu[k][1] += gamma[i][k] * X[i][1];
            }
            mu[k][0] /= Nk; mu[k][1] /= Nk;
            sigma[k] = {{{0,0},{0,0}}};
            for (int i = 0; i < N; i++) {
                double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
                sigma[k][0][0] += gamma[i][k] * d0 * d0;
                sigma[k][0][1] += gamma[i][k] * d0 * d1;
                sigma[k][1][0] += gamma[i][k] * d1 * d0;
                sigma[k][1][1] += gamma[i][k] * d1 * d1;
            }
            sigma[k][0][0] /= Nk; sigma[k][0][1] /= Nk;
            sigma[k][1][0] /= Nk; sigma[k][1][1] /= Nk;
            pi_k[k] = Nk / N;
        }
    }

    for (int i = 0; i < N; i++) {
        int label = 0;
        for (int k = 1; k < K; k++)
            if (gamma[i][k] > gamma[i][label]) label = k;
        cout << label << '\n';
    }
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        double[][] X = new double[N][2];
        for (int i = 0; i < N; i++) {
            X[i][0] = sc.nextDouble();
            X[i][1] = sc.nextDouble();
        }
        int K = sc.nextInt(), T = sc.nextInt();

        MT19937 rng = new MT19937(0);
        int[] indices = rng.choice(N, K);

        double[][] mu = new double[K][2];
        double[][][] sigma = new double[K][2][2];
        double[] pi_k = new double[K];
        double[][] gamma = new double[N][K];
        double PI2 = 2.0 * Math.PI;

        for (int k = 0; k < K; k++) {
            mu[k][0] = X[indices[k]][0];
            mu[k][1] = X[indices[k]][1];
            sigma[k][0][0] = 1;
            sigma[k][1][1] = 1;
            pi_k[k] = 1.0 / K;
        }

        for (int t = 0; t < T; t++) {
            for (int k = 0; k < K; k++) {
                double s00 = sigma[k][0][0], s01 = sigma[k][0][1];
                double s10 = sigma[k][1][0], s11 = sigma[k][1][1];
                double det = s00 * s11 - s01 * s10;
                if (Math.abs(det) < 1e-300) {
                    s00 += 1e-6; s11 += 1e-6;
                    det = s00 * s11 - s01 * s10;
                }
                double inv00 = s11 / det, inv01 = -s01 / det;
                double inv10 = -s10 / det, inv11 = s00 / det;
                double nc = 1.0 / (PI2 * Math.sqrt(Math.abs(det)));
                for (int i = 0; i < N; i++) {
                    double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
                    double e = -0.5 * (d0 * (inv00 * d0 + inv01 * d1) + d1 * (inv10 * d0 + inv11 * d1));
                    gamma[i][k] = pi_k[k] * nc * Math.exp(e);
                }
            }
            for (int i = 0; i < N; i++) {
                double sum = 0;
                for (int k = 0; k < K; k++) sum += gamma[i][k];
                if (sum < 1e-6) sum = 1e-6;
                for (int k = 0; k < K; k++) gamma[i][k] /= sum;
            }

            for (int k = 0; k < K; k++) {
                double Nk = 0;
                for (int i = 0; i < N; i++) Nk += gamma[i][k];
                if (Nk < 1e-6) continue;
                mu[k][0] = 0; mu[k][1] = 0;
                for (int i = 0; i < N; i++) {
                    mu[k][0] += gamma[i][k] * X[i][0];
                    mu[k][1] += gamma[i][k] * X[i][1];
                }
                mu[k][0] /= Nk; mu[k][1] /= Nk;
                sigma[k] = new double[2][2];
                for (int i = 0; i < N; i++) {
                    double d0 = X[i][0] - mu[k][0], d1 = X[i][1] - mu[k][1];
                    sigma[k][0][0] += gamma[i][k] * d0 * d0;
                    sigma[k][0][1] += gamma[i][k] * d0 * d1;
                    sigma[k][1][0] += gamma[i][k] * d1 * d0;
                    sigma[k][1][1] += gamma[i][k] * d1 * d1;
                }
                sigma[k][0][0] /= Nk; sigma[k][0][1] /= Nk;
                sigma[k][1][0] /= Nk; sigma[k][1][1] /= Nk;
                pi_k[k] = Nk / N;
            }
        }

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < N; i++) {
            int label = 0;
            double maxG = gamma[i][0];
            for (int k = 1; k < K; k++) {
                if (gamma[i][k] > maxG) { maxG = gamma[i][k]; label = k; }
            }
            sb.append(label).append('\n');
        }
        System.out.print(sb);
    }
}

class MT19937 {
    private static final int N_MT = 624, M = 397;
    private long[] mt = new long[N_MT];
    private int mti;

    MT19937(long seed) {
        mt[0] = seed & 0xffffffffL;
        for (mti = 1; mti < N_MT; mti++)
            mt[mti] = (1812433253L * (mt[mti - 1] ^ (mt[mti - 1] >> 30)) + mti) & 0xffffffffL;
    }

    long genrand() {
        long y;
        long[] mag01 = {0x0L, 0x9908b0dfL};
        if (mti >= N_MT) {
            int kk;
            for (kk = 0; kk < N_MT - M; kk++) {
                y = (mt[kk] & 0x80000000L) | (mt[kk + 1] & 0x7fffffffL);
                mt[kk] = mt[kk + M] ^ (y >> 1) ^ mag01[(int) (y & 1L)];
            }
            for (; kk < N_MT - 1; kk++) {
                y = (mt[kk] & 0x80000000L) | (mt[kk + 1] & 0x7fffffffL);
                mt[kk] = mt[kk + (M - N_MT)] ^ (y >> 1) ^ mag01[(int) (y & 1L)];
            }
            y = (mt[N_MT - 1] & 0x80000000L) | (mt[0] & 0x7fffffffL);
            mt[N_MT - 1] = mt[M - 1] ^ (y >> 1) ^ mag01[(int) (y & 1L)];
            mti = 0;
        }
        y = mt[mti++];
        y ^= (y >> 11);
        y ^= (y << 7) & 0x9d2c5680L;
        y ^= (y << 15) & 0xefc60000L;
        y ^= (y >> 18);
        return y;
    }

    int randBelow(int n) {
        if (n <= 1) return 0;
        long mask = n - 1;
        mask |= mask >> 1; mask |= mask >> 2; mask |= mask >> 4;
        mask |= mask >> 8; mask |= mask >> 16;
        long r;
        do { r = genrand() & mask; } while (r >= n);
        return (int) r;
    }

    int[] choice(int n, int k) {
        int[] pool = new int[n];
        for (int i = 0; i < n; i++) pool[i] = i;
        for (int i = n - 1; i > 0; i--) {
            int j = randBelow(i + 1);
            int tmp = pool[i]; pool[i] = pool[j]; pool[j] = tmp;
        }
        return java.util.Arrays.copyOf(pool, k);
    }
}