日々drdrする人のメモ

今日も明日もdrdr

CodeChef December Cook-Off 2018: Swag Subsets

CodeChef December Cook-Off 2018 の問題: Swag Subsets (Code: SOSTD)
問題ページ: https://www.codechef.com/COOK101A/problems/SOSTD

コンテスト中に無限に実装バグらせて解けなくて、後で解いたやつ。

問題概要

2つの整数の列  {A_1, A_2, ..., A_N} {B_1, B_2, .., B_N} が与えられる。

ここで、  {S \subseteq \{1, 2, ..., N\}}となる空でない部分集合 {S}に対してswagness

 { \displaystyle ( \max_{(p \in S)} A_p ) \cdot ( \max_{(p \in S)} B_p ) }

で定義する。

 {2^N - 1}通り考えられる空でない部分集合 {S}swagnessの全ての和をMOD  {10^9 + 7} で求めよ。

制約

  • 1つのテストケースに含まれるケース数:  {1 \le T \le 1000}
  •  {2 \le N \le 2 \times 10^5}
  •  {1 \le A_i, B_i \le 10^6}
  • 1つのテストケースの {N}の和は  {10^6} を超えない

解法

遅延評価セグメント木 + BIT

解法の方針は、2つの数列の中で、最大となる {A_i} {B_j}を決め打ちし、その組合せの通り数を掛けて和を求めていく。

具体的には、 {A_{k_0} \le A_{k_1} \le ... \le A_{k_{N-1}}} を満たすように2つの数列を並べ替え、
 {i = 1}から、 {S \subseteq \{k_0, k_1, ..., k_{i-1}, k_i\}, k_i \in S}となる集合Sのswagnessの和を求めていく。


ここで、 {S \subseteq \{k_0, k_1, ..., k_{i-1}, k_i\}, k_i \in S}となる全ての集合Sについて考える。

 {B_{k_0}, B_{k_1}, ..., B_{k_{i-1}}} のうち、  {B_{k_i}}より小さい整数の数を {M}とすると、swagness {A_{k_i} \cdot B_{k_i}} になる集合Sの通り数は  {2^M} になる。

また、 {B_{k_0}, B_{k_1}, ..., B_{k_{i-1}}} のうち、 {B_{k_i}} から j番目に大きい数 (  {j = 1, 2, ...} ) を  {B_{d_j}} としたとき、swagness {A_{k_i} \cdot B_{d_j}} になる集合Sの通り数は  {2^{M + j - 1}} になる。

よって、 {B_{k_i} < B_{d_1} < B_{d_2} < ...} について、swagnessの和は

 {\displaystyle A_{k_i} ( 2^M \cdot B_{k_i} + \sum_j 2^{M+j-1} \cdot B_{d_j} ) }

で求まる。



これを各 {i}について計算することで、解が求まる。この計算は遅延評価セグメント木 + BITで実現できる。

具体的に、遅延評価セグメント木は数列  {a_1, ..., a_m} に対して以下のクエリを処理できるように実装すればよい。

1.  {a_l, a_{l+1}, ..., a_{r-1}}の和を求める
2.  {a_l, a_{l+1}, ..., a_{r-1}}をx倍する
3.  {a_i} の値に x を加算する

BITでは、 {2^{(B_{k_i}\text{よりも小さい数の個数})}}を管理する。

1ケースの計算量は  {O(N \log N)}

実装

PyPy2では重かった。
提出コード: Solution: 22071761 | CodeChef

C++の遅延評価セグメント木を持ってなくて困った。この後きちんと実装しておく。

提出コード: Solution: 22072105 | CodeChef

#define N 200006
#define LV 19
#define MOD 1000000007

// 遅延評価セグメント木
int n0, lv;

int ids[LV*2], icur = 0;
ll lazy[4*N], data[4*N];

void st_init(int n) {
  n0 = 1; lv = 0;
  while(n0 < n) n0 <<= 1, lv++;
  rep(i, 2*n0) lazy[i] = 1, data[i] = 0;
}

