ST算法与LCA

RMQ问题

RMQ(Range Minimum/Maximum Query),即区间最值查询。对于长度为n的数列arr,回答若干询问Q(i,j),返回数列arr中下标在i,j之间的最大/小值。如果只有一次询问,那一遍for就可以搞定,但是如果有多次询问就无法在很快的时间处理出来。

ST算法

ST算法是一个在线算法,它可以在O(nlogn)时间内进行预处理,然后在O(1)的时间内回答每个查询,假设现在的数组为arr[] = {1,3,6,7,4,2,5,9},算法步骤如下:

一、预处理(以处理区间最小值为例)

dp[i][j]dp[i][j]表示从第i位开始连续2j2^j个数(也就是到i+2j1i+2^j-1)中的最小值。例如dp[2][1]dp[2][1]表示从第2个数开始,连续2个数的最小值,即3,6之间的最小值,即dp[2][1]=3dp[2][1]=3,从dp数组的含义我们就知道,dp[i][0]=arr[i]dp[i][0]=arr[i](下标均是从1开始),初值有了,剩下的就是状态转移方程。首先把dp[i][j]dp[i][j]平均分成两段(因为一定是偶数个数字),从i到i+2j11i+2^{j-1}-1为一段,i+2j1i+2{j-1}i+2j1i+2^j-1为一段(每段长度都为2j12^{j-1})。假设i=1,j=3时就是1,3,6,7和4,2,5,9这两段。dp[i][j]dp[i][j]就是这两段最大值的最大值。于是得到了状态转移方程式**dp[i][j]=max(dp[i][j1],dp[i+2j1][j1])dp[i][j] = max(dp[i][j-1],dp[i+2^{j-1}][j-1])**

for(int i = 1;i <= n;i++)
    dp[i][0] = arr[i];
for(int j = 1;(1 << j) <= n;j++)
    for(int i = 1;i + (1 << j) - 1 <= n;i++)
        dp[i][j] = Math.min(dp[i][j-1],dp[i + (1<<(j - 1))][j-1]);
二、查询

假设我们需要查询区间[L,R]中的最小值,令k=log2(RL+1)k=log_2(R-L+1),则区间[L,R]的最小值res=min(dp[L][k],dp[R(1&lt;&lt;k)+1][k])res=min(dp[L][k],dp[R-(1&lt;&lt;k)+1][k]),为什么这样就可以保证区间最值?dp[L][k]维护的是[L,L+2k1][L,L+2^k-1]dp[L][R2k+1][k]dp[L][R - 2^k+1][k]维护的是[R2k+1,R][R-2^k+1,R],因此只要证明R2k+1l+2k1R-2^k+1 ≤ l+2^k-1即可,这里证明省略

int k = (int) (Math.log(r - l + 1) / Math.log(2));
int min = Math.min(dp_min[l][k],dp_min[r - (1 << k) + 1][k]);
举个栗子

L=4,R=6L=4,R=6,此时k=log2(RL+1)=log23=1k=log_2(R-L+1)=log_23=1,所以RMQ(4,6)=min(dp[4][1],dp[5][1])=min(4,2)=2RMQ(4,6)=min(dp[4][1],dp[5][1])=min(4,2)=2,很容易看出来答案是正确的

题目链接:POJ3264

ST算法与LCA
ST算法板子题,用java的同学要注意的就是把你所有会的输入输出优化全用上,不然会TLE

import java.io.InputStreamReader;
import java.util.Scanner;

public class CF522A {
	final static int N = 50005;
	static int[][] dp_min = new int[N][25];
	static int[][] dp_max = new int[N][25];

