ぺんぎんメモ

プログラミングのメモです。たまに私生活のことや鬱っぽいことを書きます。

全方位木DP(ReRooting)

木DPを使って解ける問題には、以下のものがあります。

  • ある頂点から最も遠い頂点までの距離はいくつか
  • ある頂点を根としたときの木の深さはいくつか

これらの問題に共通していることは、「ある頂点」という文言が含まれていることです。頂点数を\(N\)とおいたとき、ひとつの頂点に対して上の問題を解くのに\(O(N)\)かかるため、すべての頂点に対して上の問題を解くと\(O(N ^ 2)\)かかることになります。

全方位木DPを使うことで、この計算量を\(O(N)\)まで落とすことができます。すべての頂点に対して木DPを行いたいときは、全方位木DPを使うといいかもしれません。

実装と使用例

template <typename T, typename D>
struct ReRooting {
  using L = function<T(T, D)>;
  using M = function<T(T, T)>;
  struct Edge {
    int to;
    D data;
    Edge(int to, D data) : to(to), data(data) {}
  };
  int n;
  vector<vector<Edge>> edges;
  L lift;
  M merge;
  T e;
  vector<vector<T>> dp;
  vector<int> par;
  vector<T> ans;
  ReRooting(int n, L lift, M merge, T e) : n(n), lift(lift), merge(merge), e(e) {
    edges = vector<vector<Edge>>(n);
  }
  void add_edge(int v, int u, D d) {
    edges[v].emplace_back(u, d);
  }
  T dfs(int v, int p = -1) {
    int l = edges[v].size();
    dp[v] = vector<T>(l);
    T ret = e;
    for (int i = 0; i < l; i++) {
      auto e = edges[v][i];
      if (e.to == p) {
        par[v] = i;
        continue;
      }
      dp[v][i] = lift(dfs(e.to, v), e.data);
      ret = merge(ret, dp[v][i]);
    }
    return ret;
  }
  void bfs(int v, T t) {
    if (par[v] != -1) {
      dp[v][par[v]] = lift(t, edges[v][par[v]].data);
    }
    int l = edges[v].size();
    vector<T> dpl(l + 1, e);
    vector<T> dpr(l + 1, e);
    for (int i = 0; i < l; i++) {
      dpl[i + 1] = merge(dpl[i], dp[v][i]);
    }
    for (int i = l - 1; i >= 0; i--) {
      dpr[i] = merge(dpr[i + 1], dp[v][i]);
    }
    ans[v] = dpr[0];
    for (int i = 0; i < l; i++) {
      if (i == par[v]) continue;
      int u = edges[v][i].to;
      bfs(u, merge(dpl[i], dpr[i + 1]));
    }
  }
  vector<T> solve() {
    dp = vector<vector<T>>(n);
    par = vector<int>(n, -1);
    ans = vector<T>(n);
    dfs(0);
    bfs(0, e);
    return ans;
  }
};

全方位木DPの抽象部分を抜き出した構造体です。solve()は、すべての頂点のマージし終えた結果を返します。

欲しい値が「マージし終えた結果」ではなく「マージ前の情報」のときがあります。そういったときは、少々複雑なマージの仕方を考えるか、構造体の中身を修正する必要があります。後者の場合、修正箇所はbfs()ans[v] = dpr[0];の行になります。それ以外を修正することはおそらくありません。

使用例

木の直径 | グラフ | Aizu Online Judge

2パターン載せます。最初の実装はReRooting構造体を修正しない方針で、2つ目は構造体を修正する方針です。

ひとつめ

構造体を修正しない方針です。構造体の定義は省略しています。マージの仕方を少し考える必要があります。

int main() {
  using P = pair<int, int>;
  auto lift = [](P a, int w) {
    return P(max(0, a.first) + w, -1e9);
  };
  auto merge = [](P a, P b) {
    for (auto t : { b.first, b.second }) {
      if (a.first < t) {
        a.second = a.first;
        a.first = t;
      } else if (a.second < t) {
        a.second = t;
      }
    }
    return a;
  };
  int N; cin >> N;
  auto g = ReRooting<P, int>(N, lift, merge, P(-1e9, -1e9));
  for (int i = 1; i < N; i++) {
    int s, t, w; cin >> s >> t >> w;
    g.add_edge(s, t, w);
  }
  int ans = 0;
  for (auto r : g.solve()) {
    ans = max(ans, r.first);
    ans = max(ans, r.first + r.second);
  }
  cout << ans << '\n';
  return 0;
}
ふたつめ

構造体の中身を修正する方針です。木DPで求めるものは「最も遠い頂点までの距離」とシンプルです。ans周りの実装も素直です。

template <typename T, typename D>
struct ReRooting {
  using L = function<T(T, D)>;
  using M = function<T(T, T)>;
  struct Edge {
    int to;
    D data;
    Edge(int to, D data) : to(to), data(data) {}
  };
  int n;
  vector<vector<Edge>> edges;
  L lift;
  M merge;
  T e;
  vector<vector<T>> dp;
  vector<int> par;
  int ans;
  ReRooting(int n, L lift, M merge, T e) : n(n), lift(lift), merge(merge), e(e) {
    edges = vector<vector<Edge>>(n);
  }
  void add_edge(int v, int u, D d) {
    edges[v].emplace_back(u, d);
  }
  T dfs(int v = 0, int p = -1) {
    int l = edges[v].size();
    dp[v] = vector<T>(l);
    T ret = e;
    for (int i = 0; i < l; i++) {
      auto e = edges[v][i];
      if (e.to == p) {
        par[v] = i;
        continue;
      }
      dp[v][i] = lift(dfs(e.to, v), e.data);
      ret = merge(ret, dp[v][i]);
    }
    return ret;
  }
  void bfs(int v, T t) {
    if (par[v] != -1) {
      dp[v][par[v]] = lift(t, edges[v][par[v]].data);
    }
    int l = edges[v].size();
    vector<T> dpl(l + 1, e);
    vector<T> dpr(l + 1, e);
    for (int i = 0; i < l; i++) {
      dpl[i + 1] = merge(dpl[i], dp[v][i]);
    }
    for (int i = l - 1; i >= 0; i--) {
      dpr[i] = merge(dpr[i + 1], dp[v][i]);
    }

    int fst = 0, snd = 0;
    for (int i = 0; i < l; i++) {
      if (fst <= dp[v][i]) {
        snd = fst, fst = dp[v][i];
      } else if (snd <= dp[v][i]) {
        snd = dp[v][i];
      }
    }
    ans = max(ans, fst + snd);

    for (int i = 0; i < l; i++) {
      if (i == par[v]) continue;
      int u = edges[v][i].to;
      bfs(u, merge(dpl[i], dpr[i + 1]));
    }
  }
  int solve() {
    dp = vector<vector<T>>(n);
    par = vector<int>(n, -1);
    ans = 0;
    dfs();
    bfs(0, e);
    return ans;
  }
};

int main() {
  auto lift = [](int a, int w) {
    return a + w;
  };
  auto merge = [](int a, int b) {
    return max(a, b);
  };
  int N; cin >> N;
  auto g = ReRooting<int, int>(N, lift, merge, 0);
  for (int i = 1; i < N; i++) {
    int s, t, w; cin >> s >> t >> w;
    g.add_edge(s, t, w);
  }
  cout << g.solve() << '\n';
  return 0;
}

参考