题目链接
题目描述
给定一批样本的预测标签 pred、真实标签 trueY 以及各类别在总体评估中的权重 weights。请计算加权精确率(Precision)、加权召回率(Recall)与加权 分数。
- 对于每个类别
,统计:
:预测为
且真实为
的样本数。
:预测为
但真实不为
的样本数。
:真实为
但预测不为
的样本数。
- 计算每类指标:
(分母为 0 则记为 0)。
(分母为 0 则记为 0)。
(分母为 0 则记为 0)。
- 加权汇总:
输入描述:
- 第 1 行:预测结果
pred,空格分隔。 - 第 2 行:真实标签
trueY,空格分隔。 - 第 3 行:各类别权重
weights,按类别的顺序给出。
输出描述:
- 一行三个数:
(空格分隔,均保留 2 位小数)。
解题思路
本题的核心在于对每一个类别分别统计其混淆矩阵中的各项指标,最后按照给定的权重进行加权平均。
- 读取输入与类别确定:
- 读取
pred和trueY。 - 读取
weights,权重数组的大小即为总类别数。
- 读取
- 统计 TP, FP, FN:
- 遍历
pred和trueY数组,对于每个位置的预测值和真实值
:
- 若
,则类别
的
增加 1。
- 若
,则类别
的
增加 1,类别
的
增加 1。
- 若
- 遍历
- 分层计算指标:
- 对于每个类别
:
- 根据统计结果计算
。
- 注意处理分母为 0 的情况。
- 根据统计结果计算
- 对于每个类别
- 加权求和并格式化输出:
- 使用
fixed和setprecision(2)(C++) 或String.format(Java) 或"{:.2f}".format(Python) 保证输出保留两位小数且不足位补零。
- 使用
代码
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <iomanip>
using namespace std;
int main() {
string line1, line2, line3;
getline(cin, line1);
getline(cin, line2);
getline(cin, line3);
vector<int> pred, trueY;
vector<double> weights;
int val;
double w;
stringstream ss1(line1);
while (ss1 >> val) pred.push_back(val);
stringstream ss2(line2);
while (ss2 >> val) trueY.push_back(val);
stringstream ss3(line3);
while (ss3 >> w) weights.push_back(w);
int k = weights.size();
vector<int> tp(k, 0), fp(k, 0), fn(k, 0);
for (int i = 0; i < pred.size(); ++i) {
int p = pred[i];
int t = trueY[i];
if (p == t) {
tp[p]++;
} else {
fp[p]++;
fn[t]++;
}
}
double total_p = 0, total_r = 0, total_f1 = 0;
for (int i = 0; i < k; ++i) {
double pc = (tp[i] + fp[i] == 0) ? 0 : (double)tp[i] / (tp[i] + fp[i]);
double rc = (tp[i] + fn[i] == 0) ? 0 : (double)tp[i] / (tp[i] + fn[i]);
double f1c = (pc + rc == 0) ? 0 : (2 * pc * rc) / (pc + rc);
total_p += weights[i] * pc;
total_r += weights[i] * rc;
total_f1 += weights[i] * f1c;
}
cout << fixed << setprecision(2) << total_p << " " << total_r << " " << total_f1 << endl;
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
String line1 = sc.nextLine();
String line2 = sc.nextLine();
String line3 = sc.nextLine();
String[] pStr = line1.trim().split("\\s+");
String[] tStr = line2.trim().split("\\s+");
String[] wStr = line3.trim().split("\\s+");
int n = pStr.length;
int k = wStr.length;
int[] pred = new int[n];
int[] trueY = new int[n];
double[] weights = new double[k];
for (int i = 0; i < n; i++) {
pred[i] = Integer.parseInt(pStr[i]);
trueY[i] = Integer.parseInt(tStr[i]);
}
for (int i = 0; i < k; i++) {
weights[i] = Double.parseDouble(wStr[i]);
}
int[] tp = new int[k];
int[] fp = new int[k];
int[] fn = new int[k];
for (int i = 0; i < n; i++) {
int p = pred[i];
int t = trueY[i];
if (p == t) {
tp[p]++;
} else {
fp[p]++;
fn[t]++;
}
}
double totalP = 0, totalR = 0, totalF1 = 0;
for (int i = 0; i < k; i++) {
double pc = (tp[i] + fp[i] == 0) ? 0 : (double) tp[i] / (tp[i] + fp[i]);
double rc = (tp[i] + fn[i] == 0) ? 0 : (double) tp[i] / (tp[i] + fn[i]);
double f1c = (pc + rc == 0) ? 0 : (2 * pc * rc) / (pc + rc);
totalP += weights[i] * pc;
totalR += weights[i] * rc;
totalF1 += weights[i] * f1c;
}
System.out.println(String.format("%.2f %.2f %.2f", totalP, totalR, totalF1));
}
}
def solve():
pred = list(map(int, input().split()))
true_y = list(map(int, input().split()))
weights = list(map(float, input().split()))
k = len(weights)
tp = [0] * k
fp = [0] * k
fn = [0] * k
for p, t in zip(pred, true_y):
if p == t:
tp[p] += 1
else:
fp[p] += 1
fn[t] += 1
total_p, total_r, total_f1 = 0.0, 0.0, 0.0
for i in range(k):
pc = tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0
rc = tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0
f1c = (2 * pc * rc) / (pc + rc) if (pc + rc) > 0 else 0
total_p += weights[i] * pc
total_r += weights[i] * rc
total_f1 += weights[i] * f1c
print(f"{total_p:.2f} {total_r:.2f} {total_f1:.2f}")
if __name__ == "__main__":
solve()
算法及复杂度
- 算法:多分类性能指标计算。通过对每类进行局部指标(Precision, Recall, F1)计算,再利用加权平均获得全局指标。
- 时间复杂度:
。其中
为样本数,
为类别数。统计
需要遍历一次样本,计算指标需要遍历一次类别。
- 空间复杂度:
。用于存储每个类别的
统计值。

京公网安备 11010502036488号