题目链接

【模板】多重集合操作

题目描述

你需要动态维护一个初始为空的多重集合 M(允许元素重复),并支持以下六种操作:

  1. 插入: 向集合 M 中插入一个整数 x
  2. 删除: 从集合 M 中删除一个整数 x。如果 x 存在多个,只删除一个。
  3. 查询个数: 查询整数 x 在集合 M 中的出现次数。
  4. 查询大小: 查询集合 M 中元素的总个数(计入重复元素)。
  5. 查询前驱: 查询 x 的前驱(小于 x 的最大元素)。
  6. 查询后继: 查询 x 的后继(大于 x 的最小元素)。

输入描述:

  • 第一行一个整数 q),表示操作总数。
  • 接下来 q 行,每行输入操作类型 op)和操作数 x

输出描述:

  • 对于查询操作(类型 3, 4, 5, 6),输出相应的结果,每个结果占一行。
  • 如果前驱或后继不存在,输出 -1

示例 输入:

6
1 2
1 2
3 2
2 2
3 2
4 0

输出:

2
1
1

解题思路

这道题是"核心代码模式",要求我们实现对一个多重集合(允许重复元素)的操作。

  • 核心数据结构

    • C++: std::multiset 是 C++ 标准库中提供的基于红黑树的多重集合容器,完美符合本题的所有要求,所有操作都非常高效。
    • Java: 标准库中的 TreeSet 不允许重复。因此,我们采用 TreeMap<Integer, Integer> 来模拟多重集合,其中 key 存元素值,value 存该元素的出现次数。这样既能保持元素有序,又能处理重复。同时需要一个额外变量来维护集合总大小。
    • Python: 标准库没有内置的平衡二叉搜索树。我们继续使用有序 listbisect 模块来模拟。有序列表天然支持重复元素。但插入和删除操作为 ,可能会超时。
  • 函数实现:

    • insertValue: C++直接 insert。Java 中更新 TreeMapx 的计数。Python 中使用 bisect.insort_left 插入。
    • eraseValue: C++ 中需使用 find 定位到迭代器再 erase,以确保只删除一个。Java 中将 x 的计数减一,若计数归零则从 TreeMap 中移除。Python 中 find 到第一个 x 然后 pop
    • xCount: C++直接调用 count。Java 从 TreeMap 获取计数。Python 中用 bisect_rightbisect_left 的差值计算个数。
    • sizeOfSet: C++直接 size()。Java 返回维护的总大小变量。Python 返回 len(list)
    • getPre/getBack: C++使用 lower_bound/upper_bound。Java 使用 lowerKey/higherKey。Python 使用 bisect_left/bisect_right。如果不存在,统一返回-1。

我们将这些逻辑填充到题目给定的函数模板中。

代码

#include<bits/stdc++.h>
using namespace std;

multiset<int> ms;

void insertValue(int x){
    ms.insert(x);
}
void eraseValue(int x){
    auto it = ms.find(x);
    if (it != ms.end()) {
        ms.erase(it);
    }
}
int xCount(int x){
    return ms.count(x);
}
int sizeOfSet(){
    return ms.size();
}
int getPre(int x){
    auto it = ms.lower_bound(x);
    if(it == ms.begin()) return -1;
    it--;
    return *it;
}
int getBack(int x){
    auto it = ms.upper_bound(x);
    if(it == ms.end()) return -1;
    return *it;
}

int main(){
    int q,op,x;
    cin>>q;
    while(q--){
        cin>>op;
        if(op==1){
            cin>>x;
            insertValue(x);
        }
        if(op==2){
            cin>>x;
            eraseValue(x);
        }
        if(op==3){
            cin>>x;
            cout<<xCount(x)<<endl;
        }
        if(op==4){
            cout<<sizeOfSet()<<endl;
        }
        if(op==5){
            cin>>x;
            cout<<getPre(x)<<endl;
        }
        if(op==6){
            cin>>x;
            cout<<getBack(x)<<endl;
        }
    }
    return 0;
}
import java.util.*;

public class Main {
    static TreeMap<Integer, Integer> map = new TreeMap<>();
    static int totalSize = 0;

    public static void insertValue(int x) {
        map.put(x, map.getOrDefault(x, 0) + 1);
        totalSize++;
    }

    public static void eraseValue(int x) {
        if (map.containsKey(x)) {
            int count = map.get(x);
            if (count > 1) {
                map.put(x, count - 1);
            } else {
                map.remove(x);
            }
            totalSize--;
        }
    }

    public static int xCount(int x) {
        return map.getOrDefault(x, 0);
    }

    public static int sizeOfSet() {
        return totalSize;
    }

    public static int getPre(int x) {
        Integer pre = map.lowerKey(x);
        return (pre == null) ? -1 : pre;
    }

    public static int getBack(int x) {
        Integer post = map.higherKey(x);
        return (post == null) ? -1 : post;
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int q = scanner.nextInt();
        while (q-- > 0) {
            int op = scanner.nextInt();
            switch (op) {
                case 1:
                    insertValue(scanner.nextInt());
                    break;
                case 2:
                    eraseValue(scanner.nextInt());
                    break;
                case 3:
                    System.out.println(xCount(scanner.nextInt()));
                    break;
                case 4:
                    System.out.println(sizeOfSet());
                    break;
                case 5:
                    System.out.println(getPre(scanner.nextInt()));
                    break;
                case 6:
                    System.out.println(getBack(scanner.nextInt()));
                    break;
            }
        }
        scanner.close();
    }
}
import sys
import bisect

# 使用 list+bisect 模拟多重集合。
# 插入/删除为 O(N),可能超时。
data = []

def insertValue(x):
    bisect.insort_left(data, x)

def eraseValue(x):
    idx = bisect.bisect_left(data, x)
    if idx < len(data) and data[idx] == x:
        data.pop(idx)

def xCount(x):
    left_idx = bisect.bisect_left(data, x)
    right_idx = bisect.bisect_right(data, x)
    return right_idx - left_idx

def sizeOfSet():
    return len(data)

def getPre(x):
    idx = bisect.bisect_left(data, x)
    if idx > 0:
        return data[idx - 1]
    else:
        return -1

def getBack(x):
    idx = bisect.bisect_right(data, x)
    if idx < len(data):
        return data[idx]
    else:
        return -1

def main():
    q = int(input())
    for _ in range(q):
        line = map(int,input().split())
        cnt,op,x=0,0,0
        for i in line:
            if(cnt==0):
                op=i
            else:
                x=i
            cnt+=1
        
        if op == 1:
            insertValue(x)
        elif op == 2:
            eraseValue(x)
        elif op == 3:
            print(xCount(x))
        elif op == 4:
            print(sizeOfSet())
        elif op == 5:
            print(getPre(x))
        elif op == 6:
            print(getBack(x))

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:
    • C++: 平衡二叉搜索树 (std::multiset)
    • Java: 平衡二叉搜索树 (java.util.TreeMap)
    • Python: 有序列表+二分查找
  • 时间复杂度:
    • C++: 所有操作平均和最坏时间复杂度均为 ,其中 N 是集合中元素总数。总复杂度
    • Java: 所有操作平均和最坏时间复杂度均为 ,其中 K 是集合中不同元素的个数。总复杂度
    • Python: 查找、计数、前驱、后继为 。插入和删除为 。总复杂度最坏为
  • 空间复杂度,用于存储集合中的元素。