电商活动排序
[题目链接](https://www.nowcoder.com/practice/19cefecc2ff7439b81358999178cc4b4)
思路
本题要求根据决策树中"信息增益比"来选择最优特征。输入是一个二维数组,每行是一条样本,前几列是特征,最后一列是标签(0 或 1)。需要对每个特征计算信息增益比,返回信息增益比最大的特征索引。
信息增益比的计算步骤
第一步:计算数据集的信息熵
$$
其中 是属于第
类的样本集合,
是类别数目。
第二步:计算特征 的条件熵
按特征 的不同取值将数据集
划分为若干子集
,则:
$$
第三步:信息增益
$$
第四步:计算属性熵(固有值)
$$
注意这里 是按特征
的取值划分的子集(而非按标签划分)。
第五步:信息增益比
$$
算法流程
- 解析输入的二维数组;
- 统计标签分布,计算
;
- 对每个特征,按其取值分组,计算条件熵
、信息增益
、属性熵
,进而得到信息增益比;
- 返回信息增益比最大的特征索引。
样例演示
输入 [[0,0,0,0,0],[0,0,0,1,0],[0,1,0,1,1],[0,1,1,0,0],[0,0,0,0,0]],共 5 条样本,4 个特征,标签为 [0,0,1,0,0]。
- 特征 0 全为 0,无法划分,属性熵为 0,跳过;
- 特征 1 取值 {0,1},按此划分后能区分部分标签,信息增益比最高;
- 特征 2、3 的信息增益比低于特征 1。
输出 1。
复杂度分析
- 时间复杂度:
,其中
是样本数,
是特征数。对每个特征遍历一次所有样本。
- 空间复杂度:
,存储输入数据及分组统计信息。
代码
#include <bits/stdc++.h>
using namespace std;
int main() {
string line;
getline(cin, line);
// 解析嵌套数组输入
vector<vector<int>> data;
vector<int> row;
string num;
int depth = 0;
for (char c : line) {
if (c == '[') {
depth++;
if (depth == 2) {
row.clear();
num.clear();
}
} else if (c == ']') {
if (depth == 2) {
if (!num.empty()) {
row.push_back(stoi(num));
num.clear();
}
data.push_back(row);
}
depth--;
} else if (depth == 2) {
if (c == ',') {
if (!num.empty()) {
row.push_back(stoi(num));
num.clear();
}
} else if (c != ' ') {
num += c;
}
}
}
int rows = data.size();
int cols = data[0].size();
int numFeatures = cols - 1;
auto calcEntropy = [](map<int,int>& cnt, int total) -> double {
double h = 0;
for (auto& p : cnt) {
if (p.second == 0) continue;
double prob = (double)p.second / total;
h -= prob * log2(prob);
}
return h;
};
// 计算标签熵 H(D)
map<int,int> labelCnt;
for (int i = 0; i < rows; i++)
labelCnt[data[i][cols-1]]++;
double HD = calcEntropy(labelCnt, rows);
double bestRatio = -1e18;
int bestIdx = 0;
for (int f = 0; f < numFeatures; f++) {
map<int, map<int,int>> groups;
map<int, int> groupSize;
for (int i = 0; i < rows; i++) {
groups[data[i][f]][data[i][cols-1]]++;
groupSize[data[i][f]]++;
}
// 条件熵 H(D|A)
double HDA = 0;
for (auto& g : groups) {
int sz = groupSize[g.first];
double h = calcEntropy(g.second, sz);
HDA += (double)sz / rows * h;
}
double gain = HD - HDA;
// 属性熵 H_A(D)
double HA = 0;
for (auto& g : groupSize) {
double prob = (double)g.second / rows;
HA -= prob * log2(prob);
}
if (HA == 0) continue;
double ratio = gain / HA;
if (ratio > bestRatio) {
bestRatio = ratio;
bestIdx = f;
}
}
cout << bestIdx << endl;
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
String line = sc.nextLine().trim();
// 解析输入
List<int[]> data = new ArrayList<>();
if (line.startsWith("[")) line = line.substring(1);
if (line.endsWith("]")) line = line.substring(0, line.length() - 1);
String[] parts = line.split("\\]\\s*,\\s*\\[");
for (String part : parts) {
part = part.replace("[", "").replace("]", "").trim();
String[] nums = part.split("\\s*,\\s*");
int[] row = new int[nums.length];
for (int i = 0; i < nums.length; i++)
row[i] = Integer.parseInt(nums[i].trim());
data.add(row);
}
int rows = data.size();
int cols = data.get(0).length;
int numFeatures = cols - 1;
// 标签熵 H(D)
Map<Integer, Integer> labelCnt = new HashMap<>();
for (int[] row : data)
labelCnt.merge(row[cols - 1], 1, Integer::sum);
double HD = calcEntropy(labelCnt, rows);
double bestRatio = Double.NEGATIVE_INFINITY;
int bestIdx = 0;
for (int f = 0; f < numFeatures; f++) {
Map<Integer, Map<Integer, Integer>> groups = new HashMap<>();
Map<Integer, Integer> groupSize = new HashMap<>();
for (int[] row : data) {
int fv = row[f];
groups.computeIfAbsent(fv, k -> new HashMap<>()).merge(row[cols - 1], 1, Integer::sum);
groupSize.merge(fv, 1, Integer::sum);
}
// 条件熵
double HDA = 0;
for (Map.Entry<Integer, Map<Integer, Integer>> entry : groups.entrySet()) {
int sz = groupSize.get(entry.getKey());
HDA += (double) sz / rows * calcEntropy(entry.getValue(), sz);
}
double gain = HD - HDA;
// 属性熵
double HA = 0;
for (Map.Entry<Integer, Integer> entry : groupSize.entrySet()) {
double prob = (double) entry.getValue() / rows;
HA -= prob * Math.log(prob) / Math.log(2);
}
if (HA == 0) continue;
double ratio = gain / HA;
if (ratio > bestRatio) {
bestRatio = ratio;
bestIdx = f;
}
}
System.out.println(bestIdx);
}
static double calcEntropy(Map<Integer, Integer> cnt, int total) {
double h = 0;
for (int c : cnt.values()) {
if (c == 0) continue;
double prob = (double) c / total;
h -= prob * Math.log(prob) / Math.log(2);
}
return h;
}
}
import math
from collections import Counter
def calc_entropy(counts, total):
h = 0.0
for c in counts.values():
if c == 0:
continue
p = c / total
h -= p * math.log2(p)
return h
def solve():
data = eval(input())
rows = len(data)
cols = len(data[0])
num_features = cols - 1
# 标签熵 H(D)
label_cnt = Counter(row[-1] for row in data)
HD = calc_entropy(label_cnt, rows)
best_ratio = float('-inf')
best_idx = 0
for f in range(num_features):
# 按特征值分组
groups = {}
for row in data:
fv = row[f]
if fv not in groups:
groups[fv] = Counter()
groups[fv][row[-1]] += 1
# 条件熵 H(D|A)
HDA = 0.0
group_sizes = {}
for fv, cnt in groups.items():
sz = sum(cnt.values())
group_sizes[fv] = sz
HDA += sz / rows * calc_entropy(cnt, sz)
gain = HD - HDA
# 属性熵 H_A(D)
HA = 0.0
for sz in group_sizes.values():
p = sz / rows
HA -= p * math.log2(p)
if HA == 0:
continue
ratio = gain / HA
if ratio > best_ratio:
best_ratio = ratio
best_idx = f
print(best_idx)
solve()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
let lines = [];
rl.on('line', line => lines.push(line));
rl.on('close', () => {
const data = JSON.parse(lines[0].replace(/\s+/g, ''));
const rows = data.length;
const cols = data[0].length;
const numFeatures = cols - 1;
function calcEntropy(cnt, total) {
let h = 0;
for (const c of Object.values(cnt)) {
if (c === 0) continue;
const p = c / total;
h -= p * Math.log2(p);
}
return h;
}
// 标签熵 H(D)
const labelCnt = {};
for (const row of data) {
const l = row[cols - 1];
labelCnt[l] = (labelCnt[l] || 0) + 1;
}
const HD = calcEntropy(labelCnt, rows);
let bestRatio = -Infinity;
let bestIdx = 0;
for (let f = 0; f < numFeatures; f++) {
const groups = {};
const groupSize = {};
for (const row of data) {
const fv = row[f];
if (!groups[fv]) groups[fv] = {};
const label = row[cols - 1];
groups[fv][label] = (groups[fv][label] || 0) + 1;
groupSize[fv] = (groupSize[fv] || 0) + 1;
}
// 条件熵
let HDA = 0;
for (const fv of Object.keys(groups)) {
const sz = groupSize[fv];
HDA += sz / rows * calcEntropy(groups[fv], sz);
}
const gain = HD - HDA;
// 属性熵
let HA = 0;
for (const sz of Object.values(groupSize)) {
const p = sz / rows;
HA -= p * Math.log2(p);
}
if (HA === 0) continue;
const ratio = gain / HA;
if (ratio > bestRatio) {
bestRatio = ratio;
bestIdx = f;
}
}
console.log(bestIdx);
});

京公网安备 11010502036488号