def solve(): n = int(input()) arr = list(map(int, input().split())) if n < 3: print(0) return total_sum = sum(arr) if total_sum % 3 != 0: print(0) return target = total_sum // 3 prefix_sum = [0] * (n + 1) for i in range(1, n + 1): prefix_sum[i] = prefix_sum[i - 1] + arr[i - 1] prefix_positive = [0] * (n + 1) for i in range(1, n + 1): prefix_positive[i] = prefix_positive[i - 1] + (1 if arr[i - 1] > 0 else 0) # Find all i where prefix_sum[i] == target and first part has at least one positive candidates_i = [] for i in range( 1, n - 1 ): # i can be up to n-2 to leave at least 2 elements for j and third part if prefix_sum[i] == target and prefix_positive[i] >= 1: candidates_i.append(i) # Find all j where prefix_sum[j] == 2*target and third part has at least one positive candidates_j = [] for j in range( 2, n ): # j can be up to n-1 to leave at least 1 element for third part if ( prefix_sum[j] == 2 * target and (prefix_positive[n] - prefix_positive[j]) >= 1 ): candidates_j.append(j) # Now count all i < j where the middle part has at least one positive count = 0 from bisect import bisect_right for j in candidates_j: # Find the number of i in candidates_i less than j where middle part has at least one positive # Middle part is from i+1 to j, so prefix_positive[j] - prefix_positive[i] >= 1 # So for each j, we need i < j and prefix_positive[j] - prefix_positive[i] >= 1 # Which is equivalent to prefix_positive[i] <= prefix_positive[j] - 1 max_pos_for_i = prefix_positive[j] - 1 # We need to find all i in candidates_i with i < j and prefix_positive[i] <= max_pos_for_i # Since candidates_i is already filtered for prefix_positive[i] >= 1, we can proceed left = 0 right = bisect_right(candidates_i, j - 1) # i < j # Now, within candidates_i[:right], count the number where prefix_positive[i] <= max_pos_for_i # To do this efficiently, we can precompute a list of prefix_positive[i] for candidates_i # But since candidates_i is small in practice, we can iterate for i in candidates_i[:right]: if prefix_positive[i] <= max_pos_for_i: count += 1 print(count) solve()