エイシングプログラミングコンテスト2021(AtCoder Beginner Contest 202)- Count Descendants

問題リンク

オンラインとオフラインどっちが好みの人が多いんでしょうか?

解法

問題文を整理すると,各クエリ  i において,「深さが  D _ {i} でかつ, U _ {i} の部分木に属しているものの個数を求めよ」となります. ある頂点  v が部分木に属するかどうかは,オイラーツアー (参考 : Euler Tour のお勉強 | maspyのHP) などをすると, \mathrm{in}[U _ {i}] \leq \mathrm{in}[z] \lt \mathrm{out}[U _ {i}] かどうかを確認すればわかります.各クエリは「深さが  D _ {i} でかつ, [\mathrm{in}[U _ {i}], \mathrm{out}[U_{i}]) である頂点の個数を求めよ」という問題になります.

オンラインクエリで解く方法

BinaryTrie と呼ばれるデータ構造 ( 参考 : 非負整数値を扱う Trie について - kazuma8128’s blog ) を各深さ毎に持ちます.各深さの BinaryTrie にそれぞれの, \mathrm{in}[i] を管理すると,BinaryTrie は区間にある頂点の個数を求めるのは  O(\log N) でできることから,全体で  O(N + (Q + N) \log N) で解くことができました.なお,座標圧縮などをすることにより,BinaryIndexedTree を用いてもオンラインクエリで解くことができます.

オフラインクエリで解く方法

区間和を求めたくなりますが,各深さごとに区間  N の BinaryIndexedTree を持つのは,メモリが厳しいです.ここで,1 つの BinaryIndexedTree を全ての深さで使いまわします.予め,クエリを先読みし,深さごとにクエリを管理します.そして,各深さごとに順にクエリを処理していきます.ある深さ d のクエリを処理する場合,まず,深さ  d \mathrm{in}[i] を BinaryIndexedTree で管理します. そして,深さ  d に関するクエリについて処理します.深さ  d のクエリを全て処理し終えたら,深さ  d \mathrm{in}[i] を BinaryIndexedTree から取り除きます.計算量は  O(N + (Q + N) \log N) です.

実装例 ( オンライン)

#include <iostream>
#include <string>
#include <algorithm>
#include <vector>
#include <queue>
#include <utility>
#include <tuple>
#include <cmath>
#include <numeric>
#include <set>
#include <map>
#include <array>
#include <complex>
#include <iomanip>
#include <cassert>
#include <random>
#include <chrono>
#include <valarray>
#include <bitset>
using ll = long long;
using std::cin;
using std::cout;
using std::endl;
std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }
const int inf = (int)1e9 + 7;
const long long INF = 1LL << 60;

namespace KKT89
{
    template< typename T, int MAX_LOG, typename D = int >
    struct BinaryTrie {
    public:
      struct Node {
        Node *nxt[2];
        D exist;
        std::vector< int > accept;

        Node() : nxt{nullptr, nullptr}, exist(0) {}
      };

      Node *root;

      explicit BinaryTrie() : root(new Node()) {}

      explicit BinaryTrie(Node *root) : root(root) {}

      void add(const T &bit, int idx = -1, D delta = 1, T xor_val = 0) {
        root = add(root, bit, idx, MAX_LOG, delta, xor_val);
      }

      void erase(const T &bit, T xor_val = 0) {
        add(bit, -1, -1, xor_val);
      }

      Node *find(const T &bit, T xor_val = 0) {
        return find(root, bit, MAX_LOG, xor_val);
      }

      D count(const T &bit, T xor_val = 0) {
        auto node = find(bit, xor_val);
        return node ? node->exist : 0;
      }

      std::pair< T, Node * > min_element(T xor_val = 0) {
        assert(root->exist > 0);
        return kth_element(0, xor_val);
      }

      std::pair< T, Node * > max_element(T xor_val = 0) {
        assert(root->exist > 0);
        return kth_element(root->exist - 1, xor_val);
      }

      std::pair< T, Node * > kth_element(D k, T xor_val = 0) { // 0-indexed
        assert(0 <= k && k < root->exist);
        return kth_element(root, k, MAX_LOG, xor_val);
      }

      D count_less(const T &bit, T xor_val = 0) { // < bit
        return count_less(root, bit, MAX_LOG, xor_val);
      }

    private:

      virtual Node *clone(Node *t) { return t; }

      Node *add(Node *t, T bit, int idx, int depth, D x, T xor_val, bool need = true) {
        if(need) t = clone(t);
        if(depth == -1) {
          t->exist += x;
          if(idx >= 0) t->accept.emplace_back(idx);
        } else {
          bool f = (xor_val >> depth) & 1;
          auto &to = t->nxt[f ^ ((bit >> depth) & 1)];
          if(!to) to = new Node(), need = false;
          to = add(to, bit, idx, depth - 1, x, xor_val, need);
          t->exist += x;
        }
        return t;
      }

      Node *find(Node *t, T bit, int depth, T xor_val) {
        if(depth == -1) {
          return t;
        } else {
          bool f = (xor_val >> depth) & 1;
          auto &to = t->nxt[f ^ ((bit >> depth) & 1)];
          return to ? find(to, bit, depth - 1, xor_val) : nullptr;
        }
      }

      std::pair< T, Node * > kth_element(Node *t, D k, int bit_index, T xor_val) { // 0-indexed
        if(bit_index == -1) {
          return {0, t};
        } else {
          bool f = (xor_val >> bit_index) & 1;
          if((t->nxt[f] ? t->nxt[f]->exist : 0) <= k) {
            auto ret = kth_element(t->nxt[f ^ 1], k - (t->nxt[f] ? t->nxt[f]->exist : 0), bit_index - 1, xor_val);
            ret.first |= T(1) << bit_index;
            return ret;
          } else {
            return kth_element(t->nxt[f], k, bit_index - 1, xor_val);
          }
        }
      }

