最优分词器

题意

给定一个纯小写字母组成的文本串 text,一个词典(每个词带有分值),以及一些"相邻词对"的转移加分(可正可负)。

要求把 text 完整切分成词典中的词序列,使得词典分之和 + 相邻词之间的转移加分之和最大。如果无法完整覆盖,输出

思路

为什么普通 DP 不够?

如果没有转移加分,这就是一个经典的"词典分词"问题:dp[i] 表示前 个字符能获得的最大分值,枚举以位置 结尾的所有词典词转移即可。

但加上了转移加分后,问题变了——当前能拿多少分,不仅取决于"切到了哪里",还取决于"上一个词是什么"。这提示我们:DP 状态需要多记一维信息

状态设计

定义 为一个映射: 表示恰好覆盖 text[0:i]、最后一个词为 时的最大总分。

转移时,枚举从位置 开始、能匹配的词典词

$$

起始状态 为空映射(表示还没选过任何词),此时不加转移分。

最终答案就是 ,其中 是文本长度。如果 为空,说明无法完整切分,输出

复杂度分析

设文本长度为 ,词典大小为 ,词最大长度为 。每个位置最多尝试 个词,每次转移遍历 中的所有前驱词(最多 个),总体时间复杂度 ,空间

实现要点

  1. 可以按词长分组,快速跳过长度不匹配的词。
  2. 对每个位置 ,如果 为空就直接跳过——说明这个位置不可达。
  3. 转移加分用 (prev_word, cur_word) 做 key 存哈希表,查不到就是

代码

import sys
from collections import defaultdict

def solve():
    input_data = sys.stdin.read().split('\n')
    idx = 0
    text = input_data[idx].strip(); idx += 1
    n = int(input_data[idx].strip()); idx += 1
    word_score = {}
    for _ in range(n):
        parts = input_data[idx].strip().split()
        idx += 1
        word_score[parts[0]] = int(parts[1])
    m = int(input_data[idx].strip()); idx += 1
    trans = {}
    for _ in range(m):
        parts = input_data[idx].strip().split()
        idx += 1
        trans[(parts[0], parts[1])] = int(parts[2])

    L = len(text)
    # dp[i] = dict: last_word -> best score to cover text[0:i]
    dp = [None] * (L + 1)
    dp[0] = {}

    words_by_len = defaultdict(list)
    for w in word_score:
        words_by_len[len(w)].append(w)

    for i in range(L):
        if dp[i] is None:
            continue
        for wlen, words in words_by_len.items():
            if i + wlen > L:
                continue
            substr = text[i:i + wlen]
            if substr not in word_score:
                continue
            w = substr
            ws = word_score[w]
            ni = i + wlen
            if dp[ni] is None:
                dp[ni] = {}
            if not dp[i]:
                # 起始位置,无前驱词
                if w not in dp[ni] or ws > dp[ni][w]:
                    dp[ni][w] = ws
            else:
                for prev_w, prev_score in dp[i].items():
                    bonus = trans.get((prev_w, w), 0)
                    val = prev_score + ws + bonus
                    if w not in dp[ni] or val > dp[ni][w]:
                        dp[ni][w] = val

    if dp[L] is None or not dp[L]:
        print(0)
    else:
        print(max(dp[L].values()))

solve()
#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    string text;
    cin >> text;
    int n;
    cin >> n;

    unordered_map<string, int> word_score;
    vector<string> words;
    for(int i = 0; i < n; i++){
        string w; int s;
        cin >> w >> s;
        word_score[w] = s;
        words.push_back(w);
    }

    int m;
    cin >> m;
    unordered_map<string, int> word_id;
    for(int i = 0; i < (int)words.size(); i++)
        word_id[words[i]] = i;

    map<pair<int,int>, int> trans;
    for(int i = 0; i < m; i++){
        string a, b; int s;
        cin >> a >> b >> s;
        if(word_id.count(a) && word_id.count(b))
            trans[{word_id[a], word_id[b]}] = s;
    }

    int L = text.size(), W = words.size();
    // dp[i]: last_word_id -> best score
    vector<unordered_map<int,long long>> dp(L + 1);
    dp[0][-1] = 0; // sentinel

    for(int i = 0; i < L; i++){
        if(dp[i].empty()) continue;
        for(int wi = 0; wi < W; wi++){
            int wl = words[wi].size();
            if(i + wl > L) continue;
            if(text.compare(i, wl, words[wi]) != 0) continue;
            int ws = word_score[words[wi]];
            int ni = i + wl;
            for(auto& [prev, prev_score] : dp[i]){
                long long bonus = 0;
                if(prev >= 0){
                    auto it = trans.find({prev, wi});
                    if(it != trans.end()) bonus = it->second;
                }
                long long val = prev_score + ws + bonus;
                auto it2 = dp[ni].find(wi);
                if(it2 == dp[ni].end() || val > it2->second)
                    dp[ni][wi] = val;
            }
        }
    }

    if(dp[L].empty()){
        cout << 0 << endl;
    } else {
        long long ans = LLONG_MIN;
        for(auto& [k, v] : dp[L])
            ans = max(ans, v);
        cout << ans << endl;
    }
    return 0;
}