【线段树】【LCA的RMQ求法】【树上路径求交】车站
链接:https://ac.nowcoder.com/acm/contest/368/E
来源:牛客网
很好的一道题,一开始没有想到从路径来建立线段树,总想着从原树上搞,还觉得和树链剖分有点相似,后来看了题解才知道自己完全想错了。
车站
线段树+倍增+LCA。
首先车站一定在所有铁路的经过的点的交集上,所以可以用线段树求出区间路径的交集,同时维护离路径交的两个端点最远的点。找到出了最远的两个点后,那么可以倍增求出车站的位置。
具体地,线段树每个节点维护u,v,du,dv,u,v是路径交,du,dv是分别是距离u,v最远的点。合并两个区间的过程是先对两个区间的路径求路径交,然后新的du,dv一定是两个区间中四个du,dv中的两个。
对一段区间询问,先求出区间的u,v,du,dv,于是可以选择路径du->dv上中间的位置作为车站,如果中间点不在路径交集上,那么让u,v中的一个作为车站。
上面是官方题解,说得非常清晰了,值得注意的是它以路径为基础建立了线段树,和原树没有太多关系,相当精彩。另外,路径的维护方面,想到只维护四个点,也是很不容易。
补充一些细节部分。两个树上路径求交有个固定算法:倘若存在交,路径a的两个端点和路径b的两个端点两两求LCA,然后取四个lca中深度最大的那两个点,它们之间就是交路径。而当这两个点相同,并且深度小于两个路径之一时,则说明不存在交。(怎么证?唉,直观感受感受吧)
另外(du,dv)是在(u,v)的两侧,并且中间位置如果在(u,v)之间,就是满足的车站。这个有些不好理解。事实上,假设(u,v)中有一点p,p到最远的点距离必然就是max(dis(p,du),dis(p,dv)),否则有其他情况的话会和du,dv的定义相矛盾。而p每向dv靠近1,则向du远离1,当然是两者平均的时候效果最好。候选点可能有一个可能有两个,最后判断哪个编号更小。
寻找答案时注意点在二叉路径的(x,lca)还是(lca,y)上,需要用到倍增法
用倍增的LCA的话会TLE(也可能我写得太臭)无奈之下写了RMQ版本的LCA,查询的效率基本上接近O(1)
#include <cstdio>
#include <algorithm>
#include <vector>
#define kl (k<<1)
#define kr (k<<1|1)
#define M (L+R>>1)
#define lin L,M
#define rin M+1,R
#define ept (Road){-1,-1}
using namespace std;
using Road=pair<int,int>;
struct node
{
Road r,d;
}T[1<<18],res;
int n,m,u,v,q,o,l,r,x,d1,d2,d,ans1,ans2,lg2[200005];
vector<int> E[100005];
int dep[100005],p[100005][17];
int stn,rdfn[100005],st[200005][18];
void dfs(int x,int fa) //倍增的预处理和LCA(RMQ)的预处理
{
st[++stn][0]=x;
rdfn[x]=stn;
for(int &j:E[x])
if(j!=fa)
{
dep[j]=dep[x]+1;
p[j][0]=x;
for(int k=1;1<<k<=dep[j];k++)
p[j][k]=p[p[j][k-1]][k-1];
dfs(j,x);
st[++stn][0]=x;
}
}
void RMQ_init() //深搜序列中两个位置之间深度最小的那个点就是LCA
{
for(int j=1;1<<j<=stn;j++)
for(int i=1;1<<j<=stn-i+1;i++)
st[i][j]=min(st[i][j-1],st[i+(1<<j-1)][j-1],[&](int x,int y) {
return dep[x]<dep[y];
});
}
int LCA(int a,int b)
{
if(rdfn[a]>rdfn[b])
swap(a,b);
int t=lg2[rdfn[b]-rdfn[a]+1];
return min(st[rdfn[a]][t],st[rdfn[b]-(1<<t)+1][t],[&](int x,int y) {
return dep[x]<dep[y];
});
}
int dis(int a,int b,int lca=0) //两点距离
{
return dep[a]+dep[b]-dep[lca==0?LCA(a,b):lca]*2;
}
Road Cross(Road a,Road b) //求树上路径的交
{
if(a==ept||b==ept)
return ept;
int w[4]={LCA(a.first,b.first),LCA(a.first,b.second),
LCA(a.second,b.first),LCA(a.second,b.second)};
sort(w,w+4,[&](int &u,int &v) {
return dep[u]>dep[v];
});
if(w[0]==w[1]&&(dep[w[0]]<dep[LCA(a.first,a.second)]||dep[w[0]]<dep[LCA(b.first,b.second)]))
return ept;
return {w[0],w[1]};
}
void maintain(node &a,const node &b,const node &c) //线段树的维护函数
{
a.r=Cross(b.r,c.r);
if(a.r==ept)
return;
int w[4]={b.d.first,b.d.second,c.d.first,c.d.second};
a.d.first=*max_element(w,w+4,[&](int &x,int &y) {
return dis(a.r.first,x)<dis(a.r.first,y);
});
a.d.second=*max_element(w,w+4,[&](int &x,int &y) {
return dis(a.r.second,x)<dis(a.r.second,y);
});
}
void build_tree(int k,int L,int R)
{
if(L==R)
{
scanf("%d%d",&T[k].r.first,&T[k].r.second);
T[k].d={T[k].r.second,T[k].r.first};
return;
}
build_tree(kl,lin);
build_tree(kr,rin);
maintain(T[k],T[kl],T[kr]);
}
void modify(int k,int L,int R)
{
if(L==R)
{
T[k].r={u,v};
T[k].d={v,u};
return;
}
if(x<=M)
modify(kl,lin);
else
modify(kr,rin);
maintain(T[k],T[kl],T[kr]);
}
node query(int k,int L,int R)
{
if(l<=L&&R<=r)
return T[k];
node res;
if(l<=M&&r>M)
return maintain(res,query(kl,lin),query(kr,rin)),res;
else if(l<=M)
return query(kl,lin);
else
return query(kr,rin);
}
int get_Ans(Road x,int d) //倍增法通过距离求点
{
if(dep[x.first]<dep[x.second])
swap(x.first,x.second);
int lca=LCA(x.first,x.second);
if(dep[x.first]-dep[lca]<d) //判断在路径的左半部分还是右半部分
d=dis(x.first,x.second,lca)-d,x.first=x.second,x.second=lca;
else
x.second=lca;
int j;
for(j=lg2[d];d;j=lg2[d])
x.first=p[x.first][j],d-=1<<j;
return x.first;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
scanf("%d%d",&u,&v),E[u].push_back(v),E[v].push_back(u);
for(int i=2;i<=n<<1;i++)
lg2[i]=lg2[i>>1]+1;
dfs(1,0);
RMQ_init();
build_tree(1,1,m);
scanf("%d",&q);
while(q--)
{
scanf("%d",&o);
if(o==1)
{
scanf("%d%d",&l,&r);
res=query(1,1,m);
if(res.r==ept)
{
puts("-1");
continue;
}
if(res.d.first==res.d.second) //du,dv在同一侧,特判掉
{
printf("%d\n",dis(res.d.first,res.r.first)<dis(res.d.first,res.r.second)?res.r.first:res.r.second);
continue;
}
d1=dis(res.d.first,res.d.second);
d2=dis(res.d.second,res.r.first);
d=min(max(d1/2-d2,0),dis(res.r.first,res.r.second));
ans1=get_Ans(res.r,d);
ans2=get_Ans(res.r,min(d+1,dis(res.r.first,res.r.second)));
printf("%d\n",max(dis(ans1,res.d.first),dis(ans1,res.d.second))==
max(dis(ans2,res.d.first),dis(ans2,res.d.second))&&ans2<ans1?ans2:ans1);
}
else
{
scanf("%d%d%d",&x,&u,&v);
modify(1,1,m);
}
}
return 0;
}