日々drdrする人のメモ

今日も明日もdrdr

CodeChef October Challenge 2019: Queries on Matrix

CodeChef October Challenge 2019 の問題: Queries on Matrix (Code: JIIT)

Contest Page | CodeChef

問題概要

 N \times M の行列が与えられる。全ての要素は初め 0 である。

 r行目 c列目の要素を  (r, c) として記述する。

以下の操作をちょうど  Q 回行う。

  1. 1つの要素  (x, y)  (1 \le x \le N, 1 \le y \le M) を選ぶ
  2.  x 行目の全ての要素に 1 を足す
  3.  y 列目の全ての要素に 1 を足す

 Q 回の操作の後、行列全体の内で要素が奇数となる要素の数を  Z 個とする要素の操作列は何通り存在するかを modulo 998244353 で求めよ。
(操作列は、異なる順序で要素を選ぶもの同士を区別する)

制約

  • 1つのテストケースのケース数:  1 \le T \le 50
  •  1 \le N, M \le 2000
  •  0 \le Z \le N \cdot M
  •  1 \le Q \le 10^{18}

解法

行列累乗で試行錯誤しながら計算した。


まず Q回の操作の後、 N行の中で奇数回1を足した行の数を a M列の中で奇数回1を足した列の数を bとする時、奇数となる要素の数は  aM + bN - 2 ab になるため、これが  Z と等しくなるように  a, b を選べばよい。
この選び方は  2Z = NM 2Z \not = NM かで場合分けすればよく、選び方は  O(N + M) になる。

行と列については独立に考えられるため、前もって  1 \le a \le N となる行の選び方 と  1 \le b \le M となる列の選び方 をそれぞれ計算し、全ての選び方について計算すればよい。
ここからは行について考える。(列についても同様に求められる)


DPで求める場合、以下で求めることができる。

  •  \mathit{dp}[i][a] = i回操作した後、奇数回選んだ行が  a 個となる通り数

このDPは以下で計算できる。

  •  \mathit{dp}[i][a] = \mathit{dp}[i-1][a-1] * (N - a + 1) + \mathit{dp}[i+1][a+1] * (a + 1)

最終的に各 aについて  dp[Q][a] を求めればよい。(この計算量は  O(NQ))


次に行列累乗を用いて計算することを考える。
 k回後の操作において、奇数回選んだ行数が  i 個となる通り数を x_iとする、長さ N+1 のベクトル  \mathbf{x}_k = [x_0, x_1, ..., x_N]^\mathrm{T}とし、このベクトルに行列を掛けることで最終的に  \mathbf{x}_Q を求める。
初期値の  \mathbf{x}_0 [1, 0, 0, ..., 0]^\mathrm{T} となる。


以下を満たす  (N+1) \times (N+1) の行列 Aを考える。

  •  1 \le i \le N について
    •  A[i-1][i] = N-i
    •  A[i][i-1] = i
  • それ以外の要素は 0

そして、 \mathbf{x}_Q = A^Q \mathbf{x}_0 を求める。

f:id:smijake3:20191018234238p:plain:w320
N = 5 の時の行列

この計算量は  O(N^3 \log Q) となる。


ここから、この行列累乗を高速化する。

行列を対角化する。
まず、行列 Aについて固有値を求めてみると、 \lambda = -N, 2-N, ..., N-2, N となることが分かる。

次に各固有値に対する固有ベクトルを求める必要があるが、行列 (A - \lambda I)がtridiagonalの形になるため、固有ベクトル O(N) で求めることができる。
具体的には、行列 Aに対するある固有値 Xの1つの固有ベクトル  v_X = [a_0, a_1, ..., a_N]^\mathrm{T} は以下の数列から求まる。

  •  a_0 = 1,  a_1 = X
  •  \displaystyle a_k = \frac{X \cdot a_{k-1} - (N + 2 - k) \cdot a_{k-2}}{k}

これらの固有ベクトルから成る行列 P = [v_N v_{N-2} ... v_{2-N} v_{-N}] を考える。
この時、この行列 P逆行列 P^{-1} \displaystyle \frac{1}{2^N} P になるため、逆行列を改めて求める必要がない。

また、対角行列  \Lambda = \mathit{diag}(N, N-2, ..., 2-N, -N) Q 乗の  \Lambda^Q O(N \log Q) で求まる。

よって最終的に求めたいベクトル  \displaystyle \mathbf{x}_Q = A^Q \mathbf{x}_0 = \frac{1}{2^N} P \Lambda^Q P \mathbf{x}_0 O(N (N + \log Q)) で求められる。

f:id:smijake3:20191019102403p:plain:w640
N = 5 の例

全体の計算量は  O(N (N + \log Q) + M (M + \log Q))

実装

提出コード(C++14): Solution: 27139983 | CodeChef

#define N 2003

const ll mod = 998244353;

int n, m, z;
ll q;

ll xs[N], ys[N];
ll vs[N][N];

ll fast_pow(ll x, ll n) {
  ll res = 1;
  while(n > 0) {
    if(n & 1) {
      res = res * x % mod;
    }
    x = x * x % mod;
    n >>= 1;
  }
  return res;
}

ll revv[N];
void calc(ll xs[N], int n, ll q) {
  rep(i, n+1) xs[i] = 0;
  rep(i, n+1) {
    // 固有ベクトルの計算
    ll x = (mod + n - 2*i) % mod;
    vs[i][0] = 1; vs[i][1] = x;
    for(int j = 2; j <= n; ++j) {
      vs[i][j] = ((x * vs[i][j-1] + vs[i][j-2] * (mod+j-n-2)) % mod) * revv[j] % mod;
    }

    // x_Q の計算
    ll v = fast_pow(x, q % (mod-1)) * vs[0][i] % mod;
    rep(j, n+1) {
      xs[j] += v * vs[i][j];
      xs[j] %= mod;
    }
  }
  ll rev = fast_pow(revv[2], n);
  rep(i, n+1) xs[i] = xs[i] * rev % mod;
}

int main() {
  revv[0] = revv[1] = 1;
  for(int i = 2; i <= 2000; ++i) revv[i] = fast_pow(i, mod-2);
  int t; cin >> t;
  while(t--) {
    cin >> n >> m >> q >> z;
    ll res = 0;
    calc(xs, n, q);
    calc(ys, m, q);

    if(n*m == 2*z) {
      // NM = 2Z の場合は (N - 2a) (M - 2b) = 0 を満たす
      if(n % 2 == 0) {
        rep(b, m+1) {
          if(b * 2 != m) res += xs[n/2] * ys[b] % mod;
        }
      }
      if(m % 2 == 0) {
        rep(a, n+1) {
          if(a * 2 != n) res += xs[a] * ys[m/2] % mod;
        }
      }
      if(n % 2 == 0 && m % 2 == 0) {
        res += xs[n/2] * ys[m/2] % mod;
      }
      res %= mod;
    } else {
      // NM != 2Z の場合は Z = aM + bN - 2ab を満たすa, bを走査
      rep(a, n+1) {
        int p = (z - a*m), q = (n - 2*a);
        if(q < 0) {
          p = -p; q = -q;
        }
        if(q == 0 || p < 0 || p % q != 0) {
          continue;
        }
        int b = p / q;
        if(0 <= b && b <= m) {
          res += xs[a] * ys[b] % mod;
        }
      }
      res %= mod;
    }
    cout << res << "\n";
  }
  return 0;
}