日々drdrする人のメモ

今日も明日もdrdr

ATC001 - C問題: 高速フーリエ変換 (FMT解法)

頑張ってPythonで解いた。

C: 高速フーリエ変換 - AtCoder Typical Contest 001 | AtCoder



今回の問題は畳み込みを行う問題。 {O(N \log N)}で計算しないと間に合わないので、FFTで計算する必要がある。
しかし、複素数上で行うFFTでは精度がアレして死んだりするのでmod上で計算できるFMTというものを使って解いた。

この問題の解説にFFTの解説があるので、ここでは簡単な説明を書いとく。

FMT

FMT(Fast Modulo Transform, 高速剰余変換)。NTT(Number Theoretical Transform, 数論変換)とも言われるらしい。

FMTはFFTと同じ方法で計算を行うが、複素数上ではなく剰余環上で計算するという違いがある。
具体的には、 {\omega^i \neq 1\hspace{3mm}(0 < i < N), \omega^N = 1} (mod  {P}) となるような  {\omega, P, N} を利用して計算する。この時 {P}素数で、フェルマーの小定理から  {P = A*N + 1} の形になる。

ここでは、 {\displaystyle f(x) = \sum_{i=0}^{N-1} a_i x^i} (mod  {P}) とする。

順変換

 {0 \le i \le N-1}について、 {f_k = f(\omega^k)} (mod  {P})を再帰的に計算すればよい。

 {
\begin{align}
f_k & = & f(\omega^k) = \sum_{i=0}^{N-1} a_i \omega^{ik} \\
    & = & \sum_{i=0}^{\frac{N}{2}-1} a_{2i} \omega^{i * 2k} + \omega^k \sum_{i=0}^{\frac{N}{2}-1} a_{2i+1} (\omega^2)^{i * 2k} \\
    & = & g(\omega^{2k}) + \omega^k h(\omega^{2k})
\end{align}
}

この時、 {\displaystyle g(x) = \sum_{i=0}^{\frac{N}{2}-1} a_{2i} x^i, h(x) = \sum_{i=0}^{\frac{N}{2}-1} a_{2i+1} x^i} とする。

Nが偶数であれば、 {k} の時と同時に  {\frac{N}{2} + k} も計算できる。

 {
\begin{align}
f_{\frac{N}{2} + k} & = & f(\omega^{\frac{N}{2} + k}) = \sum_{i=0}^{N-1} a_i \omega^{i(\frac{N}{2} + k)} \\
                    & = & \sum_{i=0}^{\frac{N}{2}-1} a_{2i} \omega^{i * 2k} + \omega^{\frac{N}{2} + k} \sum_{i=0}^{\frac{N}{2}-1} a_{2i+1} \omega^{i * 2k} \\
                    & = & g(\omega^{2k}) + \omega^{\frac{N}{2} + k} h(\omega^{2k})
\end{align}
}

Nが奇数の時も計算はできるけど、簡単のために今回は  {N = 2^m} (すべて偶数)で計算することにした。

逆変換

順変換と同じように  {\displaystyle a_i = \frac{1}{N} \sum_{k=0}^{N-1} f_k \omega^{-ik}} を計算する。

この時 {N}で割るが、これはフェルマーの小定理より  {N^{P-2}} (mod  {P}) を掛けることに該当する。


Pythonの実装

今回のFMTはCooley-Tukey型で実装している。この型だと、in-placeとして一つの配列の中で計算できる利点があり、無駄にlistオブジェクトを生成しないようにできる。

パラメタの決定

今回の問題の制約から、FMTにおけるNの上限は  {2 * 10^5} であり、計算数値の上限は  {10^9} である。そのため、今回のFMTでは、 {N \gt 2 * 10^5} {P \gt 10^9} となるようなパラメタを計算して求めた。
 {P}の値を小さくして複数回のFMTの結果と中国余剰定理を用いて解を出すという方法もあったが、FMT一回で計算したほうが効率がよいと考えた。

パラメタ決定のために書いたコードは以下に載せてる。狙った範囲を愚直に計算して求めるやつを書いた。
https://gist.github.com/tjkendev/16a0c0fe5e5dca811ac0171f3491ff62

その他の処理のネック

Python標準のinputやprintは入出力数が多くなると遅くなってTLEする要因になるため、今回はsys.stdin.readとsys.stdout.writeを用いている。

実装

 {\omega = 103, N = 2^{18} (= 262144), P = 5880*N + 1 (= 1541406721)}で実装。

# FMT用のパラメタ
omega = 103
n = 2**18
P = 5880*n + 1
rev = pow(omega, P-2, P)

# バタフライ演算としてのbit反転処理
# in-placeで計算するために利用
def bit_reverse(d):
    n = len(d)
    ns = n>>1; nss = ns>>1
    ns1 = ns + 1
    i = 0
    for j in xrange(0, ns, 2):
        if j<i:
            d[i], d[j] = d[j], d[i]
            d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1]
        d[i+1], d[j+ns] = d[j+ns], d[i+1]
        k = nss; i ^= k
        while k > i:
            k >>= 1; i ^= k
    return d

# FMTをループで計算
def fmt_bu(A, n, base, half, Q):
    N = n
    m = 1
    while n>1:
        n >>= 1
        # ω^{2m} ≡ 1 となるω
        w = pow(base, n, Q)
        wk = 1
        for j in xrange(m):
            for i in xrange(j, N, 2*m):
                # U = g(ω^{2k}), V = ω^k * h(ω^{2k})
                U = A[i]; V = (A[i+m]*wk) % Q
                A[i] = (U + V) % Q
                # half = ω^{N/2}
                A[i+m] = (U + V*half) % Q
            wk = (wk * w) % Q
        m <<= 1
    return A

# FMTの順変換
def fmt(f, l, Q=P):
    if l == 1: return f
    A = f[:]
    # bit反転
    bit_reverse(A)
    return fmt_bu(A, n, omega, pow(omega, n/2, Q), Q)

# FMTの逆変換
def ifmt(F, l, Q=P):
    if l == 1: return F
    A = F[:]
    # bit反転
    bit_reverse(A)
    # 逆変換なので、ωの代わりにω^{-1}を渡す
    f = fmt_bu(A, n, rev, pow(rev, n/2, Q), Q)
    # Nで割って返す
    n_rev = pow(n, Q-2, Q)
    return [(e * n_rev) % Q for e in f]

# FMTを利用した畳込み処理
def convolute(a, b, l, Q=P):
    A = fmt(a, l, Q)
    B = fmt(b, l, Q)
    C = [(s * t) % Q for s, t in zip(A, B)]
    c = ifmt(C, l, Q)
    return c

import sys
inp = map(int, sys.stdin.read().split())
# 長さnに足りないところはゼロ埋め
m = inp[0]
f = inp[1::2] + [0]*(n-m)
g = inp[2::2] + [0]*(n-m)

# 畳み込み
fg = convolute(f, g, m)

sys.stdout.write("0\n")
sys.stdout.write("\n".join(map(str, fg[:2*m-1])))
sys.stdout.write("\n")

この実装でACできる
Submission #1051733 - AtCoder Typical Contest 001 | AtCoder