原题解链接:https://ac.nowcoder.com/discuss/181037
首先树形一下,用表示以为根的子树最大收益是多少。
询问的简单路径中,有些边是必须跑的,所以j把这些边从里面单独拿出来。然后再强制加回去就行了。
/*
* @Author: wxyww
* @Date: 2019-01-24 17:27:42
* @Last Modified time: 2019-01-25 17:25:41
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<bitset>
using namespace std;
typedef long long ll;
#define change(x) x & 1 ? x + 1 : x - 1
const int N = 300000 + 100,logN = 20;
ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9') {
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
struct node {
int v,nxt,w,ww;
}e[N * 2];
int head[N],ejs;
void add(int u,int v,int w,int c) {
e[++ejs].w = w;e[ejs].ww = c;e[ejs].v = v;e[ejs].nxt = head[u];head[u] = ejs;
}
ll sumadd[N],sumreduce[N],f[N],lca[N][logN],mon[N][logN],kc[N][logN],red[N];
int n;
ll W[N];
void dp(int u,int father) {
for(int i = head[u];i;i = e[i].nxt) {
int v = e[i].v;
if(v == father) continue;
sumadd[v] += sumadd[u] + e[i].w;
sumreduce[v] += sumreduce[u] + e[i].ww;
W[v] = e[i].w;
dp(v,u);
red[v] = max(0ll,f[v] - e[i].ww + e[i].w);
f[u] += red[v];
}
}
int dep[N];
void get_lca(int u,int father) {
dep[u] = dep[father] + 1;
for(int i = 1;i < logN;++i) {
lca[u][i] = lca[lca[u][i - 1]][i - 1];
mon[u][i] = mon[u][i - 1] + mon[lca[u][i - 1]][i - 1];
kc[u][i] = kc[u][i - 1] + kc[lca[u][i - 1]][i - 1];
}
for(int i = head[u];i;i = e[i].nxt) {
int v = e[i].v;
if(v == father) continue;
lca[v][0] = u;
mon[v][0] = f[v];
kc[v][0] = red[v];
get_lca(v,u);
}
}
ll query(int x,int y) {
int S = x,T = y;
ll ans = 0;
if(dep[x] < dep[y]) swap(x,y);
// printf("%d %d\n",dep[x],dep[y]);
for(int i = logN - 1;i >= 0;--i) {
if(dep[lca[x][i]] >= dep[y]) {
ans += mon[x][i];ans -= kc[x][i];
x = lca[x][i];
}
}
for(int i = logN - 1;i >= 0;--i) {
if(lca[x][i] != lca[y][i]) {
ans += mon[x][i] + mon[y][i];
ans -= kc[x][i] + kc[y][i];
x = lca[x][i];y = lca[y][i];
}
}
if(x != y) ans -= kc[x][0] + kc[y][0],ans += mon[x][0] + mon[y][0],x = lca[x][0];
// printf("%lld\n",ans);
ans += f[x];
ans += sumadd[S] - sumadd[x] + sumadd[T] - sumadd[x];
ans -= sumreduce[S] - sumreduce[x];
return ans;
}
int main() {
// freopen("t.in","r",stdin);
// freopen("1.out","w",stdout);
n = read();int Q = read();
for(int i = 1;i < n;++i) {
int u = read(),v = read(),w = read(),c = read();
add(u,v,w,c);add(v,u,w,c);
}
dp(1,0);
get_lca(1,0);
// for(int i = 1;i <= n;++i) printf("%lld ",f[i]);
while(Q--) {
int S = read(),T = read();
printf("%lld\n",query(S,T));
}
return 0;
}
/*
7 10
1 2 5 3
2 6 3 4
3 2 4 5
2 4 7 3
1 5 4 2
5 7 4 2
4 10
1 2 10 10000
2 3 0 10
3 4 11 0
2 1
1 2 1 1
2 1
3 1
1 2 2 1
1 3 2 1
3 2
*/