题目链接

小红书推荐算法

题目描述

给定一个商品清单和用户最近搜索过的一些关键词。需要设计一个推荐算法,将商品按照与用户搜索关键词的匹配度进行排序。

排序规则如下:

  1. 主要规则:包含用户搜索过的关键词数量越多的商品,排名越靠前。
  2. 次要规则:对于包含关键词数量相同的商品,保持它们在输入时的原始相对顺序。

解题思路

这个问题的核心在于实现一个自定义的稳定排序。

  1. 数据结构选择

    • 为了快速判断一个商品的关键词是否是用户搜索过的关键词,我们可以将所有用户搜索的关键词存储在一个哈希集合(std::setHashSetset)中。这样,查询一个关键词是否存在的时间复杂度可以降至平均 (对于 std::set)或 (对于哈希实现)。
    • 为了存储每个商品的信息并进行排序,我们可以定义一个结构体或类,其中包含商品名称、与用户搜索词的匹配数量,以及它在输入中的原始索引(用于实现稳定排序)。
  2. 算法流程

    • 首先,读取用户搜索的所有关键词,并将它们存入哈希集合 search_keywords 中。
    • 然后,遍历所有商品。对于每个商品:
      • 计算它有多少个关键词出现在 search_keywords 集合中,得到匹配数 match_count
      • 将商品信息(名称、匹配数、原始索引)存储到一个列表 products 中。
    • products 列表进行排序。排序的逻辑是:
      • 首先比较两个商品的 match_count,按降序排列。
      • 如果 match_count 相等,则比较它们的原始索引 original_index,按升序排列,以保证排序的稳定性。
      • 许多语言的内置排序函数(如 C++ 的 std::stable_sort,Python 的 sort())本身就是稳定的,我们只需要提供按 match_count 降序排列的比较逻辑即可。
    • 最后,遍历排序后的 products 列表,依次输出商品名称。

代码

#include <iostream>
#include <vector>
#include <string>
#include <set>
#include <algorithm>

using namespace std;

// 用于存储商品信息的结构体
struct Product {
    string name;
    int match_count;
    int original_index;
};

// 自定义比较函数
bool compareProducts(const Product& a, const Product& b) {
    if (a.match_count != b.match_count) {
        return a.match_count > b.match_count;
    }
    return a.original_index < b.original_index;
}

void solve() {
    int n, m;
    cin >> n >> m;

    set<string> search_keywords;
    for (int i = 0; i < m; ++i) {
        string keyword;
        cin >> keyword;
        search_keywords.insert(keyword);
    }

    vector<Product> products(n);
    for (int i = 0; i < n; ++i) {
        string name;
        int k;
        cin >> name >> k;

        int match_count = 0;
        for (int j = 0; j < k; ++j) {
            string attr;
            cin >> attr;
            if (search_keywords.count(attr)) {
                match_count++;
            }
        }
        products[i] = {name, match_count, i};
    }

    // 使用稳定排序,仅按匹配数排序即可
    stable_sort(products.begin(), products.end(), [](const Product& a, const Product& b) {
        return a.match_count > b.match_count;
    });

    for (const auto& p : products) {
        cout << p.name << endl;
    }
}

int main() {
    solve();
    return 0;
}
import java.util.*;

class Product {
    String name;
    int matchCount;
    int originalIndex;

    public Product(String name, int matchCount, int originalIndex) {
        this.name = name;
        this.matchCount = matchCount;
        this.originalIndex = originalIndex;
    }
}

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        Set<String> searchKeywords = new HashSet<>();
        for (int i = 0; i < m; i++) {
            searchKeywords.add(sc.next());
        }

        List<Product> products = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            String name = sc.next();
            int k = sc.nextInt();
            int matchCount = 0;
            for (int j = 0; j < k; j++) {
                if (searchKeywords.contains(sc.next())) {
                    matchCount++;
                }
            }
            products.add(new Product(name, matchCount, i));
        }

        // Java的 `sort` 是稳定的,所以只需按匹配数降序比较
        products.sort((p1, p2) -> Integer.compare(p2.matchCount, p1.matchCount));
        
        for (Product p : products) {
            System.out.println(p.name);
        }
    }
}
# Python Solution
def solve():
    n, m = map(int, input().split())
    search_keywords = set(input().split())
    
    products = []
    for i in range(n):
        line1 = input().split()
        name = line1[0]
        k = int(line1[1])
        
        attributes = input().split()
        match_count = 0
        for attr in attributes:
            if attr in search_keywords:
                match_count += 1
        
        # 存储为元组 (匹配数, 原始索引, 商品名)
        products.append((-match_count, i, name))

    # Python的sort是稳定的,我们利用元组排序
    # -match_count 实现降序,i 实现稳定性
    products.sort()

    for item in products:
        print(item[2])

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:模拟 + 排序

  • 时间复杂度,其中 是商品数量, 是用户搜索的关键词数量, 是所有商品关键词的总数, 是关键词的平均长度。

    • 个用户关键词存入哈希集合耗时
    • 遍历所有商品的全部 个关键词并查询哈希表耗时
    • 个商品进行排序耗时
  • 空间复杂度,其中 是商品名称的平均长度。

    • 存储用户关键词的哈希集合需要 的空间。
    • 存储 个商品的信息需要 的空间。