日々drdrする人のメモ

今日も明日もdrdr

CodeChef May Challenge 2018: Edges in Spanning Trees

CodeChefのMay Challenge 2018の問題: Edges in Spanning Trees (Code: EDGEST)
問題: https://www.codechef.com/MAY18A/problems/EDGEST

問題

1からNの番号がついた {N}個の頂点を含む2つの木 {T_1, T_2}が与えられる。

 {T_1}に含まれる各辺 {e_1}について、以下の条件を満たす {T_2}上の辺 {e_2}の数を計算せよ。

  •  {T_1}から辺 {e_1}を除去し {e_2}を追加したグラフ {T_1 - e_1 + e_2}が木になる
  •  {T_2}から辺 {e_2}を除去し {e_1}を追加したグラフ {T_2 - e_2 + e_1}が木になる

制約

1つのテストケースに含まれるケース数:  {1 \le T \le 10}
 {2 \le N \le 2 * 10^5}
1つのテストケースに含まれる {N}の合計は {2 * 10^5}を超えない

解法

結構悩んだ。
特に、各辺 {e_1}と交換できる辺 {e_2}をどのように管理して数え上げするかで無限に悩んだ。通すのに5日程度かかったけど、なんとか自力で通せたので嬉しい。


最終的に、LCA + 部分木の和を計算するEuler Tour (Binary Indexed Tree)で通した。


はじめに、 {T_1} {e_1}が繋ぐ頂点番号同士を繋ぐ辺が {T_2}にも存在する場合、この {e_1}に対する {e_2}の個数は1となる。
これ以降は同じ頂点番号を繋ぐ {e_2}が存在しない {e_1}について考える。


まず、各辺 {e_1 = (a, b)}と交換できる辺 {e_2 = (x, y)}が満たすべき条件は以下の2つが存在する

  1.  {T_2}において、頂点aと頂点bの間のパスa-b上に {e_2}が含まれる (下図の左)
  2.  {T_1}から {e_1}を除去した時にできる2つの部分木それぞれに頂点xと頂点yが片方ずつ含まれる (下図の右)

f:id:smijake3:20180515225016p:plain

この問題では制約上、 {N-1}個の各辺 {e_1}に対し、条件(1)と(2)を満たす辺 {e_2}の個数を {O(\log N)}で計算する必要がある。


今回は、各 {e_1 = (a, b)}について、木 {T_2}において頂点aと頂点bの間のパスa-b上に存在する辺 {e_2}の中から、条件(2)を満たす辺の個数を数える。

この数え上げを行う際、パスa-b上の数え上げを行う代わりに、2つのパスに分解してパスa-w上パスw-b上に分けて数え上げを行うことを考える。この頂点wは、木 {T_2}における頂点aと頂点bのLCAとする。
また、条件(2)を満たす辺 {e_2}を数えるときに {T_2}において条件(1)を満たさない根頂点Oと頂点wの間のパスO-w上に存在する辺を含めて数え上げて、あとからそれらを引いても答えを計算できる。

これらのことから、各辺 {e_1 = (a, b)}について条件を満たす辺 {e_2}の数は、
 {f_{e_1}(p, q)} = ( {e_1}について {T_2}上のパスp-q上に存在する辺 {e_2}の中で条件(2)を満たす辺の数)
とすると、
 {\underline{f_{e_1}(O, a) + f_{e_1}(O, b) - 2f_{e_1}(O, w)}}
を計算することで求まることが分かる。
f:id:smijake3:20180515231055p:plain

この変形によって、 {T_2}の根頂点Oからある頂点vまでに含まれる辺 {e_2}の中から条件(2)を満たす辺を数え上げる問題になるため、 {T_2}を根からDFSで遷移しながら数え上げできるようになる。


次に、DFSを行いながら各 {e_1}について {T_2}上のパスO-v間に存在する辺 {e_2}の内、条件(2)を満たす辺の個数 {f_{e_1}(O, v)}を数えることを考える。
これは条件(2)の通りに、 {T_2}上のパスO-v間に存在する辺 {e_2}の内、片方の部分木に辺 {e_2}が繋ぐ内の一方の頂点のみが含まれる個数を計算すればよい。

