标签:启发式合并
分析
这样的三元组肯定是一个Y字形,两个点到其LCA后再向上走到当前点
考虑启发式合并,需要维护每个点到当前点的距离,记录还缺多少距离到第三个点即可
code
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<vector>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define dep(i,a,b) for(int i=a;i>=b;i--)
#define ll long long
#define mem(x,num) memset(x,num,sizeof x)
#define reg(x) for(int i=last[x];i;i=e[i].next)
using namespace std;
inline ll read(){
ll f=1,x=0;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
//**********head by yjjr**********
#define pb push_back
const int maxn=1e6+6;
ll Ans=0;int n,last[maxn],cnt=0;
vector<int> son[maxn];
struct edge{int to,next;}e[maxn<<1];
struct state{
vector<ll> dis,pr;
int f;
inline int size(){return dis.size();}
void clear(){dis.clear();pr.clear();f=0;}
}*s[maxn],T[maxn];
inline bool cmp(int a,int b){return s[a]->size()>s[b]->size();}
void insert(int u,int v){
e[++cnt]=(edge){v,last[u]};last[u]=cnt;
e[++cnt]=(edge){u,last[v]};last[v]=cnt;
}
void dfs(int x,int fa){
son[x].clear();
reg(x){
if(e[i].to==fa)continue;
son[x].push_back(e[i].to);dfs(e[i].to,x);
}
if(!son[x].size()){
s[x]=T+x;s[x]->dis.pb(1);s[x]->pr.pb(0);
return;
}
sort(son[x].begin(),son[x].end(),cmp);
s[x]=s[son[x][0]];
s[x]->dis.pb(1);
s[x]->f++;s[x]->pr.pb(0);s[x]->pr.pb(0);
Ans+=s[x]->pr[s[x]->f];
for(int i=1;i<son[x].size();i++){
int v=son[x][i],f=s[v]->f,fu=s[x]->f;
rep(j,0,f)Ans+=(s[v]->dis[f-j])*(s[x]->pr[j+fu+1]);
rep(j,1,f)Ans+=(s[v]->pr[f+j])*(s[x]->dis[fu-(j-1)]);
rep(j,0,f)s[x]->pr[j+fu+1]+=(s[v]->dis[f-j])*(s[x]->dis[fu-j-1]);
rep(j,1,f)s[x]->pr[j+fu-1]+=s[v]->pr[j+f];
rep(j,0,f)s[x]->dis[fu-(j+1)]+=s[v]->dis[f-j];
}
}
int main()
{
n=read();
rep(i,1,n-1){
int u=read(),v=read();
insert(u,v);
}
dfs(1,0);
cout<<Ans<<endl;
return 0;
}