CodeChef May Challenge 2018: Edges in Spanning Trees
CodeChefのMay Challenge 2018の問題: Edges in Spanning Trees (Code: EDGEST)
問題: https://www.codechef.com/MAY18A/problems/EDGEST
問題
1からNの番号がついた個の頂点を含む2つの木が与えられる。
木に含まれる各辺について、以下の条件を満たす上の辺の数を計算せよ。
- 木から辺を除去しを追加したグラフが木になる
- 木から辺を除去しを追加したグラフが木になる
制約
1つのテストケースに含まれるケース数:
1つのテストケースに含まれるの合計はを超えない
解法
結構悩んだ。
特に、各辺と交換できる辺をどのように管理して数え上げするかで無限に悩んだ。通すのに5日程度かかったけど、なんとか自力で通せたので嬉しい。
最終的に、LCA + 部分木の和を計算するEuler Tour (Binary Indexed Tree)で通した。
はじめに、でが繋ぐ頂点番号同士を繋ぐ辺がにも存在する場合、このに対するの個数は1となる。
これ以降は同じ頂点番号を繋ぐが存在しないについて考える。
まず、各辺と交換できる辺が満たすべき条件は以下の2つが存在する
- において、頂点aと頂点bの間のパスa-b上にが含まれる (下図の左)
- からを除去した時にできる2つの部分木それぞれに頂点xと頂点yが片方ずつ含まれる (下図の右)
この問題では制約上、個の各辺に対し、条件(1)と(2)を満たす辺の個数をで計算する必要がある。
今回は、各について、木において頂点aと頂点bの間のパスa-b上に存在する辺の中から、条件(2)を満たす辺の個数を数える。
この数え上げを行う際、パスa-b上の数え上げを行う代わりに、2つのパスに分解してパスa-w上とパスw-b上に分けて数え上げを行うことを考える。この頂点wは、木における頂点aと頂点bのLCAとする。
また、条件(2)を満たす辺を数えるときににおいて条件(1)を満たさない根頂点Oと頂点wの間のパスO-w上に存在する辺を含めて数え上げて、あとからそれらを引いても答えを計算できる。
これらのことから、各辺について条件を満たす辺の数は、
= (について上のパスp-q上に存在する辺の中で条件(2)を満たす辺の数)
とすると、
を計算することで求まることが分かる。
この変形によって、の根頂点Oからある頂点vまでに含まれる辺の中から条件(2)を満たす辺を数え上げる問題になるため、を根からDFSで遷移しながら数え上げできるようになる。
次に、DFSを行いながら各について上のパスO-v間に存在する辺の内、条件(2)を満たす辺の個数を数えることを考える。
これは条件(2)の通りに、上のパスO-v間に存在する辺の内、片方の部分木に辺が繋ぐ内の一方の頂点のみが含まれる個数を計算すればよい。
この数え上げはEuler Tourとセグ木 (or BIT)を使えば実現できる。
Euler Tourを使うことで、ある頂点v以下の部分木に含まれる頂点の値の和や頂点の値の更新をで計算することができる。
今回は上のEuler Tourを構築し、セグ木でクエリを処理できるようにした上で、上をDFSをしながらクエリを処理していく。
まず、DFSで新たに子ノードを訪れる際に辺を通過する場合は、セグ木で頂点x, yに+1、頂点x, yのLCAである頂点uに-2を足して更新する。逆に子ノードを訪れ終わってDFSにおける後退を行う場合は逆の更新を行う。
そして、頂点vに到達した時点で、辺の頂点a, bのうち深さが深い方の頂点以下の部分木の和を計算することでが計算できる。
部分和と辺の値の関係を図にすると以下の感じになる。の片方の頂点が部分木に含まれている場合(下図の右)のみ部分木の和が+1され、両方含まれる(下図の左)や両方含まれない場合(下図の中央)は+1されない。
このDFS実装を行う際、"辺を通過する" = "辺が繋ぐ頂点の内、深さが深い方の頂点を訪れる" と解釈すると実装しやすい。
ここまでをまとめた解法は以下の通りである。
- のEuler Tourによって部分木の和をセグ木で管理する
- をDFSし、通過する辺に応じてセグ木を更新する
- 各辺について、頂点aと頂点b、頂点u=LCA(a, b)に到達した時点でセグ木で部分木の和を計算する (この時depth[a] < depth[b]とする)
- 頂点aもしくは頂点bに到達した場合、頂点b以下の部分木に含まれる頂点の値の和を計算し、辺に対する解に、計算した和を足す
- 頂点uに到達した場合、頂点b以下の部分木に含まれる頂点の和の値を計算し、辺に対する解から、計算した和の2倍の値を引く
実装
計算量は。
PyPy2では間に合わなかったのでC++
提出コード (C++14): https://www.codechef.com/viewsolution/18529025
#define N 200005 #define L 20 int n, l; vector<int> g1[N], g2[N], ss[N], tt[N]; vector<P> e1; set<P> m2; int prev1[N][L], prev2[N][L]; int d1[N], d2[N]; int unq[N], ans[N]; int aa[N], bb[N]; // LCAの前準備 void lca(vector<int> g[N], int prev[N][L], int d[N]) { queue<int> que; rep(i, n) rep(j, l) prev[i][j] = -1; rep(i, n) d[i] = -1; d[0] = 0; que.push(0); while(!que.empty()) { int v = que.front(); que.pop(); rep(i, g[v].size()) { int w = g[v][i]; if(d[w] == -1) { prev[w][0] = v; d[w] = d[v] + 1; que.push(w); } } } repl(k, 1, l-1) { rep(i, n) { if(prev[i][k-1] != -1) { prev[i][k] = prev[prev[i][k-1]][k-1]; } } } } // LCAのクエリ処理 int query(int u, int v, int prev[N][L], int d[N]) { int dd = d[v] - d[u]; if(dd < 0) { dd = -dd; int tmp = u; u = v; v = tmp; } rep(k, l) { if(dd & 1) v = prev[v][k]; dd >>= 1; } if(u == v) return u; repr(k, l) { if(prev[u][k] != prev[v][k]) { u = prev[u][k]; v = prev[v][k]; } } return prev[u][0]; } // Euler Tourのクエリのための前処理 int sz, data[N]; void dfs0(int v, int p) { aa[v] = sz; rep(i, g1[v].size()) { int w = g1[v][i]; if(w == p) continue; dfs0(w, v); } bb[v] = ++sz; } // Binary Indexed Tree int get(int k) { int s = 0; while(k > 0) { s += data[k]; k -= k & -k; } return s; } void add(int k, int x) { while(k <= sz) { data[k] += x; k += k & -k; } } // 各e_1の解を計算するためのDFS int rr[N]; void dfs(int v, int p) { int w0 = -1; if(unq[v]) { w0 = query(v, p, prev1, d1); // DFSの前進: 頂点v,pに+1, LCAのw0に-2 add(bb[w0], -2); add(bb[v], 1); add(bb[p], 1); } // e_1 = (a, b)に対してf(O, a) or f(O, b)を計算 rep(j, ss[v].size()) { int i = ss[v][j]; int s = rr[i]; ans[i] += get(bb[s]) - get(aa[s]); } // e_1 = (a, b)に対してf(O, LCA(a, b))を計算 rep(j, tt[v].size()) { int i = tt[v][j]; int s = rr[i]; ans[i] -= (get(bb[s]) - get(aa[s])) * 2; } rep(i, g2[v].size()) { int w = g2[v][i]; if(w == p) continue; dfs(w, v); } if(unq[v]) { // DFSの後退: 頂点v, pに-1, LCAのw0に+2 add(bb[w0], 2); add(bb[v], -1); add(bb[p], -1); } } void solve() { cin >> n; rep(i, n-1) { int a, b; cin >> a >> b; --a; --b; g1[a].push_back(b); g1[b].push_back(a); e1.push_back(a < b ? P(a, b) : P(b, a)); } rep(i, n-1) { int a, b; cin >> a >> b; --a; --b; g2[a].push_back(b); g2[b].push_back(a); m2.insert(a < b ? P(a, b) : P(b, a)); } l = 1; int v = 1; while(v < n) v <<= 1, l++; lca(g1, prev1, d1); lca(g2, prev2, d2); // unq[i] = 1: iに紐づく辺e_2はe_1と重複するため、計算対象から除外 rep(i, n) unq[i] = 1; unq[0] = 0; rep(i, n) ans[i] = 0; rep(i, n-1) { P &e = e1[i]; int a = e.first, b = e.second; if(m2.find(e) != m2.end()) { ans[i] = 1; if(d2[a] < d2[b]) { unq[b] = 0; } else { unq[a] = 0; } continue; } int w = query(a, b, prev2, d2); tt[w].push_back(i); ss[a].push_back(i); ss[b].push_back(i); } sz = 0; dfs0(0, -1); rep(i, sz+1) data[i] = 0; rep(i, n-1) { P &e = e1[i]; int a = e.first, b = e.second; rr[i] = (d1[a] < d1[b] ? b : a); } dfs(0, -1); rep(i, n-1) { cout << (i ? " " : "") << ans[i]; } cout << endl; } int main() { int t; cin >> t; while(t--) { rep(i, N) g1[i].clear(), g2[i].clear(), ss[i].clear(), tt[i].clear(); e1.clear(); m2.clear(); solve(); } return 0; }