题目链接

同义词替换

题目描述

给定一篇由 个单词组成的作文和 条同义词替换规则。每条规则 u -> v 表示单词 u 可以被替换为 v(单向,大小写不敏感)。这种替换具有传递性,即若 a -> bb -> c,则 a 可以被替换为 c

目标是对作文进行任意次数的替换,以得到一篇新的作文,满足:

  1. 首要目标:新作文中字母 'r' 的出现次数(大小写不敏感)最少。
  2. 次要目标:在满足上一目标的前提下,新作文的总长度(所有单词长度之和)最短。

你需要输出这两个值:最少的 'r' 出现次数和对应的最短总长度。

解题思路

这是一个图论问题。我们可以将单词和替换规则模型化为一个有向图,然后在这个图上寻找最优解。

1. 图的构建

  • 节点 (Nodes):将每一个在作文和规则中出现的独立单词(统一转为小写)视为图中的一个节点。
  • 边 (Edges):每一条替换规则 u -> v 对应一条从节点 u 到节点 v 的有向边。

由于替换具有传递性,一个单词 w 可以被替换为图中从 w 出发能够到达的任何一个单词。因此,对于作文中的每个单词 w,我们需要在所有从 w 可达的单词中,找到一个“最优”的替代词。

2. “最优”的定义

“最优”的评判标准是双重的:首先是 'r' 的数量,其次是单词长度。我们可以将每个单词的这两个属性作为一个二元组 (r_count, length)。比较两个单词的优劣,就等同于比较它们的二元组的字典序。例如,(2, 5) 优于 (3, 4)(2, 5) 优于 (2, 6)

3. 处理环路:强连通分量 (SCC)

图中可能存在环路(例如,a -> b, b -> a)。环路中的所有单词都可以相互替换,它们在替换的可能性上是等价的。在图论中,这样一个“最大”的相互可达的节点集合被称为强连通分量 (Strongly Connected Component, SCC)

我们可以将同一个 SCC 中的所有单词视为一个整体。对于这个 SCC 中的任意一个单词,它的最优替换词至少是这个 SCC 内部最优的那个词(即 SCC 内 (r_count, length) 最小的词)。

4. 在 DAG 上进行动态规划

通过缩点(将每个 SCC 视为一个单一的节点),我们可以将原图转化为一个有向无环图 (DAG),即缩点图

  • 缩点图中的一个节点代表原始图中的一个 SCC。
  • 如果在原图中存在一条从 SCC A 中的某个词到 SCC B 中的某个词的边,那么在缩点图中就有一条从 A 到 B 的边。

现在,一个单词 w(属于 SCC C)的可替换集,就是从 C 在缩点图上可达的所有 SCC 中包含的所有单词的集合。

