import java.util.*; public class Solution { public int solve(int n, int[] a1, int[] a2, int[] a3, int[] m) { if (n <= 0) return 0; int[][] data = new int[3][]; data[0] = a1; data[1] = a2; data[2] = a3; int dp[][] = new int[3][n]; dp[0][0] = a1[0]; dp[1][0] = a2[0]; dp[2][0] = a3[0]; for (int col = 1; col < n; ++col) { dp[0][col] = Math.max(dp[0][col - 1], dp[1][col - 1] - m[col - 1]) + data[0][col]; dp[1][col] = Math.max(dp[1][col - 1], Math.max(dp[0][col - 1] - m[col - 1], dp[2][col - 1] - m[col - 1])) + data[1][col]; dp[2][col] = Math.max(dp[2][col - 1], dp[1][col - 1] - m[col - 1]) + data[2][col]; } return Math.max(dp[0][n - 1], Math.max(dp[1][n - 1], dp[2][n - 1])); } }