【题解】codeforces1097H数位dp+合并技巧
题意
按原题给出一个序列生成方式。求第[l,r]为间有多少子串A满足A <= B
<= 定义为每一位都小于等于
题解
思路:
维护长度为d^i的区间的信息,使得区间可以合并。
这样合并区间的思路非常常见。在计数和线段树。。。都可以用到
注意:
维护的时候如果位数不够补充成0 这样合并的时候不用再check整块是否合法
关于位运算,用all维护所有状态,可以方面的实现删除和保留一些
总结和反思:
这是复习+写题,但是仍然花了很多时间理清细节、我对于代码的析构能力很差,知道做法还是不能清晰的呈现。要加强。提高代码能力!
关于位数不够补充成0 , 一开始没有想到,看了别人的代码也没有懂,直到最后才想到。我的思维远不够敏锐!
计算复杂度要仔细,开始少乘了一个d,以为不用bitset也能过、
前缀和后缀倒着维护,这样bitset可以直接&起来。有点像卷积的思路
这道题写了2h20min,非常低效!
#include<bits/stdc++.h>
using namespace std;
#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define forup(i,l,r) for (register int i = l ; i <= r ; i += lowbit(i))
#define fordown(i,id) for (register int i = id ; i ; i -= lowbit(i))
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define lowbit(x) (x&(-x))
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<int,int> pr;
const ll inf = 2e18;
const int N = 3e4 + 10;
const int maxn = 2000020;
const ll mod = 1e9 + 7;
struct node{
ll num,len;
bitset <N> pre,suf;
node() { num = len = 0; pre.reset(),suf.reset(); }
}f[70][62];
int gen[120],n,m,d,B[N];
ll l,r,pow_[70];
int a[maxn];
bitset <N> all;
node merge(const node &a,const node &b){
if ( !a.len ) return b;
node res;
res.len = a.len + b.len;
res.pre = a.pre , res.suf = b.suf; //pre和suf中包含了长度1..n-1的所有信息,所以合并的时候直接&就好
res.num = a.num + b.num;
if ( a.len <= n - 2 ) res.pre &= (b.pre >> a.len) | (all << (n - 1 - a.len));
if ( b.len <= n - 2 ) res.suf &= (a.suf << b.len) | (all >> (n - 1 - b.len));
if ( res.len >= n ){
bitset<N> tmp = a.suf & b.pre;
if ( a.len <= n - 2 ) tmp &= all >> (n - 1 - a.len); //如果a的长度不足,则删除高位
if ( b.len <= n - 2 ) tmp &= all << (n - 1 - b.len); //删除低位
res.num += tmp.count();
}
return res;
}
void init(){
pow_[0] = 1;
rep(i,1,64) pow_[i] = pow_[i - 1] * d;
int c = 0;
while ( pow_[c + 1] < r ) c++;
rep(i,1,n - 1) all[i] = 1;
rep(j,0,m){
f[0][j].len = 1;
if ( n == 1 ) f[0][j].num = j <= B[1];
else{
//把没有出现的位置看成0
//这样便于合并、否则每次合并还需要check整段是否合法
rep(k,1,n){
f[0][j].pre[k - 1] = j <= B[k];
f[0][j].suf[k] = j <= B[k];
}
}
}
rep(i,1,c){
rep(j,0,m){
rep(k,0,d - 1){
f[i][j] = merge(f[i][j],f[i - 1][(j + gen[k]) % m]);
}
}
}
}
ll solve(ll n){
int c = 0;
while ( pow_[c + 1] < n ) c++;
int add = 0; node res;
repd(i,c,0){
rep(j,0,d - 1){
if ( n >= pow_[i] ){
n -= pow_[i];
res = merge(res,f[i][(gen[j] + add) % m]);
}
else{
add += gen[j];
break;
}
}
}
cout<<res.num<<endl;
return res.num;
}
int main(){
scanf("%d %d",&d,&m);
rep(i,0,d - 1) scanf("%d",&gen[i]);
scanf("%d",&n);
rep(i,1,n) scanf("%d",&B[i]);
cin>>l>>r;
init();
cout<<solve(r) - solve(l + n - 2)<<endl;
}