Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Total Submission(s): 1445    Accepted Submission(s): 468


Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×avk.

Can you find the number of weak pairs in the tree?
 

Input
There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k, respectively.
  The second line contains N space-separated integers, denoting a1 to aN.
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.

  Constrains:
  
   1N105
  
   0ai109
  
   0k1018
 

Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
 

Sample Input
1 2 3 1 2 1 2
 

Sample Output
1

【题意】给你一棵有根树,一个定值k,以及树上每个结点的值a[i],对于有序对(u,v),如果(1)u是v的祖先,且(2)a[u]*a[v]<=k,则称该有序对(u,v)是弱的,问树中有多少对有序对(u,v)是弱的。

【分析&&推荐】可以去看这篇博客,写得非常好http://blog.csdn.net/queuelovestack/article/details/52505856

【解题方法】和上面这篇文章介绍的方法一样,直接给出我的代码了。

【AC 代码】

//
//Created by just_sort 2016/9/12 16:50
//Copyright (c) 2016 just_sort.All Rights Reserved
//

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn = 1e5+6;
const LL inf = 1e18;
int n,q;LL k;
LL a[maxn];
LL b[maxn*2];
LL ans;
int c[maxn*2];
bool vis[maxn];
int head[maxn],edgecnt;
struct edge{
    int v,next;
}E[maxn*2];
void init()
{
    edgecnt=0;
    memset(head,-1,sizeof(head));
}
void addedge(int u,int v)
{
    E[edgecnt].v=v,E[edgecnt].next=head[u],head[u]=edgecnt++;
}
int lowbit(int x){
    return x&(-x);
}
void update(int i,int v)
{
    while(i<=q){
        c[i]+=v;
        i+=lowbit(i);
    }
}
int query(int x)
{
    int ret=0;
    while(x){
        ret += c[x];
        x -= lowbit(x);
    }
    return ret;
}
void dfs(int u)
{
    ans += query(lower_bound(b,b+q,a[u]?k/a[u]:inf)-b+1);
    update(lower_bound(b,b+q,a[u])-b+1,1);
    for(int i=head[u]; ~i; i=E[i].next){
        int v = E[i].v;
        dfs(v);
    }
    update(lower_bound(b,b+q,a[u])-b+1,-1);
}
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        q = 0;
        init();
        memset(c,0,sizeof(c));
        memset(vis,false,sizeof(vis));
        memset(b,0,sizeof(b));
        scanf("%d%lld",&n,&k);
        for(int i=1; i<=n; i++){
            scanf("%I64d",&a[i]);
            b[q++] = a[i];
            if(!a[i]) b[q++]=inf;
            else b[q++]=k/a[i];
        }
        sort(b,b+q);
        q=unique(b,b+q)-b;
        for(int i=1; i<n; i++){
            int u,v;
            scanf("%d%d",&u,&v);
            addedge(u,v);
            //addedge(v,u);
            vis[v] = 1;
        }
        int ss=-1;
        for(int i=1; i<=n; i++){
            if(!vis[i]) ss=i;
        }
        ans = 0;
        dfs(ss);
        printf("%I64d\n",ans);
    }
    return 0;
}


【解法2】 利用Treap,复杂度是nlogn级别的,加个输入挂可以跑得很快。

【AC 代码】

//
//Created by just_sort 2016/9/12 16:50
//Copyright (c) 2016 just_sort.All Rights Reserved
//

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn = 1e5+6;
const LL inf = 1e18;
int n;
LL k,ans;
LL a[maxn];

