这个题目和leetcode的124题一样,其中有题解,但是大多数用的都是递归的方式。
在牛客,这个题目被归在了动态规划中,咱们这里就用动态规划的方式解决。
思路
动态规划首先是思考最优子结构性质,假设是指以node为开始节点的路径的最大和。那么
此时,如果以node为root的子树的最大路径和是最终结果,那么结果为
的计算需要从子节点到父节点,具体来说,可以采用递归或非递归的方式后序遍历计算。
ps: 不理解题目为什么采用这种输入方式,缺点在于:
- 让找子节点比较麻烦,导致代码比较长
- 当只有一个子节点时,无法判断是左节点还是右节点。只能强制规定是左节点
代码
#include<iostream>
#include<stack>
#include<vector>
#include<limits.h>
#include<math.h>
using namespace std;
// 寻找pos的左节点
int find_left(vector<int> positions, int pos) {
int len = positions.size();
if (pos >= len) {
return -1;
}
for(int i = 0;i < len;i++) {
if (positions[i] == pos+1) {
return i;
}
}
return -1;
}
// 寻找右节点
int find_right(vector<int> positions, int pos) {
bool is_right = false;
int len = positions.size();
if (pos >= len) {
return -1;
}
for(int i = 0;i < len;i++) {
if (positions[i] == pos+1) {
if (is_right) {
return i;
}
is_right = true;
}
}
return -1;
}
// 后序遍历
void post_order_traversal(vector<int> values, vector<int> positions) {
int len = values.size();
int pos = 0;
stack<int> pos_stack;
int max_sum = INT_MIN;
int last_visited_pos = 0;
while(pos < len || !pos_stack.empty()) {
while(pos < len) {
pos_stack.push(pos);
pos = find_left(positions, pos);
if (pos < 0) {
break;
}
}
pos = pos_stack.top();
if (find_right(positions, pos) < 0 || find_right(positions, pos) == last_visited_pos) {
// 处理
int left = find_left(positions, pos);
int right = find_right(positions, pos);
if (left >= 0 && right >= 0) {
max_sum = max(max_sum, values[pos] + values[left] + values[right]);
values[pos] = max(max(values[left], values[right]) + values[pos], 0);
} else if (left >= 0) {
max_sum = max(max_sum, values[pos] + values[left]);
values[pos] = max(max(values[pos], values[pos] + values[left]), 0);
} else if (right >= 0) {
max_sum = max(max_sum, values[pos] + values[right]);
values[pos] = max(max(values[pos], values[pos] + values[right]), 0);
} else {
max_sum = max(values[pos], max_sum);
values[pos] = max(0, values[pos]);
}
last_visited_pos = pos;
pos_stack.pop();
pos = len + 1;
} else {
pos = find_right(positions, pos);
}
}
cout << max_sum;
}
int main() {
int n;
cin >> n;
vector<int> value_list(n);
vector<int> pos_list(n);
for (int i = 0;i < n;i++) {
cin >> value_list[i];
}
for (int i = 0;i < n;i++) {
cin >> pos_list[i];
}
post_order_traversal(value_list, pos_list);
return 0;
}