标签:主席树,LCA
题目
Description 兵库县位于日本列岛的中央位置,北临日本海,南面濑户内海直通太平洋,中央部位是森林和山地,与拥有关西机场的大阪府比邻而居,是关西地区面积最大的县,是集经济和文化于一体的一大地区,是日本西部门户,海陆空交通设施发达。濑户内海沿岸气候温暖,多晴天,有日本少见的贸易良港神户港所在的神户市和曾是豪族城邑“城下町”的姬路市等大城市,还有以疗养地而闻名的六甲山地等。 兵库县官方也大力发展旅游,为了方便,他们在县内的N个旅游景点上建立了n-1条观光道,构成了一棵图论中的树。同时他们推出了M条观光线路,每条线路由两个节点x和y指定,经过的旅游景点就是树上x到y的唯一路径上的点。保证一条路径只出现一次。 你和你的朋友打算前往兵库县旅游,但旅行社还没有告知你们最终选择的观光线路是哪一条(假设是线路A)。这时候你得到了一个消息:在兵库北有一群丧心病狂的香菜蜜,他们已经选定了一条观光线路(假设是线路B),对这条路线上的所有景点都释放了【精神污染】。这个计划还有可能影响其他的线路,比如有四个景点1-2-3-4,而【精神污染】的路径是1-4,那么1-3,2-4,1-2等路径也被视为被完全污染了。 现在你想知道的是,假设随便选择两条不同的路径A和B,存在一条路径使得如果这条路径被污染,另一条路径也被污染的概率。换句话说,一条路径被另一条路径包含的概率。
Input 第一行两个整数N,M 接下来N-1行,每行两个数a,b,表示A和B之间有一条观光道。 接下来M行,每行两个数x,y,表示一条旅游线路。
Output 所求的概率,以最简分数形式输出。
Sample Input 5 3
1 2
2 3
3 4
2 5
3 5
2 5
1 4 Sample Output 1/3
样例解释
可以选择的路径对有(1,2),(1,3),(2,3),只有路径1完全覆盖路径2。 HINT
100%的数据满足:N,M<=100000
分析
题意:给定一棵树上M条路径,求任选两条路径时一条被另一条包含的概率是多少
先吐槽毒瘤出题人,真的是精神污染orz
显然,一条路径被另一条路径所包含,那么这条路径的两个端点都在另一条路径上
对于每个节点x,开个vector把另一个端点y存进去
因为主席树是函数式的线段树,相当于前缀和
主席树中每一个节点x都是一个序列,序列的下标表示的是每个节点的入栈序和出栈序,那么将节点x的vector中的y节点入栈序和出栈序在它的序列中分别+1和-1打上标记
假设存在路径x->y,x->z,那么当后者包含前者的时候,y的入栈序一定在z的入栈序和出战序之间
y的出栈序在z的出战序之后,那么可以直接查询z的入出栈序之间的权值和,就是有多少条一端是x节点的路径被x->z所包含
在树上建立主席树,主席树的前缀和就表示从根节点到节点x的路径上序列的和
查询时就相当于查询起点在根节点到x路径上,终点在lca(x,y)与y之间的路径个数
记t为lca(x,y),那么被路径x–>y包含的路径数目就是,$ans=query(in[x])+query(in[y])-query(in[z])-query(in[fa[t]])-1$
这题被卡了几次内存,差评++,毒瘤++
code
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<vector>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define dep(i,a,b) for(int i=a;i>=b;i--)
#define ll long long
#define mem(x,num) memset(x,num,sizeof x)
#define reg(x) for(int i=last[x];i;i=e[i].next)
using namespace std;
inline ll read()
{
ll f=1,x=0;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline ll gcd(ll a,ll b){return b==0?a:gcd(b,a%b);}
const int maxn=1e5+6,maxk=4e6+6;
int ls[maxk],rs[maxk],sum[maxk],n,m,cnt=0,ind,sz;
int last[maxn],deep[maxn],root[200006],in[maxn],out[maxn],fa[maxn][17];
vector<int>a[maxn];
struct query{int x,y;}que[maxn];
struct edge{int to,next;}e[200006];
ll A,B;
bool operator < (query a,query b){return a.x==b.x?a.y<b.y:a.x<b.x;}
void insert(int u,int v){
e[++cnt]=(edge){v,last[u]};last[u]=cnt;
e[++cnt]=(edge){u,last[v]};last[v]=cnt;
}
int insert(int x,int l,int r,int pos,int v){
int t=++sz,mid=(l+r)>>1;ls[t]=ls[x];rs[t]=rs[x];
if(l==r){sum[t]=sum[x]+v;return t;}
if(pos<=mid)ls[t]=insert(ls[t],l,mid,pos,v);else rs[t]=insert(rs[t],mid+1,r,pos,v);
sum[t]=sum[ls[t]]+sum[rs[t]];return t;
}
int query(int p1,int p2,int p3,int p4,int l,int r,int L,int R){
int mid=(l+r)>>1;
if(l==L&&r==R)return sum[p1]+sum[p2]-sum[p3]-sum[p4];
if(R<=mid)return query(ls[p1],ls[p2],ls[p3],ls[p4],l,mid,L,R);
else if(L>mid)return query(rs[p1],rs[p2],rs[p3],rs[p4],mid+1,r,L,R);
else return query(ls[p1],ls[p2],ls[p3],ls[p4],l,mid,L,mid)+query(rs[p1],rs[p2],rs[p3],rs[p4],mid+1,r,mid+1,R);
}
void dfs(int x){
in[x]=++ind;
rep(i,1,16)if((1<<i)<=deep[x])fa[x][i]=fa[fa[x][i-1]][i-1];else break;
reg(x)
if(e[i].to!=fa[x][0]){
fa[e[i].to][0]=x;
deep[e[i].to]=deep[x]+1;
dfs(e[i].to);
}
out[x]=++ind;
}
void dfs2(int x){
root[x]=root[fa[x][0]];
for(int i=0;i<a[x].size();i++){
root[x]=insert(root[x],1,ind,in[a[x][i]],1);
root[x]=insert(root[x],1,ind,out[a[x][i]],-1);
}
reg(x)if(e[i].to!=fa[x][0])dfs2(e[i].to);
}
inline int lca(int x,int y){
if(deep[x]<deep[y])swap(x,y);
int t=deep[x]-deep[y];
rep(i,0,16)if((1<<i)&t)x=fa[x][i];
dep(i,16,0)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
if(x==y)return x;else return fa[x][0];
}
void solve(){
rep(i,1,m){
int x=que[i].x,y=que[i].y,t=lca(x,y);
A+=query(root[x],root[y],root[t],root[fa[t][0]],1,ind,in[t],in[x]);
A+=query(root[x],root[y],root[t],root[fa[t][0]],1,ind,in[t],in[y]);
A-=query(root[x],root[y],root[t],root[fa[t][0]],1,ind,in[t],in[t]);
A--;
}
}
int main()
{
n=read(),m=read();
rep(i,1,n-1){
int u=read(),v=read();insert(u,v);
}
rep(i,1,m){
int x=read(),y=read();a[x].push_back(y);
que[i].x=x,que[i].y=y;
}
sort(que+1,que+1+m);
dfs(1);dfs2(1);solve();int j;
for(int i=1;i<=m;i=j)
for(j=i;que[j].x==que[i].x&&que[j].y==que[i].y;j++)A-=(ll)(j-i)*(j-i-1)/2;
B=(ll)m*(m-1)/2;
ll t=gcd(A,B);A/=t,B/=t;
cout<<A<<'/'<<B<<endl;
return 0;
}