题意:树链是指树里的一条路径。美团外卖的形象代言人袋鼠先生最近在研究一个特殊的最长树链问题。现在树中的每个点都有一个正整数值,他想在树中找出最长的树链,使得这条树链上所有对应点的值的最大公约数大于1。请求出这条树链的长度。

解法:枚举每个约数,保留对应的边,做一次最长路径。因为一个数的约数个数可以保证,所以复杂度符合要求。

下面是官方题解给的代码:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <map>
#include <vector>
using namespace std;
int n, l, c[100001], len, value[100001], visit[100001], V[100001];
bool b[100001];
map< int, vector<int> > factor;
struct node{
  node *next;
  int where;
} a[200001], *first[100001];
inline void makelist(int x, int y) {
  a[++l].where = y;
  a[l].next = first[x];
  first[x] = &a[l];
}
pair<int, int> bfs(int init, int v, int round) {
  c[1] = init; visit[init] = 1;
  int pos = 0, will = 0;
  int k = 1, l = 1;
  for (; l <= k; ++l)
  {
    int m = c[l];
    if (visit[m] > will)
      will = visit[m], pos = m;
    for (node *x = first[m]; x; x = x->next)
      if (!(value[x->where] % v) && !visit[x->where])
      {
        visit[x->where] = visit[m] + 1;
        c[++k] = x->where;
      }
  }
  if (round == 0) 
    for (int i = 1; i <= k; i++)
      visit[c[i]] = 0;
  return make_pair(pos, will);
}
int calc(int v) {
  vector<int> idx = factor[v];
  int will = 0;
  for (int i = 0; i < idx.size(); i++)
    if (!visit[idx[i]])
    {
      will = max(will, bfs(bfs(idx[i], v, 0).first, v, 1).second);
    }
  for (int i = 0; i < idx.size(); i++)
    visit[idx[i]] = 0;
  return will;
}

int main() {
  len = 0;
  memset(b, false, sizeof(b));
  for (int i = 2; i <= 100000; i++)
  {
    if (!b[i])
      c[++len] = i;
    for (int j = 1; j <= len; ++j)
      if (c[j] * i > 100000)
        break;
      else
      {
        b[c[j] * i] = true;
        if (!(i % c[j]))
          break;
      }
  }
  scanf("%d", &n);
  memset(first, 0, sizeof(first)); l = 0;
  for (int i = 1; i < n; i++)
  {
    int x, y;
    scanf("%d%d", &x, &y);
    makelist(x, y);
    makelist(y, x);
  }
  factor.clear();
  for (int i = 1; i <= n; i++)
  {
    int x;
    scanf("%d", &x);
    value[i] = x;
    for (int j = 1; c[j] * c[j] <= x; ++j)
      if (!(x % c[j]))
      {
        if (factor.find(c[j]) == factor.end())
          factor[c[j]].clear();
        factor[c[j]].push_back(i);
        for (; !(x % c[j]); )
          x /= c[j];
      }
    if (x != 1)
    {
      if (factor.find(x) == factor.end())
        factor[x].clear();
      factor[x].push_back(i);
    }
  }
  int ans = 0;
  memset(visit, 0, sizeof(visit));
  memset(V, 0, sizeof(V));
  for (map< int, vector<int> >::iterator itr = factor.begin(); itr != factor.end(); ++itr)
    ans = max(ans, calc(itr->first));
  printf("%d\n", ans);
}

然后还看到了一种直接BFS爆搜莽过去的,能过简直玄学。

#include <stdio.h> 
#include <stdlib.h> 
#include <string.h> 
#include <iostream>
#include <iostream> 
#include <cmath> 
#include<cstdio> 
#include<cstring> 
#include <algorithm>
#include <cctype>
#include <utility>
#include <map>
#include <string>
#include <cstdlib>
#include <queue>
#include <numeric>
#include <vector>
#include<set>
#include <cctype>
using namespace std;
const int maxn = 100050;
int maxlen = 0;
int rd[maxn];
int a[maxn];
vector<int>v[maxn];
inline bool scan(int &ret) {
    char c; int sgn;
    if (c = getchar(), c == EOF) return 0; //EOF 
    while (c != '-' && (c<'0' || c>'9')) c = getchar();
    sgn = (c == '-') ? -1 : 1;
    ret = (c == '-') ? 0 : (c - '0');
    while (c = getchar(), c >= '0'&&c <= '9') ret = ret * 10 + (c - '0');
    ret *= sgn;
    return 1;
}
struct node {
    int x;
    int len;
    int gcd;
};
void bfs(int x, int len, int gcd)
{
    node n1, n2;
    queue<node>q;
    n1.x = x;
    n1.len = len;
    n1.gcd = gcd;
    q.push(n1);
    while (!q.empty())
    {
        node n1 = q.front();
        q.pop();
        int nx = n1.x;
        int nlen = n1.len;
        int ngcd = n1.gcd;

        maxlen = max(nlen, maxlen);
        for(int i=0;i<v[nx].size();i++)
        {
            int next = v[nx][i];
            int usegcd = __gcd(a[next], ngcd);
            if (usegcd == 1)
            {
                n2.gcd = a[next];
                n2.len = 1;
                n2.x = next;
                q.push(n2);
            }
            else
            {
                n2.gcd = usegcd;
                n2.len = 1+nlen;
                n2.x = next;
                q.push(n2);
            }
        }
    }
}/* void dfs(int x, int len,int gcd) { maxlen = max(len, maxlen); for (int i = 0; i < v[x].size(); i++) { int next = v[x][i]; int usegcd = __gcd(a[next], gcd); if (usegcd ==1) dfs(next, 1, a[next]); else dfs(next, len+1, usegcd); } }*/
int main()
{

    int n;
    scan(n);
    for (int i = 2; i <= n; i++)
    {
        int a, b;
        scan(a);
        scan(b);
        v[a].push_back(b);
        rd[b]++;
    }
    int go = 0;
    for (int i = 1; i <= n; i++)
    {
        if (rd[i] == 0)
            go = i;
        scan(a[i]);
    }

    bfs(go,1,a[go]);

    printf("%d", maxlen);
    return 0;
}