题目链接

最优分词器

题目描述

给定一个字符串 text、一个包含分值的词典、以及一些相邻词对的转移加分。任务是对 text 进行分词,使得所有词的“词典分”与所有相邻词对的“转移加分”之和最大。如果字符串无法被词典完全分词,则输出0。

解题思路

这是一个典型的序列分割求最优解问题,可以使用动态规划来解决。

1. DP状态定义

一个简单的 dp[i] 表示“前 i 个字符的最优得分”是不足够的,因为它丢失了计算“转移加分”所必需的“上一个词”的信息。

因此,我们需要一个更丰富的状态定义。我们令 dp[i] 为一个哈希表(map或dictionary),其中:

  • key: 是一个字符串 word
  • value: 是一个 long long 类型的分数。

dp[i][word] 的含义是:将字符串的前 i 个字符(即 text[0...i-1])进行分词,并且最后一个词是 word 时,能获得的最大总分

2. DP转移方程

为了计算 dp[i] 中的所有可能状态,我们遍历所有可能的分割点 j ():

  • current_word = text.substr(j, i - j),即子串 text[j...i-1]
  • 如果 current_word 在词典中,我们就可以尝试用它来作为前 i 个字符分词的最后一个词。
  • 此时,我们需要找到一个最优的前驱状态。我们遍历 dp[j] 中的每一个条目 (prev_word, prev_score)
    • prev_wordtext[0...j-1] 分词结果的最后一个词。
    • prev_scoredp[j][prev_word] 的值。
  • 基于这个前驱状态,我们可以计算出一种新的分词方案的总分: new_score = prev_score + 词典分(current_word) + 转移加分(prev_word, current_word)
  • 我们用 new_score 来更新 dp[i][current_word] 的最大值。

3. 基础情况

j=0 时,current_word 是整个分词序列的第一个词,没有 prev_word。此时,它的得分就是它自身的词典分。 dp[i][current_word] = 词典分(current_word)

4. 最终答案

遍历完所有 ij 后,dp[text.length()] 这个哈希表中就存储了所有以不同词结尾的、对整个 text 字符串的完整分词方案的最高分。我们只需要取这个哈希表中的最大值即可。

如果 dp[text.length()] 为空,说明不存在任何一种有效的分词方案可以覆盖整个字符串,此时按题意应输出 0

为了处理分数,我们需要使用 long long 来防止溢出。同时,为了方便查询,词典和转移加分规则也应该用哈希表存储。

代码

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <algorithm>

using namespace std;
using ll = long long;