      D count_less(Node *t, const T &bit, int bit_index, T xor_val) {
        if(bit_index == -1) return 0;
        D ret = 0;
        bool f = (xor_val >> bit_index) & 1;
        if((bit >> bit_index & 1) and t->nxt[f]) ret += t->nxt[f]->exist;
        if(t->nxt[f ^ (bit >> bit_index & 1)]) ret += count_less(t->nxt[f ^ (bit >> bit_index & 1)], bit, bit_index - 1, xor_val);
        return ret;
      }
    };
}
void solve()
{
    int n; cin >> n;
    std::vector<std::vector<int>> g(n);
    for (int i = 1; i < n; ++i)
    {
        int p; cin >> p;
        p -= 1;
        g[i].emplace_back(p);
        g[p].emplace_back(i);
    }
    std::vector<int> dep(n), in(n), out(n);
    std::vector<KKT89::BinaryTrie<unsigned, 20, int>> vb(n);
    {
        int idx = 0;
        auto dfs = [&](auto &&self, int cur, int pre)->void
        {
            if(pre >= 0)
                dep[cur] = dep[pre] + 1;
            in[cur] = idx++;
            vb[dep[cur]].add(in[cur]);
            for(const auto &nxt : g[cur])
            {
                if(nxt == pre)
                    continue;
                self(self, nxt, cur);
            }
            out[cur] = idx;
        };
        dfs(dfs, 0, -1);
    }
    int kkt; cin >> kkt;
    while(kkt--)
    {
        int u, d; cin >> u >> d;
        u -= 1;
        cout << vb[d].count_less(out[u]) - vb[d].count_less(in[u]) << "\n";
    }
}
int main()
{
  std::cin.tie(nullptr);
  std::ios::sync_with_stdio(false);


  int kkt = 1;
  // cin >> kkt;
  while(kkt--)
    solve();
  return 0;
}

実装例 ( オフライン )

#include <iostream>
#include <string>
#include <algorithm>
#include <vector>
#include <queue>
#include <utility>
#include <tuple>
#include <cmath>
#include <numeric>
#include <set>
#include <map>
#include <array>
#include <complex>
#include <iomanip>
#include <cassert>
#include <random>
#include <chrono>
#include <valarray>
#include <bitset>
using ll = long long;
using std::cin;
using std::cout;
using std::endl;
std::mt19937 rnd(std::chrono::steady_clock::now().time_since_epoch().count());
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }
const int inf = (int)1e9 + 7;
const long long INF = 1LL << 60;

namespace KKT89
{
    template<typename T>
    struct BinaryIndexedTree {
        int n;
        std::vector<T> bit;
        BinaryIndexedTree() :n(0) {}
        BinaryIndexedTree(int _n) :n(_n) { bit = std::vector<T>(n + 1); }
        void add1(int idx, T val) {
            for (int i = idx; i <= n; i += i & -i) bit[i] += val;
        }
        T sum1(int idx) {
            T res = 0;
            for (int i = idx; i > 0; i -= i & -i) res += bit[i];
            return res;
        }
        //0-indexed
        void add(int idx, T val) { add1(idx + 1, val); }
        //0-indexed [left, right)
        T sum(int left, int right) { return sum1(right) - sum1(left); }

        int lower_bound(T x) {
            int res = 0;
            int k = 1;
            while (2 * k <= n) k <<= 1;
            for (; k > 0; k >>= 1) {
                if (res + k <= n and bit[res + k] < x) {
                    x -= bit[res + k];
                    res += k;
                }
            }
            return res;
        }
    };
}

void solve()
{
    int n; cin >> n;
    std::vector<std::vector<int>> g(n);
    for (int i = 1; i < n; ++i)
    {
        int p; cin >> p;
        p -= 1;
        g[i].emplace_back(p);
        g[p].emplace_back(i);
    }
    KKT89::BinaryIndexedTree<ll> bit(n);
    std::vector<int> dep(n), in(n), out(n);
    std::vector<std::vector<int>> vdep(n);
    {
        int idx = 0;
        auto dfs = [&](auto &&self, int cur, int pre)->void
        {
            if(pre >= 0)
                dep[cur] = dep[pre] + 1;
            in[cur] = idx++;
            for(const auto &nxt : g[cur])
            {
                if(nxt == pre)
                    continue;
                self(self, nxt, cur);
            }
            out[cur] = idx;
        };
        dfs(dfs, 0, -1);
        for (int i = 0; i < n; ++i)
        {
            vdep[dep[i]].emplace_back(in[i]);
        }
    }
    int kkt; cin >> kkt;
    std::vector<std::vector<std::pair<int, int>>> query(n);
    std::vector<int> res(kkt);
    for (int i = 0; i < kkt; ++i)
    {
        int u, d; cin >> u >> d;
        u -= 1;
        query[d].emplace_back(u, i);
    }
    for (int i = n - 1; i >= 0; --i)
    {
        for(const auto &idx : vdep[i])
        {
            bit.add(idx, 1);
        }
        for(const auto &[u, idx] : query[i])
        {
            //cout << in[u] << " " << out[u] << endl;
             res[idx] = bit.sum(in[u], out[u]);
        }
        for(const auto &idx : vdep[i])
        {
            bit.add(idx, -1);
        }
    }
    for (int i = 0; i < kkt; ++i)
    {
        cout << res[i] << "\n";
    }
}
int main()
{
  std::cin.tie(nullptr);
  std::ios::sync_with_stdio(false);


  int kkt = 1;
  // cin >> kkt;
  while(kkt--)
    solve();
  return 0;
}