import sys
import math

def solve(s, n, k):
    """寻找一个子串[l, r],使得其中恰好有k个"01"子序列
	Args:
		s: 输入的字符串
		n: 字符串大小
		k: 目标个数
    """
    # 特殊情况:k=0
    if k == 0:
        # 任何单个字符都满足条件
        return "1 1"

    # 预处理前缀数组
    # prefix_zero[i]: s[0..i-1]中'0'的数量
    prefix_zero = [0] * (n + 1)
    # prefix_sum_01[i]: s[0..i-1]中"01"子序列的数量
    prefix_sum_01 = [0] * (n + 1)
    # zero_list: 顺序储存s中'0'的索引
    zero_list = []

    # 填充前缀数组
    for i in range(1, n + 1):
        prefix_sum_01[i] = prefix_sum_01[i - 1]

        if s[i - 1] == "0":
            prefix_zero[i] = prefix_zero[i - 1] + 1
            zero_list.append(i - 1)
        else:
            prefix_zero[i] = prefix_zero[i - 1]
            # 当前'1'贡献的"01"子序列数量 = 它前面'0'的数量
            prefix_sum_01[i] += prefix_zero[i - 1]

    # 计算子串s[l..r]中"01"子序列的数量
    def count_01(l, r):
        """
        原理:prefix_sum_01[r+1] - prefix_sum_01[l] 包含[0,l-1]中'0'的贡献,需要减去这部分:prefix_zero[l] * (区间[l..r]中'1'的数量)
        """
        total = prefix_sum_01[r + 1] - prefix_sum_01[l]
        # 通过 i - prefix_zero[i] 计算'1'的数量
        num_ones = (r + 1 - prefix_zero[r + 1]) - (l - prefix_zero[l])
        return total - prefix_zero[l] * num_ones

    # 检查整个字符串是否可能包含k个"01"子序列
    max_possible = count_01(0, n - 1)
    if max_possible < k:
        return "-1"

    # 计算理论最小长度
    # 当区间长度L时,最大"01"子序列数量约为(L/2)²(当0和1各占一半时)
    # 所以L至少需要2*sqrt(k)向下取整的个数的元素
    min_length = max(1, int(2 * math.sqrt(k)))

    # 只需遍历'0'作为起始位置,没'0'再多的'1'也没用(无深意)
    for l in zero_list:
        # 根据理论最小长度调整右边界
        if n - l < min_length:
            break

        low = l
        high = n - 1
		
        while low <= high:
            mid = (low + high) // 2
            
			# 这里要更快的话可以直接把函数删了,直接放这里运行,性能能提升1/5左右,但作为笔记没必要
			current_count = count_01(l, mid)
			
			# 达到要求直接return
            if current_count == k:
                return f"{l + 1} {mid + 1}"
			
			# 继续二分
            if current_count > k:
                high = mid - 1
            else:
                low = mid + 1
			
			# 区间元素个数小于 min_length 也不用考虑了
            if (low + high) // 2 - l < min_length:
                break
    return "-1"


def main():
    data = sys.stdin.read().strip().split()
    if not data:
        return

    n = int(data[0])
    k = int(data[1])
    s = data[2]

    result = solve(s, n, k)
    print(result)


if __name__ == "__main__":
    main()

当前思路最后的优化结果,把该基本该考虑的都考虑了,再要优化也没多大提升了,要么就是什么神仙算法了

该有的思路和原理都在注释里了,本质上就是前缀和+保存所有结果,然后二分查找所有可能