[HDU2196]Computer(树形dp+二次扫描换根法)
传送门
题意:
给出一棵树,求离每个节点最远的点的距离
一开始以为是树的直径。。然后看清题意之后就可以容易看出是树形dp了,对于这种无根树且需要求每一个点的情况需要运用二次扫描换根法。那么我们来设列dp方程吧,我们思考当前点x的最远点距离是怎么得到的,只有两种情况:
1、来自他的子树(图中红色)
2、来自他的子树以外的树(图中蓝色,以下简称父亲部)
第一种情况的话可以直接自底向上树形dp得到每一个节点的子树的最远点距离。 那么第二种情况就有点难办,父亲部的最远点距离可以从哪里来呢?有两种情况:
1、父亲点fa的父亲部
2、父亲点的子树
对于第二种情况的话会有一种情况需要考虑,想到这又不大家应该也会发现,父亲部的子树可能包括红色的部分,如果我们冒冒然去继承,那么就会造成没法继承到蓝色部分的解。那么怎么办呢,我们需要判断一下,假如fa的最远点路径经过了x,那么我们就不继承他,改为继承fa子树的次远点距离。
那么我们可以设列dp方程了,我们设f[i][0]为i节点子树的最远点距离,f[i][1]为i节点子树的次远点距离,设f[i][2]为i节点的父亲部的最远点距离。那么我们列出dp方程:
当x不在fa的最远点路径上:
当x在fa的最远点路径上:
但是但是!!还没有做完!问题又来了(可能大家都发现了这个大问题),假如次远路径还经过x该怎么办呢??
那么我们可以巧妙地避开这种情况,我们看一下代码:
void dp(int x,int fa)
{
for(int k=last[x];k;k=a[k].next)
{
int y=a[k].y; if(y==fa) continue;
dp(y,x);
if(f[x][0]<f[y][0]+a[k].c)
{
t[x][0]=y;
f[x][1]=f[x][0];
f[x][0]=f[y][0]+a[k].c;
}
else if(f[x][1]<f[y][0]+a[k].c)
{
f[x][1]=f[y][0]+a[k].c;
t[x][1]=y;
}
}
}
这是一个自底向上的树形dp的部分,t数组维护最远和次远路径。
为什么这样就规避了那种情况呢?因为我们维护次远点的时候,我们通过两种方式维护,第一种方式是在最远点距离发生改变的时候,我们将原来的最远点距离就变成了次远点距离,因为我们是自底向上dp,那么对于x的孩子y,原来的最远点距离一定是y的子树之外的,所以用这个来更新最远点距离就不会关于y的子树了。
那么我们来总结一下二次扫描换根法吧:在我们需要在一个无根树上多个节点为根统计答案的时候就可以用到二次扫描换根法,具体的操作是通过实现一次自底向上的dp和一次自顶向下的dfs,通过维护上面那个图(没错那就是基本模型)来计算“换根”后的解。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
const int N=10010;
struct node
{
int x,y,next;
ll c;
}a[N*2]; int len,last[N];
ll f[N][3];
int t[N][2];
int n;
void ins(int x,int y,ll c)
{
len++;
a[len].x=x;a[len].y=y;a[len].c=c;
a[len].next=last[x];last[x]=len;
}
void dp(int x,int fa)
{
for(int k=last[x];k;k=a[k].next)
{
int y=a[k].y; if(y==fa) continue;
dp(y,x);
if(f[x][0]<f[y][0]+a[k].c)
{
t[x][0]=y;
f[x][1]=f[x][0];
f[x][0]=f[y][0]+a[k].c;
}
else if(f[x][1]<f[y][0]+a[k].c)
{
f[x][1]=f[y][0]+a[k].c;
t[x][1]=y;
}
}
}
void dfs(int x,int fa)
{
for(int k=last[x];k;k=a[k].next)
{
int y=a[k].y; if(y==fa) continue;
if(t[x][0]!=y)
f[y][2]=max(f[x][0],f[x][2])+a[k].c;
else
f[y][2]=max(f[x][1],f[x][2])+a[k].c;
dfs(y,x);
}
}
int main()
{
while(~scanf("%d",&n))
{
memset(last,0,sizeof(last)); len=0;
for(int x=2;x<=n;x++)
{
int y; ll c;scanf("%d%lld",&y,&c);
ins(x,y,c); ins(y,x,c);
}
memset(f,0,sizeof(f));
memset(t,0,sizeof(t));
dp(1,0);
f[1][2]=f[1][1]; dfs(1,0);
for(int i=1;i<=n;i++)
{
printf("%lld\n",max(f[i][0],f[i][2]));
}
}
return 0;
}