백준 2143 두 배열의 합 [자바]

java

문제 분석

  • 두 배열이 주어진다
  • 두 배열의 부분합의 합으로, 목표한 수를 만드려고 한다.
  • 이 때 가능한 경우의 수는?

입력조건

목표 수 T
배열 a의 크기 n
배열 a 숫자 배열
배열 b의 크기 m
배열 b 숫자 배열

풀이과정

  • 문제 알고리즘 분류에 이분탐색이 써져 있는데, 나는 도저히 이 문제를 이분탐색으로 풀 방법이 떠오르지 않았다.
  • 그래서 이 문제를 푸는 데에 우선, a와 b에서 구할 수 있는 부분합을 전부 추출한 후,
  • 구해진 두 개의 배열을 b는 내림차순으로, a는 오름차순으로 정렬하여, 투포인터를 변형하여 계산하였다.
  • a의 부분합 배열에 포인터1, b의 부분합 배열에 포인터2를 두고,
  • a[a_pointer] + b[b_pointer]를 계산하여, 이 값이 목표보다 크다면 b포인터를 올리고, 목표보다 작다면 a포인터를 올려가며 포인터가 배열의 끝의 도달할 때 까지 계산하였다.

  • 입력받는 자료와 이 합 들을 long으로 처리해야 했다. 입력값은 1,000,000이 최대지만, 최대 1,000개 원소의 부분합을 처리하고, 또 이 부분합을 더하는 가공을 처리하므로 아주 당연하게 int자료형의 범위를 넘어감을 예측할 수 있다.
ps
  • 문제를 풀고 다른 사람의 코드를 많이 열어보았는데, 이분 탐색 뿐 아니라 상당히 다양한 방법으로 사람들이 푼 것을 볼 수 있었다.
  • 천천히 읽어보고, 다시 한 번 풀어 볼 생각이다.

코드 구성

  • 모든 부분합 입력

    ArrayList<Long> apsum = new ArrayList<>();
    ArrayList<Long> bpsum = new ArrayList<>();
    for(int i = 0; i < n; i++) {
        long tmpsum = 0;
        for(int j = i; j < n; j++) {
            tmpsum += a[j];
            apsum.add(tmpsum);
        }
    }
    for(int i = 0; i < m; i++) {
        long tmpsum = 0;
        for(int j = i; j < m; j++) {
            tmpsum += b[j];
            bpsum.add(tmpsum);
        }
    }
    
  • 투포인터

    long ans = 0;
    int ap = 0;
    int bp = 0;
    while(ap < apsum.size() && bp < bpsum.size()) {
        long sum = apsum.get(ap) + bpsum.get(bp);
      
        // 합이 T보다 크면 b쪽을 올리기
        if(sum > T) {
            long idf = bpsum.get(bp);
            while(bp < bpsum.size() && idf == bpsum.get(bp)) bp++;
        }
        // 합이 T보다 작으면 a쪽을 올리기
        else if(sum < T) {
            long idf = apsum.get(ap);
            while(ap < apsum.size() && idf == apsum.get(ap)) ap++;
        }
        else {
            long acount = 0;
            long aidf = apsum.get(ap);
            while(ap < apsum.size() && aidf == apsum.get(ap)) {
                acount++;
                ap++;
            }
            long bcount = 0;
            long bidf = bpsum.get(bp);
            while(bp < bpsum.size() && bidf == bpsum.get(bp)) {
                bcount++;
                bp++;
            }
            ans += acount * bcount;
        }
    }
    

전체 코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.stream.Stream;

public class BJ2143_두배열의합 {
	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		int T = Integer.parseInt(br.readLine());
		
		int n = Integer.parseInt(br.readLine());
		long[] a = Stream.of(br.readLine().split(" ")).mapToLong(Long::parseLong).toArray();
		int m = Integer.parseInt(br.readLine());
		long[] b = Stream.of(br.readLine().split(" ")).mapToLong(Long::parseLong).toArray();
		// 입력 끝.
		
		ArrayList<Long> apsum = new ArrayList<>();
		ArrayList<Long> bpsum = new ArrayList<>();
		for(int i = 0; i < n; i++) {
			long tmpsum = 0;
			for(int j = i; j < n; j++) {
				tmpsum += a[j];
				apsum.add(tmpsum);
			}
		}
		for(int i = 0; i < m; i++) {
			long tmpsum = 0;
			for(int j = i; j < m; j++) {
				tmpsum += b[j];
				bpsum.add(tmpsum);
			}
		}
		// 데이터 정리 완료.
		
		apsum.sort(null);
		bpsum.sort(Collections.reverseOrder());
		
//		System.out.println(apsum);
//		System.out.println(bpsum);
		
		long ans = 0;
		int ap = 0;
		int bp = 0;
		while(ap < apsum.size() && bp < bpsum.size()) {
			long sum = apsum.get(ap) + bpsum.get(bp);
			
			// 합이 T보다 크면 b쪽을 내리기
			if(sum > T) {
				long idf = bpsum.get(bp);
				while(bp < bpsum.size() && idf == bpsum.get(bp)) bp++;
			}
			else if(sum < T) {
				long idf = apsum.get(ap);
				while(ap < apsum.size() && idf == apsum.get(ap)) ap++;
			}
			else {
				long acount = 0;
				long aidf = apsum.get(ap);
				while(ap < apsum.size() && aidf == apsum.get(ap)) {
					acount++;
					ap++;
				}
				long bcount = 0;
				long bidf = bpsum.get(bp);
				while(bp < bpsum.size() && bidf == bpsum.get(bp)) {
					bcount++;
					bp++;
				}
				ans += acount * bcount;
			}
		}
		System.out.print(ans);
	}
}