import sys class Fenw: def __init__(self, size): self.n = size self.tree = [0] * (self.n + 2) def update(self, index, delta): while index <= self.n: self.tree[index] += delta index += index & -index def query(self, index): s = 0 while index: s += self.tree[index] index -= index & -index return s def range_query(self, l, r): if l > r: return 0 return self.query(r) - self.query(l - 1) def main(): data = sys.stdin.read().split() if not data: return n = int(data[0]) arr = list(map(int, data[1:1+n])) total = sum(arr) if total % 3 != 0: print(0) return target = total // 3 prefix = [0] * (n + 1) ones = [0] * (n + 1) for i in range(1, n + 1): prefix[i] = prefix[i - 1] + arr[i - 1] ones[i] = ones[i - 1] + (1 if arr[i - 1] == 1 else 0) total_ones = ones[n] if total_ones < 3: print(0) return first_cuts = [] for i in range(1, n - 1): if prefix[i] == target and ones[i] > 0: first_cuts.append(i) second_cuts = [] for j in range(2, n): if prefix[j] == 2 * target and ones[j] < total_ones: second_cuts.append(j) if not first_cuts or not second_cuts: print(0) return first_cuts.sort(reverse=True) second_cuts.sort() fenw = Fenw(total_ones + 2) p = len(second_cuts) - 1 count = 0 for i in first_cuts: while p >= 0 and second_cuts[p] > i: j_val = second_cuts[p] pos_in_fenw = ones[j_val] + 1 fenw.update(pos_in_fenw, 1) p -= 1 low_bound = ones[i] + 1 if low_bound <= total_ones: cnt = fenw.range_query(low_bound + 1, total_ones + 1) count += cnt print(count) if __name__ == "__main__": main()