この数え上げはEuler Tourとセグ木 (or BIT)を使えば実現できる。
Euler Tourを使うことで、ある頂点v以下の部分木に含まれる頂点の値の和や頂点の値の更新を {O(\log N)}で計算することができる。

今回は {T_1}上のEuler Tourを構築し、セグ木でクエリを処理できるようにした上で、 {T_2}上をDFSをしながらクエリを処理していく。
まず、DFSで新たに子ノードを訪れる際に辺 {e_2 = (x, y)}を通過する場合は、セグ木で頂点x, yに+1、頂点x, yのLCAである頂点uに-2を足して更新する。逆に子ノードを訪れ終わってDFSにおける後退を行う場合は逆の更新を行う。
そして、頂点vに到達した時点で、辺 {e_1 = (a, b)}の頂点a, bのうち深さが深い方の頂点以下の部分木の和を計算することで {f_{e_1}(O, v)}が計算できる。

部分和と辺 {e_2}の値の関係を図にすると以下の感じになる。 {e_2}の片方の頂点が部分木に含まれている場合(下図の右)のみ部分木の和が+1され、両方含まれる(下図の左)や両方含まれない場合(下図の中央)は+1されない。
f:id:smijake3:20180515232143p:plain

このDFS実装を行う際、"辺 {e_2}を通過する" = "辺 {e_2}が繋ぐ頂点の内、深さが深い方の頂点を訪れる" と解釈すると実装しやすい。


ここまでをまとめた解法は以下の通りである。

  •  {T_1}のEuler Tourによって部分木の和をセグ木で管理する
  •  {T_2}をDFSし、通過する辺 {e_2 = (x, y)}に応じてセグ木を更新する
    • 新しく通過した辺(つまり、DFSにおける前進)の場合、頂点xと頂点yに+1、頂点w=LCA(x, y)に-2する
    • 既に通過してた辺(つまり、DFSにおける後退)の場合、頂点xと頂点yに-1、頂点w=LCA(x, y)に+2する
  • 各辺 {e_1 = (a, b)}について、頂点aと頂点b、頂点u=LCA(a, b)に到達した時点でセグ木で部分木の和を計算する (この時depth[a] < depth[b]とする)
    • 頂点aもしくは頂点bに到達した場合、頂点b以下の部分木に含まれる頂点の値の和を計算し、辺 {e_1}に対する解に、計算した和を足す
    • 頂点uに到達した場合、頂点b以下の部分木に含まれる頂点の和の値を計算し、辺 {e_1}に対する解から、計算した和の2倍の値を引く

実装

計算量は {O(N \log{N})}

PyPy2では間に合わなかったのでC++

提出コード (C++14): https://www.codechef.com/viewsolution/18529025

#define N 200005
#define L 20
 
int n, l;
vector<int> g1[N], g2[N], ss[N], tt[N];
vector<P> e1;
set<P> m2;
 
 
int prev1[N][L], prev2[N][L];
int d1[N], d2[N];
int unq[N], ans[N];
int aa[N], bb[N];

// LCAの前準備
void lca(vector<int> g[N], int prev[N][L], int d[N]) {
  queue<int> que;
  rep(i, n) rep(j, l) prev[i][j] = -1;
  rep(i, n) d[i] = -1;
  d[0] = 0;
  que.push(0);
  while(!que.empty()) {
    int v = que.front(); que.pop();
    rep(i, g[v].size()) {
      int w = g[v][i];
      if(d[w] == -1) {
        prev[w][0] = v;
        d[w] = d[v] + 1;
        que.push(w);
      }
    }
  }
  repl(k, 1, l-1) {
    rep(i, n) {
      if(prev[i][k-1] != -1) {
        prev[i][k] = prev[prev[i][k-1]][k-1];
      }
    }
  }
}

// LCAのクエリ処理
int query(int u, int v, int prev[N][L], int d[N]) {
  int dd = d[v] - d[u];
  if(dd < 0) {
    dd = -dd;
    int tmp = u; u = v; v = tmp;
  }
  rep(k, l) {
    if(dd & 1) v = prev[v][k];
    dd >>= 1;
  }
 
 
  if(u == v) return u;
 
  repr(k, l) {
    if(prev[u][k] != prev[v][k]) {
      u = prev[u][k]; v = prev[v][k];
    }
  }
  return prev[u][0];
}

