def solve():
n = int(input())
arr = list(map(int, input().split()))
if n < 3:
print(0)
return
total_sum = sum(arr)
if total_sum % 3 != 0:
print(0)
return
target = total_sum // 3
prefix_sum = [0] * (n + 1)
for i in range(1, n + 1):
prefix_sum[i] = prefix_sum[i - 1] + arr[i - 1]
prefix_positive = [0] * (n + 1)
for i in range(1, n + 1):
prefix_positive[i] = prefix_positive[i - 1] + (1 if arr[i - 1] > 0 else 0)
# Find all i where prefix_sum[i] == target and first part has at least one positive
candidates_i = []
for i in range(
1, n - 1
): # i can be up to n-2 to leave at least 2 elements for j and third part
if prefix_sum[i] == target and prefix_positive[i] >= 1:
candidates_i.append(i)
# Find all j where prefix_sum[j] == 2*target and third part has at least one positive
candidates_j = []
for j in range(
2, n
): # j can be up to n-1 to leave at least 1 element for third part
if (
prefix_sum[j] == 2 * target
and (prefix_positive[n] - prefix_positive[j]) >= 1
):
candidates_j.append(j)
# Now count all i < j where the middle part has at least one positive
count = 0
from bisect import bisect_right
for j in candidates_j:
# Find the number of i in candidates_i less than j where middle part has at least one positive
# Middle part is from i+1 to j, so prefix_positive[j] - prefix_positive[i] >= 1
# So for each j, we need i < j and prefix_positive[j] - prefix_positive[i] >= 1
# Which is equivalent to prefix_positive[i] <= prefix_positive[j] - 1
max_pos_for_i = prefix_positive[j] - 1
# We need to find all i in candidates_i with i < j and prefix_positive[i] <= max_pos_for_i
# Since candidates_i is already filtered for prefix_positive[i] >= 1, we can proceed
left = 0
right = bisect_right(candidates_i, j - 1) # i < j
# Now, within candidates_i[:right], count the number where prefix_positive[i] <= max_pos_for_i
# To do this efficiently, we can precompute a list of prefix_positive[i] for candidates_i
# But since candidates_i is small in practice, we can iterate
for i in candidates_i[:right]:
if prefix_positive[i] <= max_pos_for_i:
count += 1
print(count)
solve()