原题解链接:https://ac.nowcoder.com/discuss/181037

首先树形dpdp一下,用f[i]f[i]表示以ii为根的子树最大收益是多少。

询问的简单路径中,有些边是必须跑的,所以j把这些边从dpdp里面单独拿出来。然后再强制加回去就行了。

/*
* @Author: wxyww
* @Date:   2019-01-24 17:27:42
* @Last Modified time: 2019-01-25 17:25:41
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<bitset>
using namespace std;
typedef long long ll;
#define change(x) x & 1 ? x + 1 : x - 1
const int N = 300000 + 100,logN = 20;
ll read() {
   ll x=0,f=1;char c=getchar();
   while(c<'0'||c>'9') {
      if(c=='-') f=-1;
      c=getchar();
   }
   while(c>='0'&&c<='9') {
      x=x*10+c-'0';
      c=getchar();
   }
   return x*f;
}

struct node {
   int v,nxt,w,ww;
}e[N * 2];
int head[N],ejs;

void add(int u,int v,int w,int c) {
   e[++ejs].w = w;e[ejs].ww = c;e[ejs].v = v;e[ejs].nxt = head[u];head[u] = ejs;
}

ll sumadd[N],sumreduce[N],f[N],lca[N][logN],mon[N][logN],kc[N][logN],red[N];
int n;
ll W[N];

void dp(int u,int father) {
   for(int i = head[u];i;i = e[i].nxt) {
      int v = e[i].v;
      if(v == father) continue;
      sumadd[v] += sumadd[u] + e[i].w;
      sumreduce[v] += sumreduce[u] + e[i].ww; 
      W[v] = e[i].w;
      dp(v,u);
      red[v] = max(0ll,f[v] - e[i].ww + e[i].w);
      f[u] += red[v];
   }
}
int dep[N];

void get_lca(int u,int father) {
    dep[u] = dep[father] + 1;
   for(int i = 1;i < logN;++i) {
      lca[u][i] = lca[lca[u][i - 1]][i - 1];
      mon[u][i] = mon[u][i - 1] + mon[lca[u][i - 1]][i - 1];
      kc[u][i] = kc[u][i - 1] + kc[lca[u][i - 1]][i - 1];
   }
   for(int i = head[u];i;i = e[i].nxt) {
      int v = e[i].v;
      if(v == father) continue;
      lca[v][0] = u;
      mon[v][0] = f[v];
      kc[v][0] = red[v];
      get_lca(v,u);
   }
}

ll query(int x,int y) {
    int S = x,T = y;
   ll ans = 0;
    if(dep[x] < dep[y]) swap(x,y);
    // printf("%d %d\n",dep[x],dep[y]);
   for(int i = logN - 1;i >= 0;--i) {
      if(dep[lca[x][i]] >= dep[y]) {
         ans += mon[x][i];ans -= kc[x][i];
         x = lca[x][i];
      }
   }
   for(int i = logN - 1;i >= 0;--i) {
      if(lca[x][i] != lca[y][i]) {
         ans += mon[x][i] + mon[y][i];
         ans -= kc[x][i] + kc[y][i];
         x = lca[x][i];y = lca[y][i];
      }
   }
   if(x != y) ans -= kc[x][0] + kc[y][0],ans += mon[x][0] + mon[y][0],x = lca[x][0];
   // printf("%lld\n",ans);
   ans += f[x];
   ans += sumadd[S] - sumadd[x] + sumadd[T] - sumadd[x];
   ans -= sumreduce[S] - sumreduce[x];
    return ans;
}

int main() {
   // freopen("t.in","r",stdin);
   // freopen("1.out","w",stdout);
   n = read();int Q = read();
   for(int i = 1;i < n;++i) {
      int u = read(),v = read(),w = read(),c = read();
      add(u,v,w,c);add(v,u,w,c);
   }
   dp(1,0);
   get_lca(1,0);
   // for(int i = 1;i <= n;++i) printf("%lld ",f[i]);
   while(Q--) {
   	int S = read(),T = read();
   	printf("%lld\n",query(S,T));
   }
   return 0;
}
/*
7 10
1 2 5 3 
2 6 3 4
3 2 4 5
2 4 7 3
1 5 4 2
5 7 4 2

4 10
1 2 10 10000
2 3 0 10
3 4 11 0

2 1 
1 2 1 1 
2 1

3 1
1 2 2 1
1 3 2 1
3 2
*/