const ll NINF = -1e18; // 使用一个足够小的数表示负无穷

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    string text;
    cin >> text;

    int n;
    cin >> n;
    map<string, int> word_scores;
    for (int i = 0; i < n; ++i) {
        string word;
        int score;
        cin >> word >> score;
        word_scores[word] = score;
    }

    int m;
    cin >> m;
    map<string, map<string, int>> transition_scores;
    for (int i = 0; i < m; ++i) {
        string prev_word, next_word;
        int score;
        cin >> prev_word >> next_word >> score;
        transition_scores[prev_word][next_word] = score;
    }

    int len = text.length();
    vector<map<string, ll>> dp(len + 1);

    for (int i = 1; i <= len; ++i) {
        for (int j = 0; j < i; ++j) {
            string current_word = text.substr(j, i - j);
            if (word_scores.count(current_word)) {
                ll word_score = word_scores[current_word];
                
                if (j == 0) {
                    dp[i][current_word] = max(dp[i].count(current_word) ? dp[i][current_word] : NINF, word_score);
                } else {
                    if (!dp[j].empty()) {
                        for (auto const& [prev_word, prev_score] : dp[j]) {
                            ll trans_score = 0;
                            if (transition_scores.count(prev_word) && transition_scores[prev_word].count(current_word)) {
                                trans_score = transition_scores[prev_word][current_word];
                            }
                            ll new_score = prev_score + word_score + trans_score;
                            dp[i][current_word] = max(dp[i].count(current_word) ? dp[i][current_word] : NINF, new_score);
                        }
                    }
                }
            }
        }
    }

    ll max_score = NINF;
    if (!dp[len].empty()) {
        for (auto const& [word, score] : dp[len]) {
            max_score = max(max_score, score);
        }
    }
    
    cout << (max_score == NINF ? 0 : max_score) << endl;

    return 0;
}
import java.util.Scanner;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        String text = sc.nextLine();
        
        int n = Integer.parseInt(sc.nextLine());
        Map<String, Integer> wordScores = new HashMap<>();
        for (int i = 0; i < n; i++) {
            String[] line = sc.nextLine().split(" ");
            wordScores.put(line[0], Integer.parseInt(line[1]));
        }

        int m = Integer.parseInt(sc.nextLine());
        Map<String, Map<String, Integer>> transitionScores = new HashMap<>();
        for (int i = 0; i < m; i++) {
            String[] line = sc.nextLine().split(" ");
            String prevWord = line[0];
            String nextWord = line[1];
            int score = Integer.parseInt(line[2]);
            transitionScores.computeIfAbsent(prevWord, k -> new HashMap<>()).put(nextWord, score);
        }

        int len = text.length();
        List<Map<String, Long>> dp = new ArrayList<>();
        for (int i = 0; i <= len; i++) {
            dp.add(new HashMap<>());
        }

        final long NINF = Long.MIN_VALUE / 2; // 避免加法溢出

        for (int i = 1; i <= len; i++) {
            for (int j = 0; j < i; j++) {
                String currentWord = text.substring(j, i);
                if (wordScores.containsKey(currentWord)) {
                    long wordScore = wordScores.get(currentWord);

                    if (j == 0) {
                        long currentVal = dp.get(i).getOrDefault(currentWord, NINF);
                        dp.get(i).put(currentWord, Math.max(currentVal, wordScore));
                    } else {
                        if (!dp.get(j).isEmpty()) {
                            for (Map.Entry<String, Long> entry : dp.get(j).entrySet()) {
                                String prevWord = entry.getKey();
                                long prevScore = entry.getValue();
                                
                                long transScore = transitionScores.getOrDefault(prevWord, new HashMap<>()).getOrDefault(currentWord, 0);
                                long newScore = prevScore + wordScore + transScore;
                                
                                long currentVal = dp.get(i).getOrDefault(currentWord, NINF);
                                dp.get(i).put(currentWord, Math.max(currentVal, newScore));
                            }
                        }
                    }
                }
            }
        }

        long maxScore = NINF;
        if (!dp.get(len).isEmpty()) {
            for (long score : dp.get(len).values()) {
                maxScore = Math.max(maxScore, score);
            }
        }
        
        System.out.println(maxScore == NINF ? 0 : maxScore);
    }
}
import sys

def solve():
    text = sys.stdin.readline().strip()
    
    n = int(sys.stdin.readline())
    word_scores = {}
    for _ in range(n):
        word, score = sys.stdin.readline().split()
        word_scores[word] = int(score)
        
    m = int(sys.stdin.readline())
    transition_scores = {}
    for _ in range(m):
        prev_word, next_word, score = sys.stdin.readline().split()
        if prev_word not in transition_scores:
            transition_scores[prev_word] = {}
        transition_scores[prev_word][next_word] = int(score)

    text_len = len(text)
    dp = [{} for _ in range(text_len + 1)]
    
    NINF = float('-inf')

    for i in range(1, text_len + 1):
        for j in range(i):
            current_word = text[j:i]
            if current_word in word_scores:
                word_score = word_scores[current_word]
                
                if j == 0:
                    dp[i][current_word] = max(dp[i].get(current_word, NINF), word_score)
                else:
                    if dp[j]:
                        for prev_word, prev_score in dp[j].items():
                            trans_score = transition_scores.get(prev_word, {}).get(current_word, 0)
                            new_score = prev_score + word_score + trans_score
                            dp[i][current_word] = max(dp[i].get(current_word, NINF), new_score)

    final_scores = dp[text_len]
    if not final_scores:
        print(0)
    else:
        print(max(final_scores.values()))

solve()

算法及复杂度

  • 算法:动态规划
  • 时间复杂度,其中 text 字符串的长度, 是在任何一个分割点上可能出现的上一个词(prev_word)的最大数量。在最坏情况下, 可能与词典大小有关,但实际中通常远小于词典总词数。
  • 空间复杂度,用于存储DP表。此外,还需要空间存储词典和转移规则。