题目链接

多分类加权指标计算

题目描述

给定一批样本的预测标签 pred、真实标签 trueY 以及各类别在总体评估中的权重 weights。请计算加权精确率(Precision)、加权召回率(Recall)与加权 分数。

  • 对于每个类别 ,统计:
    • :预测为 且真实为 的样本数。
    • :预测为 但真实不为 的样本数。
    • :真实为 但预测不为 的样本数。
  • 计算每类指标:
    • (分母为 0 则记为 0)。
    • (分母为 0 则记为 0)。
    • (分母为 0 则记为 0)。
  • 加权汇总:

输入描述:

  • 第 1 行:预测结果 pred,空格分隔。
  • 第 2 行:真实标签 trueY,空格分隔。
  • 第 3 行:各类别权重 weights,按类别 的顺序给出。

输出描述:

  • 一行三个数:(空格分隔,均保留 2 位小数)。

解题思路

本题的核心在于对每一个类别分别统计其混淆矩阵中的各项指标,最后按照给定的权重进行加权平均。

  1. 读取输入与类别确定
    • 读取 predtrueY
    • 读取 weights,权重数组的大小即为总类别数
  2. 统计 TP, FP, FN
    • 遍历 predtrueY 数组,对于每个位置的预测值 和真实值
      • ,则类别 增加 1。
      • ,则类别 增加 1,类别 增加 1。
  3. 分层计算指标
    • 对于每个类别
      • 根据统计结果计算
      • 注意处理分母为 0 的情况。
  4. 加权求和并格式化输出
    • 使用 fixedsetprecision(2) (C++) 或 String.format (Java) 或 "{:.2f}".format (Python) 保证输出保留两位小数且不足位补零。

代码

#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <iomanip>

using namespace std;

int main() {
    string line1, line2, line3;
    getline(cin, line1);
    getline(cin, line2);
    getline(cin, line3);

    vector<int> pred, trueY;
    vector<double> weights;
    int val;
    double w;

    stringstream ss1(line1);
    while (ss1 >> val) pred.push_back(val);
    stringstream ss2(line2);
    while (ss2 >> val) trueY.push_back(val);
    stringstream ss3(line3);
    while (ss3 >> w) weights.push_back(w);

    int k = weights.size();
    vector<int> tp(k, 0), fp(k, 0), fn(k, 0);

    for (int i = 0; i < pred.size(); ++i) {
        int p = pred[i];
        int t = trueY[i];
        if (p == t) {
            tp[p]++;
        } else {
            fp[p]++;
            fn[t]++;
        }
    }

    double total_p = 0, total_r = 0, total_f1 = 0;
    for (int i = 0; i < k; ++i) {
        double pc = (tp[i] + fp[i] == 0) ? 0 : (double)tp[i] / (tp[i] + fp[i]);
        double rc = (tp[i] + fn[i] == 0) ? 0 : (double)tp[i] / (tp[i] + fn[i]);
        double f1c = (pc + rc == 0) ? 0 : (2 * pc * rc) / (pc + rc);
        
        total_p += weights[i] * pc;
        total_r += weights[i] * rc;
        total_f1 += weights[i] * f1c;
    }

    cout << fixed << setprecision(2) << total_p << " " << total_r << " " << total_f1 << endl;

    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String line1 = sc.nextLine();
        String line2 = sc.nextLine();
        String line3 = sc.nextLine();

        String[] pStr = line1.trim().split("\\s+");
        String[] tStr = line2.trim().split("\\s+");
        String[] wStr = line3.trim().split("\\s+");

        int n = pStr.length;
        int k = wStr.length;

        int[] pred = new int[n];
        int[] trueY = new int[n];
        double[] weights = new double[k];

        for (int i = 0; i < n; i++) {
            pred[i] = Integer.parseInt(pStr[i]);
            trueY[i] = Integer.parseInt(tStr[i]);
        }
        for (int i = 0; i < k; i++) {
            weights[i] = Double.parseDouble(wStr[i]);
        }

        int[] tp = new int[k];
        int[] fp = new int[k];
        int[] fn = new int[k];

        for (int i = 0; i < n; i++) {
            int p = pred[i];
            int t = trueY[i];
            if (p == t) {
                tp[p]++;
            } else {
                fp[p]++;
                fn[t]++;
            }
        }

        double totalP = 0, totalR = 0, totalF1 = 0;
        for (int i = 0; i < k; i++) {
            double pc = (tp[i] + fp[i] == 0) ? 0 : (double) tp[i] / (tp[i] + fp[i]);
            double rc = (tp[i] + fn[i] == 0) ? 0 : (double) tp[i] / (tp[i] + fn[i]);
            double f1c = (pc + rc == 0) ? 0 : (2 * pc * rc) / (pc + rc);

            totalP += weights[i] * pc;
            totalR += weights[i] * rc;
            totalF1 += weights[i] * f1c;
        }

        System.out.println(String.format("%.2f %.2f %.2f", totalP, totalR, totalF1));
    }
}
def solve():
    pred = list(map(int, input().split()))
    true_y = list(map(int, input().split()))
    weights = list(map(float, input().split()))
    
    k = len(weights)
    tp = [0] * k
    fp = [0] * k
    fn = [0] * k
    
    for p, t in zip(pred, true_y):
        if p == t:
            tp[p] += 1
        else:
            fp[p] += 1
            fn[t] += 1
            
    total_p, total_r, total_f1 = 0.0, 0.0, 0.0
    for i in range(k):
        pc = tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0
        rc = tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0
        f1c = (2 * pc * rc) / (pc + rc) if (pc + rc) > 0 else 0
        
        total_p += weights[i] * pc
        total_r += weights[i] * rc
        total_f1 += weights[i] * f1c
        
    print(f"{total_p:.2f} {total_r:.2f} {total_f1:.2f}")

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:多分类性能指标计算。通过对每类进行局部指标(Precision, Recall, F1)计算,再利用加权平均获得全局指标。
  • 时间复杂度:。其中 为样本数, 为类别数。统计 需要遍历一次样本,计算指标需要遍历一次类别。
  • 空间复杂度:。用于存储每个类别的 统计值。