为了找到每个 SCC 的最终最优解,我们可以在这个 DAG 上进行动态规划:

  • 一个 SCC 的最优解,取决于它自身内部的最优词,以及它能到达的所有下游 SCC 的最优解。
  • 我们可以通过逆拓扑序遍历这个 DAG(或者使用记忆化搜索),从图的“末端”开始,向上游传递最优解。
  • dp[C] 表示从 SCC C 出发能达到的最优 (r_count, length)
  • 状态转移方程为:dp[C] = min(optimal_word_in_C, min(dp[C']) for all C' that are successors of C)

算法整体流程

  1. 预处理:读取所有单词,统一转为小写,并为每个独立单词分配一个 ID。存储每个单词的 (r_count, length)
  2. 建图:根据替换规则,构建邻接表。
  3. 求 SCC:使用 Tarjan 算法Kosaraju 算法 找到图中所有的强连通分量。
  4. 计算 SCC 内部最优解:遍历所有单词,为每个 SCC 找出其内部最优的 (r_count, length)
  5. 建缩点图:根据原图的边和 SCC 的划分,构建缩点后的 DAG。
  6. DP 求解:在缩点图上使用记忆化搜索(DFS),计算每个 SCC 最终能达到的最优 (r_count, length)
  7. 统计答案:遍历作文中的每个初始单词,找到它所属的 SCC,并累加该 SCC 对应的最终最优解,得到总的 r_countlength

代码

#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <stack>
#include <algorithm>
#include <cctype>

using namespace std;

struct WordInfo {
    long long r_count;
    long long len;

    bool operator<(const WordInfo& other) const {
        if (r_count != other.r_count) {
            return r_count < other.r_count;
        }
        return len < other.len;
    }
};

const WordInfo INF = {1000000000000000000LL, 1000000000000000000LL};

// Globals for Tarjan's algorithm
vector<vector<int>> adj;
vector<int> dfn, low, scc_id;
stack<int> st;
vector<bool> on_stack;
int timer, scc_count;

void tarjan(int u) {
    dfn[u] = low[u] = ++timer;
    st.push(u);
    on_stack[u] = true;

    for (int v : adj[u]) {
        if (dfn[v] == 0) {
            tarjan(v);
            low[u] = min(low[u], low[v]);
        } else if (on_stack[v]) {
            low[u] = min(low[u], dfn[v]);
        }
    }

    if (dfn[u] == low[u]) {
        ++scc_count;
        int node;
        do {
            node = st.top();
            st.pop();
            on_stack[node] = false;
            scc_id[node] = scc_count;
        } while (node != u);
    }
}

// Globals for DP
vector<vector<int>> scc_adj;
vector<WordInfo> dp_memo;
vector<bool> dp_visited;
vector<WordInfo> scc_optimal;

WordInfo solve_dp(int u_scc) {
    if (dp_visited[u_scc]) {
        return dp_memo[u_scc];
    }
    dp_visited[u_scc] = true;

    WordInfo res = scc_optimal[u_scc];
    for (int v_scc : scc_adj[u_scc]) {
        res = min(res, solve_dp(v_scc));
    }
    return dp_memo[u_scc] = res;
}

string to_lower(const string& s) {
    string lower_s = s;
    transform(lower_s.begin(), lower_s.end(), lower_s.begin(), ::tolower);
    return lower_s;
}

int count_r(const string& s) {
    int count = 0;
    for (char c : s) {
        if (c == 'r' || c == 'R') {
            count++;
        }
    }
    return count;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n_words;
    cin >> n_words;
    vector<string> initial_text(n_words);
    map<string, int> word_to_id;
    vector<WordInfo> word_infos;
    int next_id = 0;

    auto get_id = [&](const string& s) {
        string lower_s = to_lower(s);
        if (word_to_id.find(lower_s) == word_to_id.end()) {
            word_to_id[lower_s] = next_id++;
            word_infos.push_back({(long long)count_r(lower_s), (long long)lower_s.length()});
        }
        return word_to_id[lower_s];
    };

    for (int i = 0; i < n_words; ++i) {
        cin >> initial_text[i];
        get_id(initial_text[i]);
    }

    int n_rules;
    cin >> n_rules;
    vector<pair<int, int>> edges;
    for (int i = 0; i < n_rules; ++i) {
        string u_str, v_str;
        cin >> u_str >> v_str;
        int u_id = get_id(u_str);
        int v_id = get_id(v_str);
        edges.push_back({u_id, v_id});
    }

    int total_unique_words = next_id;
    adj.assign(total_unique_words, vector<int>());
    for(const auto& edge : edges) {
        adj[edge.first].push_back(edge.second);
    }

    dfn.assign(total_unique_words, 0);
    low.assign(total_unique_words, 0);
    scc_id.assign(total_unique_words, 0);
    on_stack.assign(total_unique_words, false);
    timer = 0;
    scc_count = 0;
    for (int i = 0; i < total_unique_words; ++i) {
        if (dfn[i] == 0) {
            tarjan(i);
        }
    }

    scc_optimal.assign(scc_count + 1, INF);
    for (int i = 0; i < total_unique_words; ++i) {
        scc_optimal[scc_id[i]] = min(scc_optimal[scc_id[i]], word_infos[i]);
    }

    scc_adj.resize(scc_count + 1);
    for(int u = 0; u < total_unique_words; ++u) {
        for(int v : adj[u]) {
            if (scc_id[u] != scc_id[v]) {
                scc_adj[scc_id[u]].push_back(scc_id[v]);
            }
        }
    }

    dp_memo.assign(scc_count + 1, INF);
    dp_visited.assign(scc_count + 1, false);
    for (int i = 1; i <= scc_count; ++i) {
        if (!dp_visited[i]) {
            solve_dp(i);
        }
    }

    long long total_r = 0, total_len = 0;
    for (const string& word : initial_text) {
        int id = get_id(word);
        int s_id = scc_id[id];
        WordInfo best = dp_memo[s_id];
        total_r += best.r_count;
        total_len += best.len;
    }

    cout << total_r << " " << total_len << endl;

    return 0;
}
import java.util.*;

class Main {
    static class WordInfo implements Comparable<WordInfo> {
        long rCount;
        long len;

        WordInfo(long rCount, long len) {
            this.rCount = rCount;
            this.len = len;
        }

        @Override
        public int compareTo(WordInfo other) {
            if (this.rCount != other.rCount) {
                return Long.compare(this.rCount, other.rCount);
            }
            return Long.compare(this.len, other.len);
        }
    }

    static List<List<Integer>> adj;
    static int[] dfn, low, sccId;
    static boolean[] onStack;
    static Stack<Integer> stack;
    static int timer, sccCount;

    static List<List<Integer>> sccAdj;
    static WordInfo[] dpMemo;
    static boolean[] dpVisited;
    static WordInfo[] sccOptimal;

    static void tarjan(int u) {
        dfn[u] = low[u] = ++timer;
        stack.push(u);
        onStack[u] = true;

        for (int v : adj.get(u)) {
            if (dfn[v] == 0) {
                tarjan(v);
                low[u] = Math.min(low[u], low[v]);
            } else if (onStack[v]) {
                low[u] = Math.min(low[u], dfn[v]);
            }
        }

        if (dfn[u] == low[u]) {
            sccCount++;
            int node;
            do {
                node = stack.pop();
                onStack[node] = false;
                sccId[node] = sccCount;
            } while (node != u);
        }
    }

    static WordInfo solveDp(int uScc) {
        if (dpVisited[uScc]) {
            return dpMemo[uScc];
        }
        dpVisited[uScc] = true;

        WordInfo res = sccOptimal[uScc];
        for (int vScc : sccAdj.get(uScc)) {
            WordInfo childRes = solveDp(vScc);
            if (childRes.compareTo(res) < 0) {
                res = childRes;
            }
        }
        return dpMemo[uScc] = res;
    }
    
    static int countR(String s) {
        int count = 0;
        for (char c : s.toLowerCase().toCharArray()) {
            if (c == 'r') {
                count++;
            }
        }
        return count;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int nWords = sc.nextInt();
        String[] initialText = new String[nWords];
        Map<String, Integer> wordToId = new HashMap<>();
        List<WordInfo> wordInfos = new ArrayList<>();
        int nextId = 0;

        for (int i = 0; i < nWords; i++) {
            initialText[i] = sc.next();
            String lowerS = initialText[i].toLowerCase();
            if (!wordToId.containsKey(lowerS)) {
                wordToId.put(lowerS, nextId++);
                wordInfos.add(new WordInfo(countR(lowerS), lowerS.length()));
            }
        }

        int nRules = sc.nextInt();
        List<int[]> edges = new ArrayList<>();
        for (int i = 0; i < nRules; i++) {
            String uStr = sc.next().toLowerCase();
            String vStr = sc.next().toLowerCase();
            if (!wordToId.containsKey(uStr)) {
                wordToId.put(uStr, nextId++);
                wordInfos.add(new WordInfo(countR(uStr), uStr.length()));
            }
            if (!wordToId.containsKey(vStr)) {
                wordToId.put(vStr, nextId++);
                wordInfos.add(new WordInfo(countR(vStr), vStr.length()));
            }
            edges.add(new int[]{wordToId.get(uStr), wordToId.get(vStr)});
        }

        int totalUniqueWords = nextId;
        adj = new ArrayList<>();
        for (int i = 0; i < totalUniqueWords; i++) adj.add(new ArrayList<>());
        for (int[] edge : edges) adj.get(edge[0]).add(edge[1]);

        dfn = new int[totalUniqueWords];
        low = new int[totalUniqueWords];
        sccId = new int[totalUniqueWords];
        onStack = new boolean[totalUniqueWords];
        stack = new Stack<>();
        timer = 0;
        sccCount = 0;
        for (int i = 0; i < totalUniqueWords; i++) {
            if (dfn[i] == 0) tarjan(i);
        }

        sccOptimal = new WordInfo[sccCount + 1];
        Arrays.fill(sccOptimal, new WordInfo(Long.MAX_VALUE, Long.MAX_VALUE));
        for (int i = 0; i < totalUniqueWords; i++) {
            if (wordInfos.get(i).compareTo(sccOptimal[sccId[i]]) < 0) {
                sccOptimal[sccId[i]] = wordInfos.get(i);
            }
        }

        sccAdj = new ArrayList<>();
        for (int i = 0; i <= sccCount; i++) sccAdj.add(new ArrayList<>());
        Set<Long> sccEdgeSet = new HashSet<>();
        for (int u = 0; u < totalUniqueWords; u++) {
            for (int v : adj.get(u)) {
                if (sccId[u] != sccId[v]) {
                    long edgeHash = (long)sccId[u] << 32 | sccId[v];
                    if (!sccEdgeSet.contains(edgeHash)) {
                        sccAdj.get(sccId[u]).add(sccId[v]);
                        sccEdgeSet.add(edgeHash);
                    }
                }
            }
        }

        dpMemo = new WordInfo[sccCount + 1];
        dpVisited = new boolean[sccCount + 1];
        for (int i = 1; i <= sccCount; i++) {
            if (!dpVisited[i]) solveDp(i);
        }

        long totalR = 0, totalLen = 0;
        for (String word : initialText) {
            int id = wordToId.get(word.toLowerCase());
            int sId = sccId[id];
            WordInfo best = dpMemo[sId];
            totalR += best.rCount;
            totalLen += best.len;
        }

        System.out.println(totalR + " " + totalLen);
    }
}
import sys

# It's necessary for deep recursion in Tarjan's algorithm
sys.setrecursionlimit(200005)

timer = 0
scc_count = 0
dfn = []
low = []
scc_id = []
on_stack = []
stack = []
adj = {}

def tarjan(u):
    global timer, scc_count, dfn, low, scc_id, on_stack, stack, adj
    
    timer += 1
    dfn[u] = low[u] = timer
    stack.append(u)
    on_stack[u] = True

    for v in adj.get(u, []):
        if dfn[v] == 0:
            tarjan(v)
            low[u] = min(low[u], low[v])
        elif on_stack[v]:
            low[u] = min(low[u], dfn[v])

    if dfn[u] == low[u]:
        scc_count += 1
        while True:
            node = stack.pop()
            on_stack[node] = False
            scc_id[node] = scc_count
            if node == u:
                break

dp_memo = {}
scc_adj = {}
scc_optimal = {}

def solve_dp(u_scc):
    global dp_memo, scc_adj, scc_optimal
    if u_scc in dp_memo:
        return dp_memo[u_scc]

    res = scc_optimal[u_scc]
    for v_scc in scc_adj.get(u_scc, []):
        res = min(res, solve_dp(v_scc))
    
    dp_memo[u_scc] = res
    return res

def main():
    global timer, scc_count, dfn, low, scc_id, on_stack, stack, adj
    global dp_memo, scc_adj, scc_optimal

    n_words = int(sys.stdin.readline())
    initial_text = sys.stdin.readline().strip().split()

    word_to_id = {}
    word_infos = []
    next_id = 0

    def get_id(s):
        nonlocal next_id
        lower_s = s.lower()
        if lower_s not in word_to_id:
            word_to_id[lower_s] = next_id
            r_count = lower_s.count('r')
            word_infos.append((r_count, len(lower_s)))
            next_id += 1
        return word_to_id[lower_s]

    for word in initial_text:
        get_id(word)

    n_rules = int(sys.stdin.readline())
    edges = []
    for _ in range(n_rules):
        u_str, v_str = sys.stdin.readline().strip().split()
        u_id = get_id(u_str)
        v_id = get_id(v_str)
        edges.append((u_id, v_id))

    total_unique_words = next_id
    adj = {i: [] for i in range(total_unique_words)}
    for u, v in edges:
        adj[u].append(v)

    dfn = [0] * total_unique_words
    low = [0] * total_unique_words
    scc_id = [0] * total_unique_words
    on_stack = [False] * total_unique_words
    stack = []
    timer = 0
    scc_count = 0
    for i in range(total_unique_words):
        if dfn[i] == 0:
            tarjan(i)

    scc_optimal = {}
    for i in range(total_unique_words):
        s_id = scc_id[i]
        info = word_infos[i]
        if s_id not in scc_optimal or info < scc_optimal[s_id]:
            scc_optimal[s_id] = info
    
    scc_adj = {i: [] for i in range(1, scc_count + 1)}
    for u in range(total_unique_words):
        for v in adj.get(u, []):
            if scc_id[u] != scc_id[v]:
                scc_adj[scc_id[u]].append(scc_id[v])

    dp_memo = {}
    for i in range(1, scc_count + 1):
        if i not in dp_memo:
            solve_dp(i)

    total_r, total_len = 0, 0
    for word in initial_text:
        s_id = scc_id[get_id(word)]
        best_r, best_len = dp_memo[s_id]
        total_r += best_r
        total_len += best_len

    print(total_r, total_len)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:强连通分量 (Tarjan 算法) + 缩点 + 动态规划
  • 时间复杂度:,其中 是图中独立单词的总数, 是规则的总数。整个算法的复杂度与图的规模呈线性关系。
  • 空间复杂度:,用于存储图、Tarjan 算法的辅助数组、DP 表等。