티스토리 뷰

1. 문제 링크

https://codeforces.com/contest/1324/problem/F

 

Problem - F - Codeforces

 

codeforces.com

2. 문제 개요

n개의 정점, n-1개의 간선을 가진 tree가 주어진다. 트리의 값이 0이면 black, 1이면 white이다.

 

3. 문제 힌트

DP top-down  , bottom - up 2가지를 모두 써야 한다.

dp[index]를 정의해야 하는데, dp[next] += max(dp[cur],0)  모두 연결된 간선에 대해서 max를 반복해야 할 것이다.

 

4. 문제 풀이

힌트의 의미를 문제의 흐름과 함께 풀어보자.

 

일단 dp의 초기값은 검은색이면 -1, 흰색이면 1로 두고 시작하자.

 

예제의 그림을 보면,

 

1. 예제 1번

다음과 같다.

 

루트는 의미 없지만, 1번을 루트라 가정하고 dfs탐색을 진행한다. 무한반복을 피하기 위해 dfs탐색의 매개변수에는 (현재, 이전) 정점의 index값을 넣어주자.

1 - 2 - 6 - 8의 경로를 생각해보자.

8번에 도착했다.  8번에서 6번으로 값을 던져주는 것은 본인을 선택하거나 선택하지 않거나 두 개로 나뉠 수 있다. 6번 노드 기준에서 6번 노드의 dp값은, dp[6] += max(dp[8],0)이 될 수 있겠다. +=하는 이유는 결국 최댓값은 연결된 모든 간선의 dp값을 합치는 것이기 때문이다. 음수인 경우에는 선택할 필요가 없으므로 0이 자연스레 선택될 것이다. 

이제 dp[6]이 정해졌고 dp[2]를 볼 차례이다. dp[2] += max(dp[6], 0)이 된다. max(dp[6],0)은 max(-1,0)이고 따라서 0이된다. dp[2]는 기존에 1이 있었으므로 dp[2] = 1이 된다. 이런 식으로 반복하다 보면 뭔가 문제가 있다는 것을 알 것이다.

 

dp[1]이야 연결된 모든 간선의 dp값을 더해서 올바른 값을 가지고 있다. 

2. bottom -up 으로 1번 계산

하지만 2와 3을 보더라도 1에서 나오는 간선의 dp값을 더해주지 않았다. 즉, 2를 기준으로 봤을 때, 

3. 필요한 부분

저렇게 1에서 2로 가는 부분을 구해서 최댓값을 갱신할 수 있다면 더해서 갱신해줘야 한다.

현재 dp배열에는, 

4. 현재 상태

이런 형태로 데이터가 저장되어있다.  즉, dp[2]에는 정점 2,6,8에 대해서만 최적화된 값을 가지고 있다는 의미이다.

따라서 필요한 부분을 구해주기 위해서는(그림 3번) val(필요한 부분)= dp[1] - max(0, dp[2])를 해주어야 한다. 

 

그러고 나서 새로운 dfs함수를 작성할 때 구한 val을 다시 dp[2] += max(val,0)을 하여 한번 더 누적시켜준다.

그러면 모든 경우의 수를 다 비교했으니 최적 값이 dp에 누적되게 된다.

 

※ 처음에 이렇게 두 가지 방향을 고려해야 하기 때문에 정점을 사용하기보다는 간선을 사용해서 구현해봤다.

양방향 간선이 n-1개가 있고 정점이 n개가 있기 때문에 O(2E+V)로 끝날 줄 알았다. 하지만 계속 Test set#39에서 시간 초과.. N = 200,000  E = 199,999가 되니까 시간 초과가 발생했다. 처음에 작성한 알고리즘도 위의 알고리즘처럼 밑에서 위로, 위에서 밑으로 모든 방향을 탐색하는데 왜 시간 초과인지 잘 모르겠다. 계속 분석해봐야겠다.. 

 

5. 코드

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

vector<int> dp,dpp;
int n;
vector<vector<int>> adj;

void dfs(int cur, int before)
{
	for (int i = 0; i < adj[cur].size(); ++i) {
		int next = adj[cur][i];

		if (before == next)	continue;
		dfs(next, cur);
		dp[cur] += max(dp[next], 0);
	}
	
}
void dfs2(int cur, int before)
{
	dpp[cur] = dp[cur];

	if (before)
	{
		int val = dpp[before] - max(0, dp[cur]);
		dpp[cur] += max(val, 0);
	}
	for (int i = 0; i < adj[cur].size(); ++i) {
		int next = adj[cur][i]; 
		if (before == next)	continue;
		dfs2(next, cur);
	}
	
}
int  main()
{
	scanf("%d", &n);
	dp.resize(n + 1);
	dpp.resize(n + 1);
	adj.resize(n + 1);
	
	for (int i = 1; i <= n; ++i) {
		int val;
		scanf("%d", &val);
		if (val == 0)
			dp[i] = -1;
		else
			dp[i] = 1;
	}


	for (int i = 0; i < n - 1; ++i)
	{
		int from, to;
		scanf("%d %d", &from, &to);

		adj[from].push_back(to);
		adj[to].push_back(from);
	}

	dfs(1, 0);
	dfs2(1, 0);
	for (int i = 1; i <= n; ++i)
		printf("%d ", dpp[i]);


	return 0;
}

 

 

 

지적, 댓글 언제나 환영입니다~

 

댓글
«   2025/02   »
1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28
Total
Today
Yesterday