import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        int[] a = new int[n];
        for (int i = 0; i < n; i++) {
            a[i] = scanner.nextInt();
        }

        long[] preSum = new long[n];
        preSum[0] = a[0];
        long sum = a[0];
        for (int i = 1; i < n; i++) {
            preSum[i] = preSum[i - 1] + a[i];
            sum += a[i];
        }

        if (sum % 3 != 0) {
            System.out.println(0);
            return;
        }
        long k = sum / 3;

        boolean[] prefixPos = new boolean[n];
        prefixPos[0] = a[0] > 0;
        for (int i = 1; i < n; i++) {
            prefixPos[i] = prefixPos[i - 1] || (a[i] > 0);
        }

        boolean[] suffixPos = new boolean[n];
        suffixPos[n - 1] = a[n - 1] > 0;
        for (int i = n - 2; i >= 0; i--) {
            suffixPos[i] = suffixPos[i + 1] || (a[i] > 0);
        }

        int[] nextPos = new int[n];
        int lastPos = n;
        for (int i = n - 1; i >= 0; i--) {
            if (a[i] > 0) {
                lastPos = i;
            }
            nextPos[i] = lastPos;
        }

        List<Integer> listI = new ArrayList<>();
        for (int i = 0; i <= n - 3; i++) {
            if (preSum[i] == k && prefixPos[i]) {
                listI.add(i);
            }
        }

        List<Integer> listJ = new ArrayList<>();
        for (int j = 1; j <= n - 2; j++) {
            if (preSum[j] == 2 * k && suffixPos[j + 1]) {
                listJ.add(j);
            }
        }

        Collections.sort(listJ);

        long ans = 0;
        for (int i : listI) {
            int iPlus1 = i + 1;
            int jMin = Math.max(iPlus1, nextPos[iPlus1]);

            int left = 0, right = listJ.size() - 1;
            int pos = listJ.size();
            while (left <= right) {
                int mid = left + (right - left) / 2;
                int jVal = listJ.get(mid);
                if (jVal >= jMin) {
                    pos = mid;
                    right = mid - 1;
                } else {
                    left = mid + 1;
                }
            }
            ans += listJ.size() - pos;
        }

        System.out.println(ans);
    }
}

https://www.nowcoder.com/discuss/727521113110073344

思路:

  1. 输入处理:读取数组并计算前缀和数组。
  2. 总和检查:检查数组总和是否能被3整除。
  3. 预处理数组:计算前缀正数存在数组、后缀正数存在数组和下一个正数位置数组。
  4. 收集分割点:找到所有可能的分割点i和j。
  5. 二分查找优化:通过二分查找快速统计满足条件的分割点组合数,确保每个子数组满足正数条件。