mini notes

競技プログラミングの解法メモを残していきます。

ABC133 F - Colorful Tree (600)

F - Colorful Tree

概要

N頂点の木が与えられる。辺iは頂点ai, biを結び、色ciがついている。また辺の長さはdiである。
Q個のクエリが与えられる。クエリiではxi, yi, ui, viが与えられ、下記の問いに答えよ。

  • 色xiの辺の長さがyiになったとき、uiからviまでの距離を求めよ。
制約

2 ≦ N ≦ 10^5
1 ≦ Q ≦ 10^5
1 ≦ d, y ≦ 10^4

方針

クエリが与えられる前の各頂点間の距離は根付き木におけるLCA(Lowest Common Ancestor)を用いて求められる。
具体的には、l: u, vのLCA、dist[u]:根からuへの距離とすると、uv間の距離 = dist[u] + dist[v] - 2 * dist[l]

LCAの求め方は下記のとおり。
・uとvのうち、根からの深さが深い方について、浅い方と同じ高さになるまで親頂点を辿る。
例えばdepth[u] < depth[v] ならv <- parent[v]を繰り返してdepth[u] = depth[v]とする。
・u <- parent[u], v<-parent[v]としてゆき、初めてu=vとなるuが元のu, vのLCA
これは、ダブリングという手法を用いて高速化できる。

クエリ(x, y, u, v)が与えられた場合の処理について。
num[u]: 根からuまでの色xの辺の個数、sum[u]: 根からuまでの色xの辺の距離の合計が分かっているとする。
このとき、色xの辺の距離をyとした後の根から頂点uまでの距離をdist'とすると、
dist'[u] = dist[u] - sum[u] + num[u] * y となり、求める距離はdist'[u] + dist'[v] - 2 * dist'[l]で求めることが出来る。

全ての色、頂点についてnum, sumを求めようとすると、N^2オーダーのメモリが必要だが、必要なのはQ個の各クエリについて、xiの1色とui, vi, liの3頂点であるため、クエリ先読みを生かすとQオーダーのメモリで十分である。

感想

アルメリアさんのブログを大いに参考にしました。ありがたいです。。
betrue12.hateblo.jp

解答

Submission #6530743 - AtCoder Beginner Contest 133

#include <bits/stdc++.h>
#define rep(i,n) for(int i=(0);i<(n);i++)
 
using namespace std;
 
typedef long long ll;
 
struct LCA{
    static const int BITLEN_MAX = 30;
    vector<int> parent[BITLEN_MAX];
    vector<int> depth;
    int bitlen;
 
    void initialize(int N, const vector<int> edges[]){
        int root = 0;
        bitlen = 1;
        while((1<<bitlen) < N) bitlen += 1;
        for(int i=0; i<bitlen; i++) parent[i].resize(N);
        depth.resize(N, -1);
 
        dfs(root, -1, 0, edges);
        for(int k=0; k<bitlen-1; k++){
            for(int v=0; v<N; v++){
                if(depth[v] == -1) continue;
                if(parent[k][v] < 0){
                    parent[k+1][v] = -1;
                }else{
                    parent[k+1][v] = parent[k][parent[k][v]];
                }
            }
        }
    }
 
    void dfs(int v, int p, int d, const vector<int> edges[]){
        parent[0][v] = p;
        depth[v] = d;
        for(auto u : edges[v]){
            if(u != p) dfs(u, v, d+1, edges);
        }
    }
 
    int calc_lca(int u, int v){
        if(depth[u] > depth[v]) swap(u, v);
        for(int k=0; k<bitlen; k++){
            if( ((depth[v]-depth[u]) >> k) & 1 ) v = parent[k][v];
        }
        if(u == v) return u;
        for(int k=bitlen-1; k>=0; k--){
            if(parent[k][u] != parent[k][v]){
                u = parent[k][u];
                v = parent[k][v];
            }
        }
        return parent[0][u];
    }
 
    int calc_dist(int u, int v){
        int l = calc_lca(u, v);
        return depth[u] + depth[v] - depth[l]*2;
    }
};
 
int N;
vector<vector<int>> edges[101010];
vector<int> edges2[101010];
 
vector<int> need[101010];
map<int, int> res_num[101010], res_sum[101010];
 
int dist[101010];
int now_num[101010], now_sum[101010];
 
 
void dfs(int i, int p){
	for(int c : need[i]){
		res_num[i][c] = now_num[c];
		res_sum[i][c] = now_sum[c];
	}
 
	for(vector<int> e : edges[i]){
		int j = e[0], c = e[1], d = e[2];
		if(j == p) continue;
 
		dist[j] = dist[i] + d;
		now_num[c]++;
		now_sum[c] += d;
		dfs(j, i);
		now_num[c]--;
		now_sum[c] -= d;
	}
}
 
int main()
{
	cin.tie(0);
	ios::sync_with_stdio(false);
 
	int Q;
	cin >> N >> Q;
 
	rep(i, N-1){
		int a, b, c, d;
		cin >> a >> b >> c >> d;
		a--; b--; c--;
 
		edges[a].push_back({b, c, d});
		edges[b].push_back({a, c, d});
		edges2[a].push_back(b);
		edges2[b].push_back(a);
	}
 
	LCA lca;
	lca.initialize(N, edges2);
 
	vector<int> X(Q), Y(Q), U(Q), V(Q), L(Q);
	rep(i, Q){
		cin >> X[i] >> Y[i] >> U[i] >> V[i];
		X[i]--; U[i]--; V[i]--;
 
		L[i] = lca.calc_lca(U[i], V[i]);
		need[U[i]].push_back(X[i]);
		need[V[i]].push_back(X[i]);
		need[L[i]].push_back(X[i]);
	}
 
	dfs(0, -1);
 
	rep(i, Q){
		int x = X[i], y = Y[i], u = U[i], v = V[i], l = L[i];
 
		int du = dist[u] - res_sum[u][x] + res_num[u][x] * y;
		int dv = dist[v] - res_sum[v][x] + res_num[v][x] * y;
		int dl = dist[l] - res_sum[l][x] + res_num[l][x] * y;
		int ans = du + dv - 2 * dl;
		cout << ans << endl;
	}
}