题目链接:http://acm.split.hdu.edu.cn/showproblem.php?pid=4812
题意:给出一棵树,让你寻找一条路径,使得路径上的点相乘mod10^6+3等于k,输出路径的两个端点,按照字典序最小输出。
解法:这类问题很容易想到树的分治,每次找出树的重心,以重心为根,将树分成若干棵子树,然后对于每棵子树再一样的操作,现在就需要求一重心为根,寻找路径,依次遍历每一个子树,然后记录子树中点到根的权值的乘积X,然后通过在哈希表中寻找K×逆元(x),看是否存在,存在则更新答案,我这里用map来维护。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5+3;
const int mod = 1e6+3;
LL inv[mod];
struct edge{
int to,next;
edge(){}
edge(int to,int next):to(to),next(next){}
}E[maxn*2];
int head[maxn],edgecnt,mi;
void add(int u, int v){
E[edgecnt].to=v,E[edgecnt].next=head[u],head[u]=edgecnt++;
}
pair <int,int> ans;
int n,siz[maxn],mx[maxn],root,val[maxn];
LL K;
bool vis[maxn];
void init(){
memset(head,-1, sizeof(head));
edgecnt=0;
memset(vis,0,sizeof(vis));
}
void dfssize(int u, int fa){//处理子树的大小
siz[u]=1;
mx[u]=0;
for(int i=head[u];~i;i=E[i].next){
int v=E[i].to;
if(v!=fa&&!vis[v]){
dfssize(v, u);
siz[u]+=siz[v];
if(siz[v]>mx[u]) mx[u]=siz[v];
}
}
}
void dfsroot(int r, int u, int fa){//求重心
if(siz[r]-siz[u]>mx[u]) mx[u]=siz[r]-siz[u];
if(mx[u]<mi) mi=mx[u],root=u;
for(int i=head[u]; ~i; i=E[i].next){
int v=E[i].to;
if(v!=fa&&!vis[v]) dfsroot(r, v, u);
}
}
void cal(int u, int fa, int length, unordered_map<int, int>&status){
length = (LL)length*val[u]%mod;
int &s=status[length];
if(!s) s=u;
else s=min(s,u);
for(int i=head[u];~i;i=E[i].next){
int v = E[i].to;
if(v==fa || vis[v]) continue;
cal(v, u, length, status);
}
}
void solve(int u){
mi=n;
dfssize(u,0);
dfsroot(u,u,0);
u=root;
vis[u]=1;
unordered_map<int,int>preStatus;
preStatus[val[u]]=u;
for(int i=head[u];~i;i=E[i].next){
int v=E[i].to;
if(vis[v]) continue;
unordered_map<int, int>nowStatus;
cal(v,u,1,nowStatus);
for(auto &p:nowStatus){
int need = K*inv[p.first]%mod;
auto it = preStatus.find(need);
if(it==preStatus.end()) continue;
int x=p.second;
int y=it->second;
if(x>y) swap(x,y);
ans = min(ans, make_pair(x,y));
}
for(auto &p:nowStatus){
int length = ((LL)p.first*val[u])%mod;
int point = p.second;
auto it = preStatus.find(length);
if(it == preStatus.end())
preStatus[length] = point;
else
it->second = min(point,it->second);
}
}
for(int i=head[u]; ~i; i=E[i].next){
int v = E[i].to;
if(vis[v]) continue;
solve(v);
}
}
int main(){
inv[0]=inv[1]=1;
for(int i=2; i<mod; i++) inv[i] = (mod-mod/i)*inv[mod%i]%mod;
while(~scanf("%d%lld",&n,&K)){
init();
for(int i=1; i<=n; i++) scanf("%d", &val[i]);
for(int i=1; i<n; i++){
int u,v;
scanf("%d%d", &u,&v);
add(u,v);
add(v,u);
}
ans = make_pair(INT_MAX,INT_MAX);
solve(1);
if(ans.first==INT_MAX){
puts("No solution");
}else{
printf("%d %d\n", ans.first,ans.second);
}
}
return 0;
}