题目链接

我们惺惺相惜

题目描述

给定一个长度为 的数组 ,我们定义一个区间是“好的”,当且仅当这个区间可以被分成两个非空的、元素相对顺序不变的严格单调递增子序列。对于给出的多次询问,你需要回答指定区间是不是“好的”。

输入:

  • 第一行一个整数 ,表示有 次询问。
  • 对于每次询问,第一行两个整数 ,第二行 个整数表示数组
  • 接下来 行,每行两个整数 ,表示询问的区间。

输出:

  • 对于每次询问,如果区间是“好的”,输出 YES,否则输出 NO

解题思路

一个区间是“好的”,等价于该区间的子数组可以被划分为两个严格递增子序列。根据 Dilworth 定理,这又等价于该子数组的最长非递增子序列的长度不超过

因此,一个区间不是“好的”,当且仅当它包含一个长度为 的非递增子序列。即存在三个下标 满足

我们的目标就是对于每个查询 ,判断是否存在这样的三元组

我们可以采用离线处理的方式来高效解决此问题:

  1. 预处理“坏对”: 我们首先找出所有可能构成非递增三元组 的“坏对” 。对于数组中的每一个元素 (作为中间值),我们寻找它左侧最靠右的 使得 ,以及右侧最靠左的 使得

    • 所有左边界 (即每个位置左边第一个大于等于它的元素位置)可以通过一次单调栈遍历在 时间内求出。
    • 所有右边界 (即每个位置右边第一个小于等于它的元素位置)也可以通过一次反向的单調栈遍历在 时间内求出。
    • 这样我们就得到了所有满足 的“最小”区间 [p, k],我们称之为“坏对”。
  2. 离线查询与线段树: 一个查询区间 是“坏的”,当且仅当它完全包含某个“坏对” [p, k],即 。 这个问题可以转化为一个二维查询问题,但可以通过离线处理更高效地解决:

    • 我们将所有查询按右端点 分组。
    • 我们将所有“坏对”按右端点 分组。
    • 我们从左到右遍历 。在每一步 : a. 将所有右端点 的“坏对” 加入一个数据结构。这个数据结构需要支持查询某个区间内是否存在值。我们使用线段树,将“坏对”的左端点 更新到线段树的第 个位置上。 b. 处理所有右端点 的查询 。我们在线段树上查询区间 上的最大值。如果这个最大值 max_p 满足 max_p >= l,说明在 区间内存在一个已激活的“坏对”的左端点 ,并且这个“坏对”的右端点 ,因此该“坏对” [p, k] 被查询区间 完全包含。此时答案为 "NO",否则为 "YES"。

通过这种方式,我们可以在 的时间内处理所有查询。

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>

using namespace std;

const int N = 200005;

struct SegTreeNode {
    int max_val;
};

SegTreeNode tree[4 * N];
int a[N];
vector<pair<int, int>> queries_at_r[N];
vector<int> bad_pairs_at_k[N];
string ans[N];

void build(int node, int start, int end) {
    tree[node].max_val = 0;
    if (start == end) return;
    int mid = (start + end) / 2;
    build(2 * node, start, mid);
    build(2 * node + 1, mid + 1, end);
}

void update(int node, int start, int end, int idx, int val) {
    if (start == end) {
        tree[node].max_val = max(tree[node].max_val, val);
        return;
    }
    int mid = (start + end) / 2;
    if (start <= idx && idx <= mid) {
        update(2 * node, start, mid, idx, val);
    } else {
        update(2 * node + 1, mid + 1, end, idx, val);
    }
    tree[node].max_val = max(tree[2 * node].max_val, tree[2 * node + 1].max_val);
}

int query(int node, int start, int end, int l, int r) {
    if (r < start || end < l) {
        return 0;
    }
    if (l <= start && end <= r) {
        return tree[node].max_val;
    }
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start, mid, l, r);
    int p2 = query(2 * node + 1, mid + 1, end, l, r);
    return max(p1, p2);
}

