Tree POJ - 1741 (树分治入门)
参考论文:https://wenku.baidu.com/view/8861df38376baf1ffc4fada8.html
这题的两个关键部分: 1. 求出树的重心(这在之前的博客里面提到,不在叙述)
2. 如何进行分治。
首先,如果我们选定一个点作为根节点,那么一条路径要么经过这个根节点,要么不经过这个根节点。
题目要求出所有路径小于等于k的路径,我们可以这样考虑(先不考虑重复的),依次把每个点都当做根节点,那么我们只要dfs一遍就可以求出所有点到根节点的距离,然后考虑两两组合,只要符合,我们就统计。
上面的思路必然会有重复的,第一,在同一子树上的两个点,都连向根的话,必然重复了一段路径,所以这部分应该删除。
所以我们的思路就是先直接求出以一个点为根节点的所有满足个数,然后递归到子树,减去在子树中满足的(在子树中满足,对应于前面的,就是在同一颗子树重复的)。
对于以下这棵树:
显然A点是它的重心。
我们假设现在分治到了A点(当前点为A)
我们一开始求解贡献时,会有以下路径被处理出来:
A—>A
A—>B
A—>B—>C
A—>B—>D
A—>E
A—>E—>F (按照先序遍历顺序罗列)
那么我们在合并答案是会将上述6条路径两两进行合并。
这是注意到:
合并A—>B—>C 和 A—>B—>D 肯定是不合法的!!
因为这并不是一条树上(简单)路径,出现了重边,我们要想办法把这种情况处理掉。
处理方法很简单,减去每个子树的单独贡献。
例如对于以B为根的子树,就会减去:
B—>B
B—>C
B—>D
这三条路径组合的贡献
接着就是代码了(应该看代码更直观一点)
#include<iostream>
#include<map>
#include<string>
#include<cstring>
#include<vector>
#include<algorithm>
#include<set>
#include<sstream>
#include<cstdio>
#include<cmath>
#include<climits>
using namespace std;
const int maxn=1e4+7;
const int inf=0x3f3f3f3f;
typedef long long ll;
const int mod=1e9+7;
int n,k,allnode;
int head[maxn*2];
int num;
int dp[maxn];
int size[maxn];
int Focus,M;
ll dist[maxn];
int deep[maxn];
bool vis[maxn];
ll ans;
struct Edge
{
int u,v,w,next;
}edge[maxn<<2];
void addEdge(int u,int v,int w)
{
edge[num].u=u;
edge[num].v=v;
edge[num].w=w;
edge[num].next=head[u];
head[u]=num++;
}
void init()
{
memset(head,-1,sizeof(head));
memset(dist,0,sizeof(dist));
memset(vis,0,sizeof(vis));
num=0;
}
void getFocus(int u,int pre)
{
size[u]=1;
dp[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==pre||vis[v]) continue;
getFocus(v,u);
size[u]+=size[v];
dp[u]=max(dp[u],size[v]);
}
dp[u]=max(dp[u],allnode-size[u]);
if(M>dp[u])
{
M=dp[u];
Focus=u;
}
}
void dfs(int u,int pre)
{
deep[++deep[0]]=dist[u];
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v,w=edge[i].w;
if(v==pre||vis[v]) continue;
dist[v]=dist[u]+w;
dfs(v,u);
}
}
int cal(int x,int now)
{
dist[x]=now,deep[0]=0;
dfs(x,0);
sort(deep+1,deep+deep[0]+1);
int ans=0;
for(int l=1,r=deep[0];l<r;)
{
if(deep[l]+deep[r]<=k)
{
ans+=r-l;
l++;
}
else r--;
}
return ans;
}
void solve(int x)
{
vis[x]=1;
ans+=cal(x,0);
for(int i=head[x];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v]) continue;
ans-=cal(v,edge[i].w);
allnode=size[v];
Focus=0,M=1e9;
getFocus(v,x);
solve(Focus);
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
while(scanf("%d%d",&n,&k)!=EOF&&(n+k))
{
init();
for(int i=1,u,v,w;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
addEdge(u,v,w);
addEdge(v,u,w);
}
Focus=ans=0;
allnode=n,M=1e9;
getFocus(1,0);
solve(Focus);
printf("%lld\n",ans);
}
return 0;
}