void gindex(int l, int r) {
  int li = (l + n0) >> 1, ri = (r + n0) >> 1;
  int lc = (l & 1) ? 0 : __builtin_ffs(li);
  int rc = (r & 1) ? 0 : __builtin_ffs(ri);
  icur = 0;
  rep(i, lv) {
    if(rc <= i) {
      ids[icur++] = ri;
    }
    if(li < ri && lc <= i) {
      ids[icur++] = li;
    }
    li >>= 1; ri >>= 1;
  }
}

void propagates() {
  repr(i, icur) {
    int k = ids[i];
    ll v = lazy[k-1];
    if(v == 1) continue;
    lazy[2*k-1] = (lazy[2*k-1] * v) % MOD;
    lazy[2*k] = (lazy[2*k] * v) % MOD;
    data[2*k-1] = (data[2*k-1] * v) % MOD;
    data[2*k] = (data[2*k] * v) % MOD;
    lazy[k-1] = 1;
  }
}

void set_value(int k, ll x) {
  gindex(k, k+1);
  propagates();
  k += n0 - 1;
  data[k] = (data[k] + x) % MOD;
  while(k > 0) {
    k = (k - 1) >> 1;
    data[k] = (data[2*k+1] + data[2*k+2]) % MOD;
  }
}

void update(int l, int r, ll x) {
  gindex(l, r);
  propagates();

  int li = n0 + l, ri = n0 + r;
  while(li < ri) {
    if(ri & 1) {
      --ri;
      lazy[ri-1] = (lazy[ri-1] * x) % MOD;
      data[ri-1] = (data[ri-1] * x) % MOD;
    }
    if(li & 1) {
      lazy[li-1] = (lazy[li-1] * x) % MOD;
      data[li-1] = (data[li-1] * x) % MOD;
      ++li;
    }
    li >>= 1; ri >>= 1;
  }
  rep(i, icur) {
    int k = ids[i];
    data[k-1] = (data[2*k-1] + data[2*k]) % MOD;
  }
}

ll query(int l, int r) {
  gindex(l, r);
  propagates();

  int li = n0 + l, ri = n0 + r;

  ll s = 0;
  while(li < ri) {
    if(ri & 1) {
      s += data[(--ri)-1];
    }
    if(li & 1) {
      s += data[(li++)-1];
    }
    li >>= 1; ri >>= 1;
  }
  return s % MOD;
}

// BIT
// - 2^(x以下の要素の数)を管理
int n;
ll bdata[N];
int a[N], b[N];
P c[N];
int get(int k) {
  int s = 1;
  while(k) {
    s = (s * bdata[k]) % MOD;
    k -= k & -k;
  }
  return s;
}
void add(int k, int x) {
  while(k <= n) {
    bdata[k] = (bdata[k] * x) % MOD;
    k += k & -k;
  }
}

// 数列の圧縮
map<int, int> compress(int *p, int sz) {
  map<int, int> result;
  set<int> v;
  for(int i = 0; i < sz; ++i) {
    v.insert(p[i]);
  }
  int k = 0;
  for(int e : v) result[e] = k++;
  return result;
}


int main() {
  int t; cin >> t;
  while(t--) {
    cin >> n;
    rep(i, n) cin >> a[i];
    rep(i, n) cin >> b[i];
    rep(i, n) c[i] = P(a[i], b[i]);
    sort(c, c+n);

    map<int, int> mp = compress(b, n);
    int m = mp.size();

    rep(i, n+1) bdata[i] = 1;

    st_init(n);

    ll ans = 0;
    rep(i, n) {
      P &p = c[i];
      ll a = p.first, b = p.second;
      int k = mp[b];
      //  w = 2^M \cdot B_{k_i}
      ll w = (b * get(k+1)) % MOD;

      // query(k+1, m) = \sum_j 2^{M+j-1} \cdot B_{d_j}
      ans = (ans + (w + query(k+1, m)) * a % MOD) % MOD;

      // セグメント木内のB_{k_i} より大きい数を2倍
      update(k+1, m, 2);
      // BITのk_i番目を更新
      add(k+1, 2);
      // セグメント木のk_i番目にwを加える
      set_value(k, w);
    }

    cout << ans << endl;
  }
  return 0;
}