题目链接
题目描述
给定一个长度为 的 01 字符串
。你需要把它切分成若干个连续段,要求每个连续段内恰好包含一个数字
1
。
求切分方案的总数量,结果对 取模。
输入:
- 第一行输入一个整数
。
- 第二行输入一个长度为
的 01 字符串
。
输出:
- 输出一个整数,表示满足要求的切分方案数量对
取模的结果。
解题思路
这道题的核心在于理解“每个连续段内恰好包含一个数字 1”这个约束条件。这个条件极大地简化了问题,我们可以通过乘法原理来解决。
首先,我们找到字符串中所有字符 1
的位置。设它们的(从0开始的)下标分别为 。
-
特殊情况处理:
- 如果字符串中没有
1
(即),那么无法满足“每个段包含一个1”的条件,所以方案数为 0。
- 如果字符串中只有一个
1
(即),那么不需要进行任何切分,整个字符串本身就是唯一的一个段,满足条件。所以方案数为 1。
- 如果字符串中没有
-
一般情况 (
):
- 根据题意,第一个段必须包含第一个
1
(在处),第二个段必须包含第二个
1
(在处),以此类推。
- 这意味着,我们必须在每两个相邻的
1
之间进行一次且仅一次切分。 - 考虑第一和第二个
1
,它们分别位于和
。第一个段必须包含
处的
1
,但不包含处的
1
。所以,切分点必须在和
之间。
- 切分点可以在下标为
的任何一个位置之后。因此,在
和
之间,共有
种切分选择。
- 同理,对于任意一对相邻的
1
(位于和
),都有
种切分选择。
- 根据题意,第一个段必须包含第一个
-
最终计算:
- 由于每次切分的选择是相互独立的,根据乘法原理,总的方案数就是所有相邻
1
之间切分选择数量的乘积。 - 总方案数 =
- 由于最终结果需要取模,我们在计算乘积的每一步都要进行取模操作,以防止中间结果溢出。
- 由于每次切分的选择是相互独立的,根据乘法原理,总的方案数就是所有相邻
算法步骤:
- 遍历输入字符串,记录下所有
1
的下标。 - 如果没有找到
1
,输出 0。 - 如果只找到一个
1
,输出 1。 - 如果找到多个
1
,则计算相邻1
下标之差的乘积,每步都对取模。
代码
#include <iostream>
#include <vector>
#include <string>
using namespace std;
const int MOD = 1e9 + 7;
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n;
cin >> n;
string s;
cin >> s;
vector<int> one_indices;
for (int i = 0; i < n; ++i) {
if (s[i] == '1') {
one_indices.push_back(i);
}
}
if (one_indices.empty()) {
cout << 0 << '\n';
} else {
long long ans = 1;
for (size_t i = 0; i < one_indices.size() - 1; ++i) {
ans = (ans * (one_indices[i+1] - one_indices[i])) % MOD;
}
cout << ans << '\n';
}
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
String s = sc.next();
final int MOD = 1_000_000_007;
List<Integer> oneIndices = new ArrayList<>();
for (int i = 0; i < n; i++) {
if (s.charAt(i) == '1') {
oneIndices.add(i);
}
}
if (oneIndices.isEmpty()) {
System.out.println(0);
} else {
long ans = 1;
for (int i = 0; i < oneIndices.size() - 1; i++) {
ans = (ans * (oneIndices.get(i + 1) - oneIndices.get(i))) % MOD;
}
System.out.println(ans);
}
}
}
MOD = 1_000_000_007
n = int(input())
s = input()
one_indices = [i for i, char in enumerate(s) if char == '1']
if not one_indices:
print(0)
else:
ans = 1
for i in range(len(one_indices) - 1):
ans = (ans * (one_indices[i+1] - one_indices[i])) % MOD
print(ans)
算法及复杂度
- 算法:计数、乘法原理
- 时间复杂度:
,需要一次遍历来找出所有
1
的位置。 - 空间复杂度:
,其中
是字符串中
1
的数量,用于存储1
的下标。在最坏情况下为。