🐹.dev
[BOJ 31427] 신촌 도로망 관리와 쿼리2024년 3월 11일 생성(2024년 3월 11일 수정)알고리즘브루트 포스최소 스패닝 트리

문제 소개

문제 링크 : https://boj.kr/31427

Tier : Platinum V

Tag : Bruteforcing, Minimum Spanning Tree

무방향 그래프가 주어지고, 각 도로(= 간선)는 5개의 학교가 각각 소유하고 있습니다.

이때, 각 학교마다 도로의 비용은 모두 같으며 쿼리마다 각 학교의 도로의 비용이 주어지게 됩니다.

쿼리가 주어질 때마다 도로 관리비의 합의 최소를 구해야 하는 문제입니다.

풀이

1. 지문의 조건으로부터 최소 스패닝 트리 문제임을 알아차리기

문제의 조건을 보면, "임의의 두 정점을 직접적으로 잇는 도로는 최대 한 개만 존재하고, 어떤 두 정점을 선택하더라도 그 두 정점을 연결하는 경로가 존재함이 보장된다."

가 있습니다. 이는 그래프의 형태가 트리일 때의 주요 특징 중 하나입니다.

따라서, 우리는 이 조건을 보자마자 트리 문제임을 알아차려야 하며, 관리비(= 간선 비용)의 합의 최소를 구해야 하는 문제이므로

최소 스패닝 트리를 구성하는 문제임을 알 수 있습니다.

2. 시간 복잡도 체크 후, 브루트포스 가능 여부 판단해보기

일반적으로 최소 스패닝 트리 구성 시, 크루스칼 알고리즘을 이용하면

간선 정렬 시 O(NlogN), Union-Find를 적절히 최적화하여 O(logN) 미만 정도이므로

커봤자 O(Nlog^2N) 정도에 구성할 수 있습니다.

그런데, 이 문제에서는 쿼리 당 이러한 과정이 필요하므로 (만약 브루트 포스로 구한다면)

최종 시간 복잡도는 O(QNlog^2N) 정도가 될 것입니다. 그런데, Q*N의 최대는 10억 정도이므로 2.5초 정도의 제한 시간에 푸는 것은 불가능함을 알 수 있습니다.

일반적으로 연산 1억 회에 1초로 간주하면 시간 초과 여부를 쉽게 판단할 수 있습니다.

3. 이 문제만의 특징을 통해, 최적화 방법 고안해보기

학교는 총 5개이며, 이로서 간선의 비용의 종류도 최대 5가지인 점을 캐치해야 합니다.

만약 간선의 비용의 종류가 천차만별이라면, 위와 같이 정렬을 수행해야겠지만

5가지로 한정되는 경우 각 학교의 비용 순위에 대한 모든 경우의 수를 미리 판단해 볼 수 있게 됩니다.

다시 말해, 각 학교 간의 비교 관계는 5! = 120가지 이므로 이에 대해 최소 스패닝 트리를 미리 구성해 놓을 경우

하나의 쿼리에 대해 O(1)의 시간에 구할 수 있게 됩니다.

4. 쿼리를 처리하기 앞서 모든 비교 관계에 대해 전처리하기

특정 비교 관계에 대해, 간선의 비용을 1, 2, 3, 4, 5로 임의로 정한 다음에 최소 스패닝 트리를 구성합니다.

우리가 여기서 구하는 건 총 비용이 아니라 각 학교 당 최소 스패닝 트리에 포함되는 간선의 개수이기 때문에, 저런 식으로 간선의 비용을 설정할 수 있는 것입니다.

이후에, 특정 쿼리가 들어왔을 때 비교 관계를 연산하여 120가지의 경우 중 하나를 찾은 다음

(특정 학교의 간선의 비용) * (특정 학교가 관리 중인 도로의 개수)를 모두 더해주면 됩니다.

코드

⚠️ 해당 문제의 시간 제한이 C++ 기준이므로, Python으로 제출 시 시간 초과가 발생합니다.

필자는 C++로 바꾸어 제출했습니다.

