頑張ってPythonで解いた。
C: 高速フーリエ変換 - AtCoder Typical Contest 001 | AtCoder
今回の問題は畳み込みを行う問題。で計算しないと間に合わないので、FFTで計算する必要がある。
しかし、複素数上で行うFFTでは精度がアレして死んだりするのでmod上で計算できるFMTというものを使って解いた。
この問題の解説にFFTの解説があるので、ここでは簡単な説明を書いとく。
FMT
FMT(Fast Modulo Transform, 高速剰余変換)。NTT(Number Theoretical Transform, 数論変換)とも言われるらしい。
FMTはFFTと同じ方法で計算を行うが、複素数上ではなく剰余環上で計算するという違いがある。
具体的には、 (mod ) となるような を利用して計算する。この時は素数で、フェルマーの小定理から の形になる。
ここでは、 (mod ) とする。
順変換
について、 (mod )を再帰的に計算すればよい。
この時、 とする。
Nが偶数であれば、 の時と同時に も計算できる。
Nが奇数の時も計算はできるけど、簡単のために今回は (すべて偶数)で計算することにした。
Pythonの実装
今回のFMTはCooley-Tukey型で実装している。この型だと、in-placeとして一つの配列の中で計算できる利点があり、無駄にlistオブジェクトを生成しないようにできる。
パラメタの決定
今回の問題の制約から、FMTにおけるNの上限は であり、計算数値の上限は である。そのため、今回のFMTでは、 、 となるようなパラメタを計算して求めた。
の値を小さくして複数回のFMTの結果と中国余剰定理を用いて解を出すという方法もあったが、FMT一回で計算したほうが効率がよいと考えた。
パラメタ決定のために書いたコードは以下に載せてる。狙った範囲を愚直に計算して求めるやつを書いた。
https://gist.github.com/tjkendev/16a0c0fe5e5dca811ac0171f3491ff62
その他の処理のネック
Python標準のinputやprintは入出力数が多くなると遅くなってTLEする要因になるため、今回はsys.stdin.readとsys.stdout.writeを用いている。
実装
で実装。
# FMT用のパラメタ omega = 103 n = 2**18 P = 5880*n + 1 rev = pow(omega, P-2, P) # バタフライ演算としてのbit反転処理 # in-placeで計算するために利用 def bit_reverse(d): n = len(d) ns = n>>1; nss = ns>>1 ns1 = ns + 1 i = 0 for j in xrange(0, ns, 2): if j<i: d[i], d[j] = d[j], d[i] d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1] d[i+1], d[j+ns] = d[j+ns], d[i+1] k = nss; i ^= k while k > i: k >>= 1; i ^= k return d # FMTをループで計算 def fmt_bu(A, n, base, half, Q): N = n m = 1 while n>1: n >>= 1 # ω^{2m} ≡ 1 となるω w = pow(base, n, Q) wk = 1 for j in xrange(m): for i in xrange(j, N, 2*m): # U = g(ω^{2k}), V = ω^k * h(ω^{2k}) U = A[i]; V = (A[i+m]*wk) % Q A[i] = (U + V) % Q # half = ω^{N/2} A[i+m] = (U + V*half) % Q wk = (wk * w) % Q m <<= 1 return A # FMTの順変換 def fmt(f, l, Q=P): if l == 1: return f A = f[:] # bit反転 bit_reverse(A) return fmt_bu(A, n, omega, pow(omega, n/2, Q), Q) # FMTの逆変換 def ifmt(F, l, Q=P): if l == 1: return F A = F[:] # bit反転 bit_reverse(A) # 逆変換なので、ωの代わりにω^{-1}を渡す f = fmt_bu(A, n, rev, pow(rev, n/2, Q), Q) # Nで割って返す n_rev = pow(n, Q-2, Q) return [(e * n_rev) % Q for e in f] # FMTを利用した畳込み処理 def convolute(a, b, l, Q=P): A = fmt(a, l, Q) B = fmt(b, l, Q) C = [(s * t) % Q for s, t in zip(A, B)] c = ifmt(C, l, Q) return c import sys inp = map(int, sys.stdin.read().split()) # 長さnに足りないところはゼロ埋め m = inp[0] f = inp[1::2] + [0]*(n-m) g = inp[2::2] + [0]*(n-m) # 畳み込み fg = convolute(f, g, m) sys.stdout.write("0\n") sys.stdout.write("\n".join(map(str, fg[:2*m-1]))) sys.stdout.write("\n")
この実装でACできる
Submission #1051733 - AtCoder Typical Contest 001 | AtCoder