void solve() {
    int n, q;
    cin >> n >> q;

    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
        queries_at_r[i].clear();
        bad_pairs_at_k[i].clear();
    }

    for (int i = 0; i < q; ++i) {
        int l, r;
        cin >> l >> r;
        queries_at_r[r].push_back({l, i});
    }

    vector<int> p(n + 1, 0), k(n + 2, n + 1);
    stack<int> st;

    // 找到每个位置 i 左边第一个 >= a[i] 的位置 p
    for (int i = 1; i <= n; ++i) {
        while (!st.empty() && a[st.top()] < a[i]) {
            st.pop();
        }
        if (!st.empty()) p[i] = st.top();
        st.push(i);
    }
    while (!st.empty()) st.pop();

    // 找到每个位置 i 右边第一个 <= a[i] 的位置 k
    for (int i = n; i >= 1; --i) {
        while (!st.empty() && a[st.top()] > a[i]) {
            st.pop();
        }
        if (!st.empty()) k[i] = st.top();
        st.push(i);
    }

    for (int i = 1; i <= n; ++i) {
        if (p[i] != 0 && k[i] != n + 1) {
            bad_pairs_at_k[k[i]].push_back(p[i]);
        }
    }

    build(1, 1, n);

    for (int i = 1; i <= n; ++i) {
        for (int p_val : bad_pairs_at_k[i]) {
            update(1, 1, n, p_val, p_val);
        }
        for (auto& query_pair : queries_at_r[i]) {
            int l = query_pair.first;
            int id = query_pair.second;
            int max_p = query(1, 1, n, l, n);
            if (max_p >= l) {
                ans[id] = "NO";
            } else {
                ans[id] = "YES";
            }
        }
    }

    for (int i = 0; i < q; ++i) {
        cout << ans[i] << "\n";
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import java.util.Stack;

public class Main {
    static class SegTree {
        int[] tree;
        int n;

        SegTree(int size) {
            n = size;
            tree = new int[4 * n];
        }

        void update(int node, int start, int end, int idx, int val) {
            if (start == end) {
                tree[node] = Math.max(tree[node], val);
                return;
            }
            int mid = (start + end) / 2;
            if (start <= idx && idx <= mid) {
                update(2 * node, start, mid, idx, val);
            } else {
                update(2 * node + 1, mid + 1, end, idx, val);
            }
            tree[node] = Math.max(tree[2 * node], tree[2 * node + 1]);
        }

        int query(int node, int start, int end, int l, int r) {
            if (r < start || end < l || l > r) {
                return 0;
            }
            if (l <= start && end <= r) {
                return tree[node];
            }
            int mid = (start + end) / 2;
            int p1 = query(2 * node, start, mid, l, r);
            int p2 = query(2 * node + 1, mid + 1, end, l, r);
            return Math.max(p1, p2);
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            int n = sc.nextInt();
            int q = sc.nextInt();
            int[] a = new int[n + 1];
            for (int i = 1; i <= n; i++) {
                a[i] = sc.nextInt();
            }

            List<List<int[]>> queriesAtR = new ArrayList<>(n + 1);
            for (int i = 0; i <= n; i++) queriesAtR.add(new ArrayList<>());
            for (int i = 0; i < q; i++) {
                int l = sc.nextInt();
                int r = sc.nextInt();
                queriesAtR.get(r).add(new int[]{l, i});
            }

            int[] p = new int[n + 1];
            int[] k = new int[n + 2];
            Stack<Integer> stack = new Stack<>();

            for (int i = 1; i <= n; i++) {
                while (!stack.isEmpty() && a[stack.peek()] < a[i]) {
                    stack.pop();
                }
                p[i] = stack.isEmpty() ? 0 : stack.peek();
                stack.push(i);
            }
            stack.clear();

            for (int i = n; i >= 1; i--) {
                while (!stack.isEmpty() && a[stack.peek()] > a[i]) {
                    stack.pop();
                }
                k[i] = stack.isEmpty() ? n + 1 : stack.peek();
                stack.push(i);
            }
            
            List<List<Integer>> badPairsAtK = new ArrayList<>(n + 1);
            for (int i = 0; i <= n; i++) badPairsAtK.add(new ArrayList<>());

            for (int i = 1; i <= n; i++) {
                if (p[i] != 0 && k[i] != n + 1) {
                    badPairsAtK.get(k[i]).add(p[i]);
                }
            }

            String[] ans = new String[q];
            SegTree segTree = new SegTree(n + 1);

            for (int i = 1; i <= n; i++) {
                for (int pVal : badPairsAtK.get(i)) {
                    segTree.update(1, 1, n, pVal, pVal);
                }
                for (int[] query : queriesAtR.get(i)) {
                    int l = query[0];
                    int id = query[1];
                    int maxP = segTree.query(1, 1, n, l, n);
                    if (maxP >= l) {
                        ans[id] = "NO";
                    } else {
                        ans[id] = "YES";
                    }
                }
            }
            
            for (int i = 0; i < q; i++) {
                System.out.println(ans[i]);
            }
        }
    }
}
import sys

# 设置足够大的递归深度以防线段树递归时出错
sys.setrecursionlimit(200005)

def solve():
    n, q = map(int, input().split())
    a = [0] + list(map(int, input().split()))

    queries_at_r = [[] for _ in range(n + 1)]
    for i in range(q):
        l, r = map(int, input().split())
        queries_at_r[r].append((l, i))

    p = [0] * (n + 1)
    k = [n + 1] * (n + 2)
    stack = []

    # 找到每个位置 i 左边第一个 >= a[i] 的位置 p
    for i in range(1, n + 1):
        while stack and a[stack[-1]] < a[i]:
            stack.pop()
        if stack:
            p[i] = stack[-1]
        stack.append(i)
    
    stack.clear()

    # 找到每个位置 i 右边第一个 <= a[i] 的位置 k
    for i in range(n, 0, -1):
        while stack and a[stack[-1]] > a[i]:
            stack.pop()
        if stack:
            k[i] = stack[-1]
        stack.append(i)
        
    bad_pairs_at_k = [[] for _ in range(n + 1)]
    for i in range(1, n + 1):
        if p[i] != 0 and k[i] != n + 1:
            bad_pairs_at_k[k[i]].append(p[i])

    ans = [""] * q
    seg_tree = [0] * (4 * (n + 1))

    def update(node, start, end, idx, val):
        if start == end:
            seg_tree[node] = max(seg_tree[node], val)
            return
        mid = (start + end) // 2
        if start <= idx <= mid:
            update(2 * node, start, mid, idx, val)
        else:
            update(2 * node + 1, mid + 1, end, idx, val)
        seg_tree[node] = max(seg_tree[2 * node], seg_tree[2 * node + 1])

    def query(node, start, end, l, r):
        if r < start or end < l or l > r:
            return 0
        if l <= start and end <= r:
            return seg_tree[node]
        mid = (start + end) // 2
        p1 = query(2 * node, start, mid, l, r)
        p2 = query(2 * node + 1, mid + 1, end, l, r)
        return max(p1, p2)

    for i in range(1, n + 1):
        for p_val in bad_pairs_at_k[i]:
            update(1, 1, n, p_val, p_val)
        
        for l, q_id in queries_at_r[i]:
            max_p = query(1, 1, n, l, n)
            if max_p >= l:
                ans[q_id] = "NO"
            else:
                ans[q_id] = "YES"
    
    for res in ans:
        print(res)


def main():
    t = int(input())
    for _ in range(t):
        solve()

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:单调栈 + 离线查询 + 线段树
  • 时间复杂度:,其中 是数组长度, 是查询次数。单调栈预处理为 ,线段树的更新和查询为
  • 空间复杂度:,用于存储输入、查询、坏对以及线段树。