特征相关性排序

[牛客链接](https://www.nowcoder.com/practice/16d30d36a32d414c8bd85507773d8a80)

思路

本题要求计算数据集中每个特征与目标变量之间的 皮尔逊相关系数(Pearson Correlation Coefficient),并按相关系数从大到小排序输出。

输入格式

  • 第一行:(样本数)和 (特征数)
  • 接下来 行:每行 个浮点数,前 个为特征值,最后一个为目标值
  • 最后一行:一个整数 (本题中未影响输出)

皮尔逊相关系数

对于特征列 和目标列 ,皮尔逊相关系数定义为:

$$

其中 分别是 的均值。当分母为 0 时(即某列方差为 0),相关系数取 0。

核心步骤

  1. 读入数据:将每行的前 个值作为特征,第 个值作为目标。
  2. 计算目标列的均值和偏差:预先算好 和每个 ,避免重复计算。
  3. 逐特征计算相关系数:对每个特征列 ,计算 和各偏差,代入公式得到
  4. 排序输出:按相关系数降序排列,相关系数相同时按特征编号升序排列,保留 4 位小数。

样例演示

输入的 6 个样本、4 个特征中,特征 0 与目标完全线性相关(),其余特征相关性依次递减:

特征 相关系数
0 1.0000
1 0.8286
2 0.6713
3 0.5562

整个过程就是按定义直接计算即可,没有特殊的算法技巧。

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    vector<vector<double>> data(n, vector<double>(m + 1));
    for(int i = 0; i < n; i++)
        for(int j = 0; j <= m; j++)
            cin >> data[i][j];
    int k;
    cin >> k;

    double mean_y = 0;
    for(int i = 0; i < n; i++) mean_y += data[i][m];
    mean_y /= n;

    double den_y = 0;
    for(int i = 0; i < n; i++){
        double d = data[i][m] - mean_y;
        den_y += d * d;
    }
    den_y = sqrt(den_y);

    vector<pair<double, int>> results;
    for(int j = 0; j < m; j++){
        double mean_x = 0;
        for(int i = 0; i < n; i++) mean_x += data[i][j];
        mean_x /= n;

        double num = 0, den_x = 0;
        for(int i = 0; i < n; i++){
            double dx = data[i][j] - mean_x;
            double dy = data[i][m] - mean_y;
            num += dx * dy;
            den_x += dx * dx;
        }
        den_x = sqrt(den_x);
        double corr = (den_x * den_y > 0) ? num / (den_x * den_y) : 0.0;
        results.push_back({corr, j});
    }

    sort(results.begin(), results.end(), [](auto& a, auto& b){
        if(a.first != b.first) return a.first > b.first;
        return a.second < b.second;
    });

    cout << fixed << setprecision(4);
    for(auto& [corr, idx] : results){
        cout << idx << " " << corr << "\n";
    }
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();
        double[][] data = new double[n][m + 1];
        for (int i = 0; i < n; i++)
            for (int j = 0; j <= m; j++)
                data[i][j] = sc.nextDouble();
        int k = sc.nextInt();

        double meanY = 0;
        for (int i = 0; i < n; i++) meanY += data[i][m];
        meanY /= n;

        double denY = 0;
        for (int i = 0; i < n; i++) {
            double d = data[i][m] - meanY;
            denY += d * d;
        }
        denY = Math.sqrt(denY);

        double[][] results = new double[m][2];
        for (int j = 0; j < m; j++) {
            double meanX = 0;
            for (int i = 0; i < n; i++) meanX += data[i][j];
            meanX /= n;

            double num = 0, denX = 0;
            for (int i = 0; i < n; i++) {
                double dx = data[i][j] - meanX;
                double dy = data[i][m] - meanY;
                num += dx * dy;
                denX += dx * dx;
            }
            denX = Math.sqrt(denX);
            double corr = (denX * denY > 0) ? num / (denX * denY) : 0.0;
            results[j][0] = j;
            results[j][1] = corr;
        }

        Arrays.sort(results, (a, b) -> {
            if (a[1] != b[1]) return Double.compare(b[1], a[1]);
            return Double.compare(a[0], b[0]);
        });

        StringBuilder sb = new StringBuilder();
        for (double[] r : results) {
            sb.append((int) r[0]).append(" ").append(String.format("%.4f", r[1])).append("\n");
        }
        System.out.print(sb);
    }
}
import sys
input = sys.stdin.readline

def main():
    n, m = map(int, input().split())
    data = []
    for _ in range(n):
        data.append(list(map(float, input().split())))
    k = int(input().strip())

    y = [data[i][m] for i in range(n)]
    mean_y = sum(y) / n
    dy = [y[i] - mean_y for i in range(n)]
    den_y = sum(d * d for d in dy) ** 0.5

    results = []
    for j in range(m):
        x = [data[i][j] for i in range(n)]
        mean_x = sum(x) / n
        dx = [x[i] - mean_x for i in range(n)]
        den_x = sum(d * d for d in dx) ** 0.5
        num = sum(dx[i] * dy[i] for i in range(n))
        corr = num / (den_x * den_y) if den_x * den_y > 0 else 0.0
        results.append((j, corr))

    results.sort(key=lambda t: (-t[1], t[0]))
    for idx, corr in results:
        print(f"{idx} {corr:.4f}")

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const [n, m] = lines[0].split(' ').map(Number);
    const data = [];
    for (let i = 1; i <= n; i++) {
        data.push(lines[i].split(' ').map(Number));
    }
    const k = parseInt(lines[n + 1]);

    let meanY = 0;
    for (let i = 0; i < n; i++) meanY += data[i][m];
    meanY /= n;

    let denY = 0;
    for (let i = 0; i < n; i++) {
        const d = data[i][m] - meanY;
        denY += d * d;
    }
    denY = Math.sqrt(denY);

    const results = [];
    for (let j = 0; j < m; j++) {
        let meanX = 0;
        for (let i = 0; i < n; i++) meanX += data[i][j];
        meanX /= n;

        let num = 0, denX = 0;
        for (let i = 0; i < n; i++) {
            const dx = data[i][j] - meanX;
            const dy = data[i][m] - meanY;
            num += dx * dy;
            denX += dx * dx;
        }
        denX = Math.sqrt(denX);
        const corr = (denX * denY > 0) ? num / (denX * denY) : 0.0;
        results.push([j, corr]);
    }

    results.sort((a, b) => {
        if (a[1] !== b[1]) return b[1] - a[1];
        return a[0] - b[0];
    });

    const out = [];
    for (const [idx, corr] of results) {
        out.push(idx + ' ' + corr.toFixed(4));
    }
    console.log(out.join('\n'));
});

复杂度分析

  • 时间复杂度,其中 为样本数, 为特征数。对每个特征遍历一次所有样本计算相关系数,最后对 个特征排序。
  • 空间复杂度,存储数据矩阵。