struct FastIO
{
    static const int S = 1310720;
    int wpos;
    char wbuf[S];
    FastIO() : wpos(0) {}
    inline int xchar()
    {
        static char buf[S];
        static int len = 0, pos = 0;
        if (pos == len)
            pos = 0, len = fread(buf, 1, S, stdin);
        if (pos == len) return -1;
        return buf[pos ++];
    }
    inline int xuint()
    {
        int c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x;
    }
    inline int xint()
    {
        int s = 1, c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        if (c == '-') s = -1, c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x * s;
    }
    inline LL xlongint()
    {
        LL s = 1, c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        if (c == '-') s = -1, c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x * s;
    }
    inline void xstring(char *s)
    {
        int c = xchar();
        while (c <= 32) c = xchar();
        for (; c > 32; c = xchar()) * s++ = c;
        *s = 0;
    }
    inline void wchar(int x)
    {
        if (wpos == S) fwrite(wbuf, 1, S, stdout), wpos = 0;
        wbuf[wpos ++] = x;
    }
    inline void wint(LL x)
    {
        if (x < 0) wchar('-'), x = -x;
        char s[24];
        int n = 0;
        while (x || !n) s[n ++] = '0' + x % 10, x /= 10;
        while (n--) wchar(s[n]);
    }
    inline void wstring(const char *s)
    {
        while (*s) wchar(*s++);
    }
    ~FastIO()
    {
        if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0;
    }
} io;
/************************************Treap***************************************/
struct node
{
    int l,r,w,rnd,size;
    LL v;
};
node tr[100011];
int num=0,root=0,size=0;
void update(int k)
{
    tr[k].size=tr[tr[k].l].size+tr[tr[k].r].size+tr[k].w;
    return;
}
void rturn(int &k)
{
    int t;t=tr[k].l;
    tr[k].l=tr[t].r;tr[t].r=k;
    tr[t].size=tr[k].size;
    update(k);k=t;
    return;
}
void lturn(int &k)
{
    int t;t=tr[k].r;
    tr[k].r=tr[t].l;tr[t].l=k;
    tr[t].size=tr[k].size;
    update(k);k=t;
    return;
}
void insert(int &k,LL x)
{
    if (k==0)
    {
        ++num;k=num;tr[k].size=tr[k].w=1;
        tr[k].v=x;tr[k].rnd=rand();
        return;
    }
    tr[k].size++;
    if (x==tr[k].v) tr[k].w++;
    else
        if (x>tr[k].v)
        {
            insert(tr[k].r,x);
            if (tr[tr[k].r].rnd<tr[k].rnd) lturn(k);
        }
        else
        {
            insert(tr[k].l,x);
            if (tr[tr[k].l].rnd<tr[k].rnd) rturn(k);
        }
    return;
}
void del(int &k,LL x)
{
    if (k==0) return;
    if (x==tr[k].v)
    {
        if (tr[k].w>1)
        {
            tr[k].w--;tr[k].size--;return;
        }
        else
        {
            if (tr[k].l==0||tr[k].r==0) k=tr[k].l+tr[k].r;
            else
                if (tr[tr[k].l].rnd<tr[tr[k].r].rnd)
                {
                    rturn(k);del(k,x);
                }
                else
                {
                    lturn(k);del(k,x);
                }
        }
    }
    else
        if (x>tr[k].v) tr[k].size--,del(tr[k].r,x);
        else tr[k].size--,del(tr[k].l,x);
}

int query_rank(int k,LL x)
{
    if (k==0) return 0;
    if (x==tr[k].v) return tr[tr[k].l].size+1;
    if (x>tr[k].v) return tr[tr[k].l].size+tr[k].w+query_rank(tr[k].r,x);
    else return query_rank(tr[k].l,x);
}
int query_num(int k,int x)
{
    if (x<=tr[tr[k].l].size) return query_num(tr[k].l,x);
    if (x>tr[tr[k].l].size+tr[k].w) return query_num(tr[k].r,x-tr[tr[k].l].size-tr[k].w);
    return tr[k].v;
}
//int ans;
//void query_pro(int k,int x)//5
//{
//    if (k==0) return;
//    if (tr[k].v<x)
//    {
//        ans=k;query_pro(tr[k].r,x);
//    }
//    else query_pro(tr[k].l,x);
//}
//void query_sub(int k,int x)//6
//{
//    if (k==0) return;
//    if (tr[k].v>x)
//    {
//        ans=k;query_sub(tr[k].l,x);
//    }
//    else query_sub(tr[k].r,x);
//}
/******************************************************************************/
//
int head[maxn],edgecnt;
bool vis[maxn];
struct edge{
    int v,next;
}E[maxn*2];
void init(){
    edgecnt=0;
    memset(head,-1,sizeof(head));
}
void addedge(int u,int v){
    E[edgecnt].v=v,E[edgecnt].next=head[u],head[u]=edgecnt++;
}
void dfs(int u,int fa)
{
    LL exe = inf;
    if(a[u]!=0) exe = k/(a[u]);
    ans += query_rank(root,exe);
    insert(root,a[u]);
    for(int i=head[u]; ~i; i=E[i].next){
        int v=E[i].v;
        if(v==fa) continue;
        dfs(v,u);
    }
    del(root,a[u]);
}
int main()
{
    int T;
    T=io.xint();
    while(T--)
    {
        init();
        size=num=root=0;
        for(int i=1; i<maxn; i++) tr[i].l=tr[i].r=0;
        memset(vis,false,sizeof(vis));
        //scanf("%d%lld",&n,&k);
        n=io.xint();k=io.xlongint();
        for(int i=1; i<=n; i++) a[i]=io.xlongint();
        for(int i=1; i<n; i++){
            int u,v;
            //scanf("%d%d",&u,&v);
            u=io.xint(),v=io.xint();
            addedge(u,v);
            addedge(v,u);
            vis[v]=1;
        }
        int s=-1;
        for(int i=1; i<=n; i++){
            if(!vis[i]){
                s=i;
                break;
            }
        }
        ans = 0;
        dfs(s,0);
        //io.wint(ans);
        
        printf("%lld\n",ans);
    }
    return 0;
}