	public static void main(String[] args) {
		Scanner cin = new Scanner(new InputStreamReader(System.in));
		int n = Integer.parseInt(cin.next());
		int m = Integer.parseInt(cin.next());		
		for(int i = 1;i <= n;i++) {
			int tmp = cin.nextInt();
			dp_min[i][0] = tmp;
			dp_max[i][0] = tmp;
		}
		//预处理
		for(int j = 1;(1 << j) <= n;j++)
		    for(int i = 1;i + (1 << j) <= n + 1;i++) {
		        dp_min[i][j] = Math.min(dp_min[i][j - 1],dp_min[i + (1 << j - 1)][j - 1]);//加减优先级高于位运算
		        dp_max[i][j] = Math.max(dp_max[i][j - 1],dp_max[i + (1 << j - 1)][j - 1]);
		    }
	
		while((m--) != 0) {
			int l = Integer.parseInt(cin.next());
			int r = Integer.parseInt(cin.next());
			int k = (int) (Math.log(r - l + 1) / Math.log(2));
			int min = Math.min(dp_min[l][k],dp_min[r - (1 << k) + 1][k]);
			int max = Math.max(dp_max[l][k],dp_max[r - (1 << k) + 1][k]);
			System.out.println(max - min);
		}
	}
}

LCA

求LCA(最近公共祖先)的算法有好多种,按在线和离线分为在线算法和离线算法,离线算法有基于搜索的Tarjan算法,而在线算法则是基于DP的ST算法。首先给定一棵树
ST算法与LCA
通过深搜,可以得到这样的一个序列:

数组下标:1 2 3 4 5 6 7 8 9 10 11 12 13
遍历顺序: A B D B E F E G E B A C A
结点在树中的深度:1 2 3 2 3 4 3 4 3 2 1 2 1

要查询D和G的LCA:

  1. 在遍历序列中找到D和G第一次出现的位置,first[D]=3,first[G]=8(3,8指数组下标)
  2. 取深度数组的[3,8]那一段序列,查询一个最小值min(3,2,3,4,3,4)=2,对应遍历数组中的结点是B,所以D,G的LCA是B
#include<cstring>
#include<iostream>
using namespace std;
int n,m,s,tot = 0,cnt = 0;
//vis[i]:dfs第i个访问的结点
//r[i]:vis[i]所在的层数
//fir[i]:vis[i]第一次出现的下标
int head[1000100],nxt[1000100],to[1000100];
int fir[1000100],vis[1000100],r[1000100];
int f[20][1000100],rec[20][1000100];
void addEdge(int x,int y) {
    cnt++;
    nxt[cnt] = head[x];
    head[x] = cnt;
    to[cnt] = y;
}
void dfs(int u,int dep) {//dfs处理出三个数组 
    fir[u] = ++tot,vis[tot] = u,r[tot] = dep;
    for(int i = head[u];i != -1;i = nxt[i]) {
        int v = to[i];
        if(!fir[v]) {
            dfs(v,dep + 1);
            vis[++tot] = u,r[tot] = dep;
        }
    }
}
int main() {
    memset(head,-1,sizeof(head));
    scanf("%d%d%d",&n,&m,&s);
    for(int i = 1;i < n;i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        addEdge(x,y);
        addEdge(y,x);
    }
    dfs(s,1);
    //ST表求RMQ
    for(int i = 1;i <= tot;i++)
        f[0][i] = r[i],rec[0][i] = vis[i];
    for(int i = 1;(1 << i) <= tot;i++)
        for(int j = 1;j + (1 << i) <= tot + 1;j++)
            if(f[i - 1][j] < f[i - 1][j + (1 << (i - 1))])
                f[i][j] = f[i - 1][j],rec[i][j] = rec[i - 1][j];
            else 
				f[i][j] = f[i - 1][j + (1 << (i - 1))],rec[i][j] = rec[i - 1][j + (1 << (i - 1))];
    //rec记录的是区间内深度最小值的编号
    for(int i = 1;i <= m;i++) {
    	int l,r,k = 0;
        scanf("%d%d",&l,&r);
        l = fir[l],r = fir[r];
        if(l > r) 
			swap(l,r);
        while((1 << k) <= r - l + 1) 
			k++;
        k--;
        if(f[k][l] < f[k][r - (1 << k) + 1]) 
			printf("%d\n",rec[k][l]);
        else 
			printf("%d\n",rec[k][r - (1 << k) + 1]);
    }
    return 0;
}