solution.py
1import sys
2input = lambda: sys.stdin.readline().rstrip()
3
4def find(G, x):
5 if G[x] == x:
6 return G[x]
7
8 nodes = [x]
9
10 while G[nodes[-1]] != nodes[-1]:
11 nodes.append(G[nodes[-1]])
12
13 for node in nodes:
14 G[node] = nodes[-1]
15
16 return G[x]
17
18def union(G, T, x, y):
19 x, y = find(G, x), find(G, y)
20
21 # Small to Large
22 if T[x] <= T[y]:
23 x, y = y, x
24
25 G[y] = G[x]
26 T[x] += T[y]
27
28# C: [A, B, C, D, E]의 비용
29def analyse(N, E, C):
30 # A, B, C, D, E가 담당하는 각각의 간선의 수
31 result = [0] * 5
32
33 count = 0
34 offset = 0
35 G = [*range(N + 1)]
36 T = [1] * (N + 1)
37
38 E.sort(key=lambda x: C[x[0]])
39
40 while count < N - 1:
41 z, u, v = E[offset]
42 offset += 1
43
44 u, v = find(G, u), find(G, v)
45
46 if u == v:
47 continue
48
49 result[z] += 1
50 count += 1
51 union(G, T, u, v)
52
53 return result
54
55def signature(C):
56 A = [(C[x], x) for x in range(5)]
57 A.sort()
58
59 return int("".join(str(x[1]) for x in A))
60
61if __name__ == "__main__":
62 N, M, Q = map(int, input().split())
63 E = []
64
65 S = [[] for _ in range(43211)]
66
67 for _ in range(M):
68 u, v, z = input().split()
69 E.append((ord(z) - ord('A'), int(u), int(v)))
70
71 for _ in range(Q):
72 C = [*map(int, input().split())]
73 sign = signature(C)
74
75 if not S[sign]:
76 S[sign] = analyse(N, E, C)
77
78 result = 0
79
80 for x in range(5):
81 result += C[x] * S[sign][x]
82
83 print(result)
84
solution.py
1import sys
2input = lambda: sys.stdin.readline().rstrip()
3
4def find(G, x):
5 if G[x] == x:
6 return G[x]
7
8 nodes = [x]
9
10 while G[nodes[-1]] != nodes[-1]:
11 nodes.append(G[nodes[-1]])
12
13 for node in nodes:
14 G[node] = nodes[-1]
15
16 return G[x]
17
18def union(G, T, x, y):
19 x, y = find(G, x), find(G, y)
20
21 # Small to Large
22 if T[x] <= T[y]:
23 x, y = y, x
24
25 G[y] = G[x]
26 T[x] += T[y]
27
28# C: [A, B, C, D, E]의 비용
29def analyse(N, E, C):
30 # A, B, C, D, E가 담당하는 각각의 간선의 수
31 result = [0] * 5
32
33 count = 0
34 offset = 0
35 G = [*range(N + 1)]
36 T = [1] * (N + 1)
37
38 E.sort(key=lambda x: C[x[0]])
39
40 while count < N - 1:
41 z, u, v = E[offset]
42 offset += 1
43
44 u, v = find(G, u), find(G, v)
45
46 if u == v:
47 continue
48
49 result[z] += 1
50 count += 1
51 union(G, T, u, v)
52
53 return result
54
55def signature(C):
56 A = [(C[x], x) for x in range(5)]
57 A.sort()
58
59 return int("".join(str(x[1]) for x in A))
60
61if __name__ == "__main__":
62 N, M, Q = map(int, input().split())
63 E = []
64
65 S = [[] for _ in range(43211)]
66
67 for _ in range(M):
68 u, v, z = input().split()
69 E.append((ord(z) - ord('A'), int(u), int(v)))
70
71 for _ in range(Q):
72 C = [*map(int, input().split())]
73 sign = signature(C)
74
75 if not S[sign]:
76 S[sign] = analyse(N, E, C)
77
78 result = 0
79
80 for x in range(5):
81 result += C[x] * S[sign][x]
82
83 print(result)
84