【图论·习题】Cow at Large G(LCA+STL set)
Problem
题目描述
最后,Bessie被迫去了一个远方的农场。这个农场包含N个谷仓(2 <= N <= 105)和N-1条连接两个谷仓的双向隧道,所以每两个谷仓之间都有唯一的路径。每个只与一条隧道相连的谷仓都是农场的出口。当早晨来临的时候,Bessie将在某个谷仓露面,然后试图到达一个出口。
但当Bessie露面的时候,她的位置就会暴露。一些农民在那时将从不同的出口谷仓出发尝试抓住Bessie。农民和Bessie的移动速度相同(在每个单位时间内,每个农民都可以从一个谷仓移动到相邻的一个谷仓,同时Bessie也可以这么做)。农民们和Bessie总是知道对方在哪里。如果在任意时刻,某个农民和Bessie处于同一个谷仓或在穿过同一个隧道,农民就可以抓住Bessie。反过来,如果Bessie在农民们抓住她之前到达一个出口谷仓,Bessie就可以逃走。
Bessie不确定她成功的机会,这取决于被雇佣的农民的数量。给定Bessie露面的谷仓K,帮助Bessie确定为了抓住她所需要的农民的最小数量。假定农民们会自己选择最佳的方案来安排他们出发的出口谷仓。
输入格式
输入的第一行包含N和K。接下来的N – 1行,每行有两个整数(在1~N范围内)描述连接两个谷仓的一条隧道。
输出格式
输出为了确保抓住Bessie所需的农民的最小数量。
Solution
这道题其实考试的时候大致想出来了,最后再维护的时候卡壳了…
我们这里采用画图的方式来理解一下喽:
如果任意叶节点A和叶节点B,在选A的情况下要B不选,必然满足:
在这里Lca(a,b)表示A和B的最近公共祖先;变形一下,就是:
不等式的左边是Lca(a,b)到A的路径,右边是起点到Lac(a,b)的路径。如果我们把Lca(a,b)当做是交界口的话,一定是农夫到交界口的距离小于或等于才能够追上奶牛。这一点很容易想到,问题是农夫的人数怎么求。
我们可以每一次找到深度最小的那个点,其实也是贪心:因为要尽可能具体出口靠近来更容易抓住奶牛。然后把每一个满足上述不等式的叶子结点去掉,答案就是用了的点,或者n-去掉了的点。
在这里最值可以用STL priority来进行维护,但是不支持删除操作;我们可以使用自带判重功能、排序功能、删除和插入功能的STL set来进行维护,其中set因为具有判重功能,所以删除的erase离可以带某个具体数值,这里是pair类型。然后迭代器it在枚举时如果没有等到it++就把当前it删除,就会造成RE的大好局面,所以需要特殊处理一下。
代码如下:
#include<bits/stdc++.h>
#define make make_pair
using namespace std;
int n,k;
int in[200000];
int deep[200000];
int fa[200000][30];
vector<int>a[200000];
set< pair<int,int> >s;
void dfs(int x)
{
for (int i=0;i<a[x].size();++i)
{
int y=a[x][i];
if (fa[x][0] == y) continue;
fa[y][0]=x;
deep[y]=deep[x]+1;
dfs(y);
}
}
void dp(void)
{
fa[k][0]=-1;
for (int i=1;i<=20;++i)
for (int j=1;j<=n;++j)
if (fa[j][i-1] == -1) fa[j][i]=-1;
else fa[j][i]=fa[fa[j][i-1]][i-1];
}
int Lca(int u,int v)
{
if (deep[u]>deep[v]) swap(u,v);
for (int i=0,d=deep[v]-deep[u];d;i++,d>>=1)
if (d&1 == 1) v=fa[v][i];
if (u == v) return u;
for (int i=20;i>=0;--i)
if (fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int main(void)
{
freopen("atlarge.in","r",stdin);
freopen("atlarge.out","w",stdout);
scanf("%d %d",&n,&k);
for (int i=1;i<n;++i)
{
int x,y;
scanf("%d %d",&x,&y);
a[x].push_back(y);
a[y].push_back(x);
in[x] ++,in[y] ++;
}
dfs(k);
dp();
int ans=0;
for (int i=1;i<=n;++i)
if (in[i] == 1)
s.insert(make(deep[i],i));
while (s.size())
{
int u=(*s.begin()).second;
s.erase(s.begin());
for (set< pair<int,int> >::iterator it=s.begin();it!=s.end();)
{
int v=(*it).second,tmp=(*it).first;
it++;
if (deep[Lca(u,v)]*2>=deep[u]) s.erase(make(tmp,v));
}
ans++;
}
printf("%d\n",ans);
return 0;
}