茉茉的密码

思路

题意很直接:给你 个由小写字母组成的字符串,找出一个在所有字符串中都出现过的公共子串,输出任意一个即可。

既然"任意一个"都行,那最简单的想法是——找一个所有字符串里都有的单个字符不就行了?确实可以,但这题其实是在考最长公共子串的思路,我们不妨把它做完整。

二分 + 滚动哈希

核心思路是二分答案:二分公共子串的长度 ,然后验证是否存在长度为 的子串同时出现在所有字符串中。

验证的方法是滚动哈希(Rabin-Karp)

  1. 取最短的那个字符串 作为基准(公共子串长度不可能超过最短串)。
  2. 所有长度为 的子串的哈希值丢进一个集合 common
  3. 对每个其他字符串,算出它所有长度为 的子串哈希,和 common 取交集。
  4. 如果最终交集非空,说明长度 可行。

为了防止哈希冲突,用双模数)双哈希,基本上可以杜绝碰撞。

二分范围是 ,每次验证的复杂度是所有字符串长度之和乘以集合操作的 ,整体效率很不错。

复杂度

设所有字符串的总长度为 ,最短串长度为

  • 时间复杂度,二分 层,每层扫描所有字符串。
  • 空间复杂度,存储哈希集合。

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;
    vector<string> strs(n);
    for(int i = 0; i < n; i++) cin >> strs[i];

    if(n == 1){
        cout << strs[0] << endl;
        return 0;
    }

    // 找最短串作为基准
    int minIdx = 0;
    for(int i = 1; i < n; i++){
        if(strs[i].size() < strs[minIdx].size()) minIdx = i;
    }

    const long long MOD1 = 1e9 + 7, MOD2 = 1e9 + 9;
    const long long BASE1 = 131, BASE2 = 137;

    string &s0 = strs[minIdx];
    int lo = 1, hi = s0.size(), ansStart = 0, ansLen = 0;

    while(lo <= hi){
        int mid = (lo + hi) / 2;

        long long pw1 = 1, pw2 = 1;
        for(int i = 0; i < mid; i++){
            pw1 = pw1 * BASE1 % MOD1;
            pw2 = pw2 * BASE2 % MOD2;
        }

        // 收集 s0 中所有长度为 mid 的子串哈希
        set<pair<long long,long long>> common;
        {
            long long h1 = 0, h2 = 0;
            for(int i = 0; i < mid; i++){
                h1 = (h1 * BASE1 + s0[i]) % MOD1;
                h2 = (h2 * BASE2 + s0[i]) % MOD2;
            }
            common.insert({h1, h2});
            for(int i = mid; i < (int)s0.size(); i++){
                h1 = (h1 * BASE1 + s0[i] - pw1 * s0[i - mid] % MOD1 + MOD1) % MOD1;
                h2 = (h2 * BASE2 + s0[i] - pw2 * s0[i - mid] % MOD2 + MOD2) % MOD2;
                common.insert({h1, h2});
            }
        }

        // 依次和每个字符串取交集
        bool found = true;
        for(int k = 0; k < n && found; k++){
            if(k == minIdx) continue;
            string &s = strs[k];
            if((int)s.size() < mid){ found = false; break; }
            set<pair<long long,long long>> cur;
            long long h1 = 0, h2 = 0;
            for(int i = 0; i < mid; i++){
                h1 = (h1 * BASE1 + s[i]) % MOD1;
                h2 = (h2 * BASE2 + s[i]) % MOD2;
            }
            if(common.count({h1, h2})) cur.insert({h1, h2});
            for(int i = mid; i < (int)s.size(); i++){
                h1 = (h1 * BASE1 + s[i] - pw1 * s[i - mid] % MOD1 + MOD1) % MOD1;
                h2 = (h2 * BASE2 + s[i] - pw2 * s[i - mid] % MOD2 + MOD2) % MOD2;
                if(common.count({h1, h2})) cur.insert({h1, h2});
            }
            if(cur.empty()) found = false;
            else common = cur;
        }

        if(found){
            ansLen = mid;
            auto [th1, th2] = *common.begin();
            long long h1 = 0, h2 = 0;
            for(int i = 0; i < mid; i++){
                h1 = (h1 * BASE1 + s0[i]) % MOD1;
                h2 = (h2 * BASE2 + s0[i]) % MOD2;
            }
            if(h1 == th1 && h2 == th2) ansStart = 0;
            else {
                for(int i = mid; i < (int)s0.size(); i++){
                    h1 = (h1 * BASE1 + s0[i] - pw1 * s0[i - mid] % MOD1 + MOD1) % MOD1;
                    h2 = (h2 * BASE2 + s0[i] - pw2 * s0[i - mid] % MOD2 + MOD2) % MOD2;
                    if(h1 == th1 && h2 == th2){ ansStart = i - mid + 1; break; }
                }
            }
            lo = mid + 1;
        } else {
            hi = mid - 1;
        }
    }

    cout << s0.substr(ansStart, ansLen) << endl;
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    static final long MOD1 = 1_000_000_007L, MOD2 = 1_000_000_009L;
    static final long BASE1 = 131, BASE2 = 137;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        String[] strs = new String[n];
        for (int i = 0; i < n; i++) strs[i] = br.readLine().trim();

        if (n == 1) {
            System.out.println(strs[0]);
            return;
        }

        int minIdx = 0;
        for (int i = 1; i < n; i++) {
            if (strs[i].length() < strs[minIdx].length()) minIdx = i;
        }

        String s0 = strs[minIdx];
        int lo = 1, hi = s0.length(), ansStart = 0, ansLen = 0;

        while (lo <= hi) {
            int mid = (lo + hi) / 2;

            long pw1 = 1, pw2 = 1;
            for (int i = 0; i < mid; i++) {
                pw1 = pw1 * BASE1 % MOD1;
                pw2 = pw2 * BASE2 % MOD2;
            }

            Set<Long> common = new HashSet<>();
            {
                long h1 = 0, h2 = 0;
                for (int i = 0; i < mid; i++) {
                    h1 = (h1 * BASE1 + s0.charAt(i)) % MOD1;
                    h2 = (h2 * BASE2 + s0.charAt(i)) % MOD2;
                }
                common.add(h1 * MOD2 + h2);
                for (int i = mid; i < s0.length(); i++) {
                    h1 = (h1 * BASE1 + s0.charAt(i) - pw1 * s0.charAt(i - mid) % MOD1 + MOD1 * 2) % MOD1;
                    h2 = (h2 * BASE2 + s0.charAt(i) - pw2 * s0.charAt(i - mid) % MOD2 + MOD2 * 2) % MOD2;
                    common.add(h1 * MOD2 + h2);
                }
            }

            boolean found = true;
            for (int k = 0; k < n && found; k++) {
                if (k == minIdx) continue;
                String s = strs[k];
                if (s.length() < mid) { found = false; break; }
                Set<Long> cur = new HashSet<>();
                long h1 = 0, h2 = 0;
                for (int i = 0; i < mid; i++) {
                    h1 = (h1 * BASE1 + s.charAt(i)) % MOD1;
                    h2 = (h2 * BASE2 + s.charAt(i)) % MOD2;
                }
                long key = h1 * MOD2 + h2;
                if (common.contains(key)) cur.add(key);
                for (int i = mid; i < s.length(); i++) {
                    h1 = (h1 * BASE1 + s.charAt(i) - pw1 * s.charAt(i - mid) % MOD1 + MOD1 * 2) % MOD1;
                    h2 = (h2 * BASE2 + s.charAt(i) - pw2 * s.charAt(i - mid) % MOD2 + MOD2 * 2) % MOD2;
                    key = h1 * MOD2 + h2;
                    if (common.contains(key)) cur.add(key);
                }
                if (cur.isEmpty()) found = false;
                else common = cur;
            }

            if (found) {
                ansLen = mid;
                long target = common.iterator().next();
                long h1 = 0, h2 = 0;
                for (int i = 0; i < mid; i++) {
                    h1 = (h1 * BASE1 + s0.charAt(i)) % MOD1;
                    h2 = (h2 * BASE2 + s0.charAt(i)) % MOD2;
                }
                if (h1 * MOD2 + h2 == target) ansStart = 0;
                else {
                    for (int i = mid; i < s0.length(); i++) {
                        h1 = (h1 * BASE1 + s0.charAt(i) - pw1 * s0.charAt(i - mid) % MOD1 + MOD1 * 2) % MOD1;
                        h2 = (h2 * BASE2 + s0.charAt(i) - pw2 * s0.charAt(i - mid) % MOD2 + MOD2 * 2) % MOD2;
                        if (h1 * MOD2 + h2 == target) { ansStart = i - mid + 1; break; }
                    }
                }
                lo = mid + 1;
            } else {
                hi = mid - 1;
            }
        }

        System.out.println(s0.substring(ansStart, ansStart + ansLen));
    }
}
import sys
input = sys.stdin.readline