// Euler Tourのクエリのための前処理
int sz, data[N];
void dfs0(int v, int p) {
  aa[v] = sz;
  rep(i, g1[v].size()) {
    int w = g1[v][i];
    if(w == p) continue;
    dfs0(w, v);
  }
  bb[v] = ++sz;
}

// Binary Indexed Tree 
int get(int k) {
  int s = 0;
  while(k > 0) {
    s += data[k];
    k -= k & -k;
  }
  return s;
}
 
void add(int k, int x) {
  while(k <= sz) {
    data[k] += x;
    k += k & -k;
  }
}

// 各e_1の解を計算するためのDFS
int rr[N];
void dfs(int v, int p) {
  int w0 = -1;
  if(unq[v]) {
    w0 = query(v, p, prev1, d1);
    // DFSの前進: 頂点v,pに+1, LCAのw0に-2
    add(bb[w0], -2);
    add(bb[v], 1);
    add(bb[p], 1);
  }
 
  // e_1 = (a, b)に対してf(O, a) or f(O, b)を計算
  rep(j, ss[v].size()) {
    int i = ss[v][j];
    int s = rr[i];
    ans[i] += get(bb[s]) - get(aa[s]);
  }
  // e_1 = (a, b)に対してf(O, LCA(a, b))を計算
  rep(j, tt[v].size()) {
    int i = tt[v][j];
    int s = rr[i];
    ans[i] -= (get(bb[s]) - get(aa[s])) * 2;
  }
 
  rep(i, g2[v].size()) {
    int w = g2[v][i];
    if(w == p) continue;
    dfs(w, v);
  }
 
  if(unq[v]) {
    // DFSの後退: 頂点v, pに-1, LCAのw0に+2
    add(bb[w0], 2);
    add(bb[v], -1);
    add(bb[p], -1);
  }
}
 
 
void solve() {
  cin >> n;
  rep(i, n-1) {
    int a, b; cin >> a >> b; --a; --b;
    g1[a].push_back(b);
    g1[b].push_back(a);
    e1.push_back(a < b ? P(a, b) : P(b, a));
  }
  rep(i, n-1) {
    int a, b; cin >> a >> b; --a; --b;
    g2[a].push_back(b);
    g2[b].push_back(a);
    m2.insert(a < b ? P(a, b) : P(b, a));
  }
  l = 1;
  int v = 1;
  while(v < n) v <<= 1, l++;
 
  lca(g1, prev1, d1);
  lca(g2, prev2, d2);

  // unq[i] = 1: iに紐づく辺e_2はe_1と重複するため、計算対象から除外
  rep(i, n) unq[i] = 1;
  unq[0] = 0;
 
  rep(i, n) ans[i] = 0;
  rep(i, n-1) {
    P &e = e1[i];
    int a = e.first, b = e.second;
    if(m2.find(e) != m2.end()) {
      ans[i] = 1;
      if(d2[a] < d2[b]) {
        unq[b] = 0;
      } else {
        unq[a] = 0;
      }
      continue;
    }
    int w = query(a, b, prev2, d2);
    tt[w].push_back(i);
    ss[a].push_back(i);
    ss[b].push_back(i);
  }
  sz = 0; dfs0(0, -1);
  rep(i, sz+1) data[i] = 0;
 
  rep(i, n-1) {
    P &e = e1[i];
    int a = e.first, b = e.second;
    rr[i] = (d1[a] < d1[b] ? b : a);
  }
 
  dfs(0, -1);
  rep(i, n-1) {
    cout << (i ? " " : "") << ans[i];
  }
  cout << endl;
}
 
int main() {
  int t;
  cin >> t;
  while(t--) {
    rep(i, N) g1[i].clear(), g2[i].clear(), ss[i].clear(), tt[i].clear();
    e1.clear();
    m2.clear();
 
    solve();
  }
  return 0;
}