코딩 테스트 준비! (백준)/TREE

[TREE] 백준 1167번,1967번 트리의 지름(C++)

lee-soo 2025. 5. 13. 18:22

https://www.acmicpc.net/problem/1167

 

이 문제에 대해 간단히 설명하자면

 

여러개의 노드가 있을때, 한 노드와 또 다른 노드 사이의 거리가 제일 멀 때의 거리를 구하라

그게 지름이고

 

이것이다.

 

 

예를들어 이렇게 노드가 있으면?

 

가장 긴 거리는

3+2+6 =11이다

 

그래서 내가 먼저 처음에 생각한건

 

"무조건 거리를 잴 때는, 연결된 노드가 1개밖에 없는 경우이다"

라고 생각해서

 

조건문으로 자식이 한개인 경우를 걸러낸 후,

그 노드들에 대해 dfs를 통해 distance들을 알아내고

max를 지정해줬었다

 

#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>

using namespace std;
bool visited[100001];
bool visited_2[100001];
vector<pair<int, int>> v[100001];
int distance_ = 0;
int max_dis = 0;
int start;
void dfs(int temp);

int main()
{
    ios_base ::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int n;
    cin >> n;
    for (int i = 0; i < n; i++)
    {
        int t1;
        cin >> t1;
        while (1)
        {
            int t2, dis;
            cin >> t2;
            if (t2 == -1)
                break;
            cin >> dis;
            v[t1].push_back({dis, t2});
        }
    }
    for (int i = 1; i <= n; i++)
    {
        if (v[i].size() == 1)
        {
            visited[i] = true;
            visited_2[i] = true;
            start = i;
            dfs(i);
        }
        distance_ = 0;
        for (int j = 1; j <= n; j++)
            visited_2[j] = false;
    }
    cout << max_dis << "\n";
}

void dfs(int temp)
{
    if (temp != start && v[temp].size() == 1)
        max_dis = max(max_dis, distance_);

    for (int i = 0; i < v[temp].size(); i++)
    {
        int x = v[temp][i].second;
        int temp_d = v[temp][i].first;
        if (!visited_2[x] && !visited[x])
        {
            visited_2[x] = true;
            distance_ += temp_d;
            dfs(x);
            distance_ -= temp_d;
            visited_2[x] = false;
        }
    }
}
 
이렇게
 
근데
시간초과...

보니까 노드는 최대 10만개까지 들어올 수 있는데

dfs 를 통한 내 코드의 방식은 

최악의 경우 O(n^2)이상이 뜨게된다.

..

그렇다면 어떻게든 시간을 줄여야되어서 곰곰히 생각해보았다.

 

어떤 단말노드 두개와 연결된 노드는, 한개의 노드만 모든 경로에 갔다오면, 두번째 노드는 한개의 노드가 간 거리에서 모두 첫번째 노드의 거리를 빼고 , 두번째 노드의 거리만 더하면 되지않는가?

-> 코드가 굉장히 복잡해짐

 

그래서 그냥 인터넷 찾아봤더니

 

정말 똑똑한 방법이 있었다.

 

 

내가 사용한 dfs 방법은 어떠한 단말 노드에서 가장 긴 노드의 거리를 찾는다.

 

예를들어 줄의 거리를 구하기 위해서(가장긴)

 

파란줄에서 빨간점 하나를 찾자.

 

거기서 가장 긴 거리는?

 

이렇게 될 것이다.

그 점은 가장 끝 점이 될 것이고 

또 그점에서 가장 긴 거리를 찾는 점을 찾으면?

 

 

결국 가장 긴 거리가 될 것이다

 

요점은 무엇이냐

 

1. dfs를 아무점을 사용하여 가장 끝 점을 하나 찾는다

2. dfs를 또 사용하여 가장 긴 길이를 구해낸다.

 

그렇게되면 dfs를 두번밖에 안쓰게 되니

최악의 경우에도 O(n+알파)정도의 시간이 걸리게 된다!

 

그럼 코드를 써보자

 

#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>

using namespace std;
bool visited[100001];
vector<pair<int, int>> v[100001];

int distance_ = 0;
int max_dis = 0;
int first;
void dfs(int temp, int distance);

int main()
{
    int n;
    cin >> n;
    for (int i = 0; i < n; i++)
    {
        int t1;
        cin >> t1;
        while (1)
        {
            int t2, dis;
            cin >> t2;
            if (t2 == -1)
                break;
            cin >> dis;
            v[t1].push_back({dis, t2});
        }
    }

    dfs(1, 0);
    for (int j = 1; j <= n; j++)
        visited[j] = false;
    max_dis = 0;
    distance_ = 0;
    dfs(first, 0);
    cout << max_dis << "\n";
}

void dfs(int temp, int distance)
{
    visited[temp] = true;
    if (v[temp].size() == 1 && distance > max_dis)
    {
        max_dis = distance;
        first = temp;
    }
    for (int i = 0; i < v[temp].size(); i++)
    {
        int x = v[temp][i].second;
        int temp_d = v[temp][i].first;
        if (!visited[x])
            dfs(x, distance + temp_d);
    }
}
그냥 임의의점을 사용하니 1,0을 이용하였고
두번째는 첫번째에서 구한 가장 끝 점을 first라는 변수안에 넣어 구하였다.
잘 작동이 되는것을 볼 수 있다.
그렇다면 

같은 문제이지만 번호만 다른 이 문제는?

 

입력만 다르게 받는걸 상정하면 쉽게 풀 수 있다.

 

 

입력도 깔끔하게 받을 수 있으니

 

#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>

using namespace std;
bool visited[100001];
vector<pair<int, int>> v[100001];

int distance_ = 0;
int max_dis = 0;
int first;
void dfs(int temp, int distance);

int main()
{
    int n;
    cin >> n;
    for (int i = 0; i < n - 1; i++)
    {
        int t1, t2, dis;
        cin >> t1 >> t2 >> dis;
        v[t1].push_back({t2, dis});
        v[t2].push_back({t1, dis});
    }

    dfs(1, 0);
    for (int j = 1; j <= n; j++)
        visited[j] = false;
    max_dis = 0;
    distance_ = 0;
    dfs(first, 0);
    cout << max_dis << "\n";
}

void dfs(int temp, int distance)
{
    visited[temp] = true;
    if (v[temp].size() == 1 && max_dis < distance)
    {
        max_dis = distance;
        first = temp;
    }
    for (int i = 0; i < v[temp].size(); i++)
    {
        int x = v[temp][i].first;
        int temp_d = v[temp][i].second;
        if (!visited[x])
            dfs(x, distance + temp_d);
    }
}
이렇게하자~
쉽다