noyのブログ

プログラミングとかゲームとか

ABC009 - D 漸化式

問題概要

二つの数列 { \displaystyle
A_1, A_2, …, A_K
}{ \displaystyle
 C_1, C_2, …, C_K
} が入力で与えられる。

数列 A は以下の式によって計算できる。
{ \displaystyle
A_{N+K}=(C_1 AND A_{N+K−1}) XOR (C_2 AND A_{N+K−2}) XOR …  XOR (C_K AND A_N)
}

数列 A の M 番目の値を求めよ。

解き方

漸化式を計算するには、DPを使えば良いです。しかし、制約が大きく、DPでは解けません。
別の方法を使って漸化式を解く必要があります。

そこで、行列累乗を使います(あり本 P.180参照)。
行列累乗に and や xor が使えるのかですが、それは atcoder の公式解説を参照してください。
あり本を見ながらやれば実装できます。ほとんどあり本のプログラムそのままです。
正しい順番で値を入れて、正しい値で初期化して、演算子を変えましょう。

ただ、割と引っかかるポイントが多かったです。

気をつけるところ

and は論理積

bit の and が取りたい時は、 bitand か & を使いましょう。

各項の順番

行列に値を入れる際、右から左に並べるのか、左から右に並べるのか(もちろん上下も)。

and 演算子単位元

普通に掛け算をするのであれば、単位元は 1 です。
あり本にあるプログラムや式は、行列の掛け算を行うため、 単位元に 1 が使われています。
しかし、この問題では and 演算子を使うため、単位元は全 bit が 1 の値です。
初期化ミスに気をつけましょう。

m ≤ k の時の処理

m ≤ k の時は、与えらえれた入力をそのまま返せば良いです。
入力を行列累乗のために逆順に持っていたことを忘れるとWAです(2敗)。

実装

#define int long long なので、単位元は -1 を使っています。

#include<bits/stdc++.h>
#define range(i,a,b) for(int i = (a); i < (b); i++)
#define rep(i,b) for(int i = 0; i < (b); i++)
#define all(a) (a).begin(), (a).end()
#define show(x)  cerr << #x << " = " << (x) << endl;
//const int INF = 1e8;
using namespace std;

#define int long long
typedef vector<vector<int>> mat;

mat mul(mat &a, mat &b){
	mat c(a.size(), vector<int>(b[0].size()));
	rep(i,a.size()){
		rep(k,b.size()){
			rep(j,b[0].size()){
				c[i][j] = (c[i][j] xor (a[i][k] bitand b[k][j]));
			}
		}
	}
	return c;
}

mat pow(mat a, int n){
	mat b(a.size(), vector<int>(a.size(),0));
	rep(i,b.size()) b[i][i] = -1;
	while(n > 0){
		if(n & 1) b = mul(b,a);
		a = mul(a, a);
		n >>= 1;
	}
	return b;
}

int solve(int n, mat A, vector<int> c){
	mat a(A.size(), vector<int>(A.size(), 0));
	for(int i = c.size() - 1; i >= 0; i--){
		//a[0][i] = c[A.size() - 1 - i]; ここの順番が逆だった。
		a[0][i] = c[i];
	}
	rep(i,c.size() - 1){
		a[i + 1][i] = -1;
	}

	a = pow(a,n); //行列Aのn乗。
	mat res = mul(a, A);
	return res[0][0];
}

signed main(){
	int k, m;
	cin >> k >> m;

	mat a(k, vector<int>(1));
	vector<int> c(k);
	rep(i,k) cin >> a[k - i - 1][0];
	rep(i,k) cin >> c[i];

	if(m <= k){
		cout << a[k - m][0] << endl;
	}else{
		cout << solve(m - k,a,c) << endl; //-kを忘れずに
	}
}