CodeForces 1106F Lunar New Year and a Recursive Sequence(BSGS + 原根 + 矩阵类)
大致题意:告诉你F的递推式,与前K个数字有关。但是只告诉你初始前K-1个项的值和第N项的值,让你求是否存在一个满足条件的第K项。
首先,由于这个递推式是:
这个式子是前K项的乘积,不方便我们用矩阵快速幂取递推,而本题n最大可以到1e9且K最大为100,根据复杂度来判断,显然是要用到矩阵快速幂的。所以我们想办法把这个转换为矩阵可以解的形式。注意到,给出的模数p是998244353,所以我们可以联想到原根,而998244353的原根就是3。我们令 ,那么有:
于是,转换为原根之后,只要我们求出指数x,我们就能够用快速幂g^x求出对应的f的数值。那么问题就转变成了如何求这个指数x。根据我们上面的式子,结合费马小定理,可以特出指数的递推式:
对于这个式子,我们就可以用上矩阵快速来快速求某一项的数值了。我们设构造矩阵为A,那么有:
左边的an表示第n项对应的原根的指数,ak是我们要求的那一项数字对应的原根的指数。中间的矩阵我们直接用快速幂即可求出。对于左边的an,我们直接用离散对数BSGS求解即可。正常来说,我们直接求逆元即可,但是注意到这里的p-1并不是一个质数,而且题目也说了可能不存在一个合法的第k项,也就意味着不一定有互质,不一定有逆元。所以我们不能简单的快速幂求逆元,应该用exgcd去求解。具体见代码:
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f3f3f3f3fll
#define LL long long
#define sc(x) scanf("%lld",&x)
#define scc(x,y) scanf("%lld%lld",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
const int N = 110;
const int mod = 998244353 - 1;
map<LL, LL> tab;
LL K,n,m,b[N];
struct Matrix
{
LL a[N][N];
Matrix(){memset(this,0,sizeof(Matrix));}
Matrix operator *(const Matrix x) const
{
Matrix ans;
for(int i=0;i<K;i++)
for(int j=0;j<K;j++)
{
for(int k=0;k<K;k++)
ans.a[i][j]+=a[i][k]*x.a[k][j]%mod;
ans.a[i][j]%=mod;
}
return ans;
}
friend Matrix operator ^(Matrix x,LL y)
{
Matrix ans;
for(int i=0;i<K;i++)
ans.a[i][i]=1;
while(y)
{
if (y&1) ans=ans*x;
x=x*x; y>>=1;
}
return ans;
}
};
LL qpow(LL x,LL n,LL p)
{
LL res=1;
while(n)
{
if (n&1) res=res*x%p;
x=x*x%p; n>>=1;
}
return res;
}
LL bsgs(LL a, LL b, LL p){
LL u = (LL) sqrt(p) + 1;
LL now = 1, step;
for (LL i = 0; i < u; i++){
LL tmp = b * qpow(now, p - 2, p) % p;
if (!tab.count(tmp)){
tab[tmp] = i;
}
(now *= a) %= p;
}
step = now;
now = 1;
for (LL i = 0; i < p; i += u){
if (tab.count(now)){
return i + tab[now];
}
(now *= step) %= p;
}
return -1;
}
LL ex_gcd(LL a,LL b,LL &x,LL &y)
{
if(b==0){x=1; y=0;return a;}
else
{
LL r=ex_gcd(b,a%b,y,x);
y-=x*(a/b); return r;
}
}
LL solve(LL a, LL b, LL c)
{
if (c == 0) return 0;
LL q = __gcd(a, b);
if (c % q){
return -1;
}
a /= q, b /= q, c /= q;
LL ans, _;
ex_gcd(a, b, ans, _);
(ans *= c) %= b;
while (ans < 0) ans += b;
return ans;
}
int main()
{
sc(K);
for(int i=0;i<K;i++) sc(b[i]);
Matrix x;
for(int i=1;i<K;i++) x.a[i][i-1]=1;
for(int i=0;i<K;i++) x.a[0][i]=b[i];
scc(n,m); x=x^(n-K); LL y=bsgs(3,m,mod+1);
LL ans=solve(x.a[0][0],mod,y);
if (ans<0) printf("%lld\n",-1LL);
else printf("%lld\n",qpow(3,ans,mod+1));
return 0;
}