【CF1294F】Three Paths on a Tree


Description

给出一棵无权树(可理解为边权为 \(1\))。

你需要选取三个点 \(a,b,c\),最大化 \(a,b\)\(b,c\)\(a,c\) 的简单路径的并集的长度。

输出这个最大长度和 \(a,b,c\)

Solution

有一个结论:

必定会存在一组最优解,使得 \(a,b\) 是树直径上的端点。

那我们可以套路地去把树直径两端点求出来,推荐大家用两次搜索求出树直径端点。

确定了 \(a,b\),接下来我们只要去找到最优的 \(c\),就可以最大化答案了。

此时我们注意到:\(a,b\)\(b,c\)\(a,c\) 的简单路径的并集的长度其实就是 \(\frac{\text{dist}(a,b)+\text{dist}(b,c)+\text{dist}(a,c)}{2}\)

此时 \(\text{dist}(a,b)\) 已经确定了,当 \(\text{dist}(b,c)+\text{dist}(a,c)\) 的值取到最大,那么整个式子取最大。

\(a,b\) 到所有点的简单路径距离求出来,去枚举这个最优的 \(c\) 即可,枚举的过程中记得判与 \(a,b\) 相同的情况。

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

#define RI register int

using namespace std;

inline int read() {
    int x = 0, f = 1;
    char s = getchar();

    while (s < '0' || s > '9') {
        if (s == '-')
            f = -f;

        s = getchar();
    }

    while (s >= '0' && s <= '9') {
        x = x * 10 + s - '0';
        s = getchar();
    }

    return x * f;
}

const int N = 200100, M = 400100;

int n;

int tot, head[N], ver[M], edge[M], Next[M];

void add(int u, int v, int w) {
    ver[++tot] = v;
    edge[tot] = w;
    Next[tot] = head[u];
    head[u] = tot;
}

int d[N], vis[N];

int pos;

void bfs(int sta) {
    memset(d, 0, sizeof(d));
    memset(vis, 0, sizeof(vis));

    queue<int>q;

    q.push(sta);
    vis[sta] = 1;

    while (q.size()) {
        int u = q.front();
        q.pop();

        for (RI i = head[u]; i; i = Next[i]) {
            int v = ver[i], w = edge[i];

            if (vis[v])
                continue;

            d[v] = d[u] + w;
            vis[v] = 1;

            if (d[v] > d[pos])
                pos = v;

            q.push(v);
        }
    }
}

int p1, p2;
int ans;

int tmp1[N], tmp2[N];

int main() {
    n = read();

    for (RI i = 1; i < n; i++) {
        int u = read(), v = read();
        add(u, v, 1), add(v, u, 1);
    }

    bfs(1);
    p1 = pos;

    bfs(p1);
    p2 = pos;

    for (RI i = 1; i <= n; i++)
        tmp1[i] = d[i];

    bfs(p2);

    for (RI i = 1; i <= n; i++)
        tmp2[i] = d[i];

    pos = 0;

    for (RI i = 1; i <= n; i++)
        if (tmp1[i] + tmp2[i] > tmp1[pos] + tmp2[pos] && i != p1 && i != p2)
            pos = i;

    ans = (tmp1[p2] + tmp1[pos] + tmp2[pos]) / 2;

    printf("%d\n", ans);
    printf("%d %d %d\n", p1, p2, pos);

    return 0;
}


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM