티스토리 뷰

1. 문제 링크

www.acmicpc.net/problem/9426

 

9426번: 중앙값 측정

첫째 줄에 N과 K가 주어진다. (1 ≤ N ≤ 250,000, 1 ≤ K ≤ 5,000, K ≤ N) 둘째 줄부터 N개 줄에 측정한 온도가 순서대로 주어진다. 온도는 0보다 크거나 같고, 65535보다 작거나 같은 정수이다.

www.acmicpc.net

2. 문제 개요

주어진 데이터의 부분 수열의 중앙값을 찾아 모두 합을 구하는 프로그램을 만드시오.

 

3. 문제 힌트

세그먼트 트리를 사용해보자.

매번 정렬하여 중간값을 꺼내는 방식이나, merge sort처럼 각 세그먼트 트리에 정렬된 값을 갖고 올라가는 것은 O(NK)로 시간이 아슬아슬할 것 같다.

 

뭔가,, 공간 복잡도를 희생해서 시간 복잡도를 줄일 수 있는 방법이 없을지 생각해보자.

세그먼트 트리는 부분합 구하기 좋으니까, 각 인덱스는 그 인덱스를 값으로 가지는 원소의 개수를 저장하면 어떨까?

 

4. 문제 풀이

세그먼트 트리를 만들거다.

그런데 문제를 보면 값의 범위가 [0, 65535]이다. 리프 노드가 65536개 있는 세그먼트 트리를 만들어주고 각 트리의 노드는 0으로 초기화 하자.

매번 정렬하면 제한시간에 걸릴 것 같으므로 공간 복잡도를 희생해서 시간을 줄였다.

 

홈페이지의 예제를 보면, 3, 4, 5, 6... 이 있다. k는 3이다.

그럼, 세그먼트 트리에 k-1개가 될 때까지 삽입시키자. 그러면 3번 인덱스에 1, 4번 인덱스에 1이 추가되어 각 노드는 1이라는 값을 가지게 될 것이다.

 

이제부터는

1개 삽입,

중간값 찾기,

가장 먼저 넣었던 노드 제거

이런 식으로 될 것이다. 위의 예제로 예를 들면, 5를 삽입하고, 4를 출력하고, 3을 제거하면 된다.

 

자 그러면, 어떻게 중간값을 찾을까?

중간값은 log k의 시간으로 구해야 nlogk로 제한 시간 내에 해결할 수 있다.

 

리프 노드까지 내려갈 때, 자식 노드들의 개수를 살펴봐야 한다.

예를 들어서, k가 9라고 하고 5번째 값을 찾으면 된다고 해보자,

 

처음에 루트는 9라는 값을 가질 것이고,

자식들이 6, 3이라는 값을 가지고 있다고 해보자.

그럼 5번째 값은 왼쪽 노드로 내려가야 한다.

 

그럼 지금 서브 트리에서 루트는 6, 자식들은 3, 3이라고 해보자.

여기서 5번째는 오른쪽 서브 트리에 있게 된다.

그러면 왼쪽 노드의 3개를 빼서 오른쪽 서브 트리에서 2번째 노드를 찾으면 된다.

 

이런 방식으로 리프 노드에 도달할 때까지 찾으면 된다.

즉, x번째라고 하면

if 왼쪽 서브 트리 >= x 왼쪽으로 이동,

else 오른쪽으로 이동

이라고 할 수 있다.

 

5. 코드

#include <cstdio>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;

const int SIZE = 65536;
vector<int> arr;
int n, k;
long long int ans;

class Segment_tree {
public:
	Segment_tree(){}
	Segment_tree(int tree_size) { tree.resize(tree_size);}

	int query_update(int cur, int start, int end, int target, int val)
	{
		if (target < start || end < target)
			return tree[cur];

		if (start == end)
			return tree[cur] += val;

		int mid = (start + end) / 2;
		return tree[cur] = query_update(cur * 2, start, mid, target, val) + query_update(cur * 2 + 1, mid + 1, end, target, val);
		
	}

	int query_find(int cur, int start, int end, int index)
	{
		
		if (start == end) {
			return start;
		}

		int mid = (start + end) / 2;
		if (tree[cur * 2] >= index)
		{
			return query_find(cur * 2, start, mid, index);
		}
		else
		{
			index -= tree[cur * 2];
			return query_find(cur * 2 + 1, mid + 1, end, index);
		}

	}

	vector<int> tree;
};

int main()
{
	int h = (int)ceil(log2(SIZE));
	int tree_size = 1 << (h + 1);
	Segment_tree sgt(tree_size);
	
	int left = 0, right = 0;
	scanf("%d %d", &n, &k);
	arr.resize(n);
	int middle = (k + 1) / 2;
	for (int i = 0; i < n; ++i)
		scanf("%d", &arr[i]);

	while (right < k-1) 
		sgt.query_update(1, 0, SIZE - 1, arr[right++], 1);

	while (right < n) {
		sgt.query_update(1, 0, SIZE - 1, arr[right++], 1);
		ans += (long long)sgt.query_find(1, 0, SIZE - 1, middle);
		sgt.query_update(1, 0, SIZE - 1, arr[left++], -1);
	}

	printf("%lld", ans);
	
	return 0;
}

 

 

6. 실행 결과

 

지적 댓글 환영입니다.

 

댓글
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Total
Today
Yesterday