def main():
    n = int(input())
    strs = [input().strip() for _ in range(n)]

    if n == 1:
        print(strs[0])
        return

    min_idx = 0
    for i in range(1, n):
        if len(strs[i]) < len(strs[min_idx]):
            min_idx = i

    MOD1 = 10**9 + 7
    MOD2 = 10**9 + 9
    BASE1 = 131
    BASE2 = 137

    s0 = strs[min_idx]
    lo, hi = 1, len(s0)
    ans_start, ans_len = 0, 0

    while lo <= hi:
        mid = (lo + hi) // 2

        pw1 = pow(BASE1, mid, MOD1)
        pw2 = pow(BASE2, mid, MOD2)

        # Collect hashes of s0 substrings of length mid
        common = set()
        h1 = h2 = 0
        for i in range(mid):
            h1 = (h1 * BASE1 + ord(s0[i])) % MOD1
            h2 = (h2 * BASE2 + ord(s0[i])) % MOD2
        common.add((h1, h2))
        for i in range(mid, len(s0)):
            h1 = (h1 * BASE1 + ord(s0[i]) - pw1 * ord(s0[i - mid])) % MOD1
            h2 = (h2 * BASE2 + ord(s0[i]) - pw2 * ord(s0[i - mid])) % MOD2
            common.add((h1, h2))

        found = True
        for k in range(n):
            if k == min_idx:
                continue
            s = strs[k]
            if len(s) < mid:
                found = False
                break
            cur = set()
            h1 = h2 = 0
            for i in range(mid):
                h1 = (h1 * BASE1 + ord(s[i])) % MOD1
                h2 = (h2 * BASE2 + ord(s[i])) % MOD2
            if (h1, h2) in common:
                cur.add((h1, h2))
            for i in range(mid, len(s)):
                h1 = (h1 * BASE1 + ord(s[i]) - pw1 * ord(s[i - mid])) % MOD1
                h2 = (h2 * BASE2 + ord(s[i]) - pw2 * ord(s[i - mid])) % MOD2
                if (h1, h2) in common:
                    cur.add((h1, h2))
            if not cur:
                found = False
                break
            common = cur

        if found:
            ans_len = mid
            th1, th2 = next(iter(common))
            h1 = h2 = 0
            for i in range(mid):
                h1 = (h1 * BASE1 + ord(s0[i])) % MOD1
                h2 = (h2 * BASE2 + ord(s0[i])) % MOD2
            if h1 == th1 and h2 == th2:
                ans_start = 0
            else:
                for i in range(mid, len(s0)):
                    h1 = (h1 * BASE1 + ord(s0[i]) - pw1 * ord(s0[i - mid])) % MOD1
                    h2 = (h2 * BASE2 + ord(s0[i]) - pw2 * ord(s0[i - mid])) % MOD2
                    if h1 == th1 and h2 == th2:
                        ans_start = i - mid + 1
                        break
            lo = mid + 1
        else:
            hi = mid - 1

    print(s0[ans_start:ans_start + ans_len])

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l));
rl.on('close', () => {
    const n = parseInt(lines[0]);
    const strs = [];
    for (let i = 1; i <= n; i++) strs.push(lines[i]);

    if (n === 1) {
        console.log(strs[0]);
        return;
    }

    let minIdx = 0;
    for (let i = 1; i < n; i++) {
        if (strs[i].length < strs[minIdx].length) minIdx = i;
    }

    const MOD1 = 1000000007n, MOD2 = 1000000009n;
    const BASE1 = 131n, BASE2 = 137n;

    const s0 = strs[minIdx];
    let lo = 1, hi = s0.length, ansStart = 0, ansLen = 0;

    while (lo <= hi) {
        const mid = (lo + hi) >> 1;

        let pw1 = 1n, pw2 = 1n;
        for (let i = 0; i < mid; i++) {
            pw1 = pw1 * BASE1 % MOD1;
            pw2 = pw2 * BASE2 % MOD2;
        }

        let common = new Set();
        {
            let h1 = 0n, h2 = 0n;
            for (let i = 0; i < mid; i++) {
                h1 = (h1 * BASE1 + BigInt(s0.charCodeAt(i))) % MOD1;
                h2 = (h2 * BASE2 + BigInt(s0.charCodeAt(i))) % MOD2;
            }
            common.add(h1.toString() + ',' + h2.toString());
            for (let i = mid; i < s0.length; i++) {
                h1 = ((h1 * BASE1 + BigInt(s0.charCodeAt(i))) % MOD1 - pw1 * BigInt(s0.charCodeAt(i - mid)) % MOD1 + MOD1 * 2n) % MOD1;
                h2 = ((h2 * BASE2 + BigInt(s0.charCodeAt(i))) % MOD2 - pw2 * BigInt(s0.charCodeAt(i - mid)) % MOD2 + MOD2 * 2n) % MOD2;
                common.add(h1.toString() + ',' + h2.toString());
            }
        }

        let found = true;
        for (let k = 0; k < n && found; k++) {
            if (k === minIdx) continue;
            const s = strs[k];
            if (s.length < mid) { found = false; break; }
            const cur = new Set();
            let h1 = 0n, h2 = 0n;
            for (let i = 0; i < mid; i++) {
                h1 = (h1 * BASE1 + BigInt(s.charCodeAt(i))) % MOD1;
                h2 = (h2 * BASE2 + BigInt(s.charCodeAt(i))) % MOD2;
            }
            let key = h1.toString() + ',' + h2.toString();
            if (common.has(key)) cur.add(key);
            for (let i = mid; i < s.length; i++) {
                h1 = ((h1 * BASE1 + BigInt(s.charCodeAt(i))) % MOD1 - pw1 * BigInt(s.charCodeAt(i - mid)) % MOD1 + MOD1 * 2n) % MOD1;
                h2 = ((h2 * BASE2 + BigInt(s.charCodeAt(i))) % MOD2 - pw2 * BigInt(s.charCodeAt(i - mid)) % MOD2 + MOD2 * 2n) % MOD2;
                key = h1.toString() + ',' + h2.toString();
                if (common.has(key)) cur.add(key);
            }
            if (cur.size === 0) found = false;
            else common = cur;
        }

        if (found) {
            ansLen = mid;
            const target = common.values().next().value;
            const [th1s, th2s] = target.split(',');
            const th1 = BigInt(th1s), th2 = BigInt(th2s);
            let h1 = 0n, h2 = 0n;
            for (let i = 0; i < mid; i++) {
                h1 = (h1 * BASE1 + BigInt(s0.charCodeAt(i))) % MOD1;
                h2 = (h2 * BASE2 + BigInt(s0.charCodeAt(i))) % MOD2;
            }
            if (h1 === th1 && h2 === th2) ansStart = 0;
            else {
                for (let i = mid; i < s0.length; i++) {
                    h1 = ((h1 * BASE1 + BigInt(s0.charCodeAt(i))) % MOD1 - pw1 * BigInt(s0.charCodeAt(i - mid)) % MOD1 + MOD1 * 2n) % MOD1;
                    h2 = ((h2 * BASE2 + BigInt(s0.charCodeAt(i))) % MOD2 - pw2 * BigInt(s0.charCodeAt(i - mid)) % MOD2 + MOD2 * 2n) % MOD2;
                    if (h1 === th1 && h2 === th2) { ansStart = i - mid + 1; break; }
                }
            }
            lo = mid + 1;
        } else {
            hi = mid - 1;
        }
    }

    console.log(s0.substring(ansStart, ansStart + ansLen));
});