0515_알고리즘 백트래킹(Backtracking) 정리
Algorithm/etc.

0515_알고리즘 백트래킹(Backtracking) 정리

728x90
반응형

그 동안 문제 풀면서 자주 들었던 말.. "이거 백트래킹으로 풀면 돼^^"

찾아야 하는 경우의 수를 줄이는 것은 알았지만, 정확히 어떤 식으로 해결하는 지는 몰랐는데 드디어 오늘 SSAFY 라이브 강의에서 백트래킹 개념을 배웠다.

 

[개념 정리]

백트래킹을 내 식대로 정리하자면,

 

가장 기본적으로는 '해를 찾는 것'이다. 근데 해를 어떻게 찾느냐 하면, 깊이우선탐색(DFS)와 비슷하게 풀이된다. 이때 다른 점은 깊이우선탐색은 가능한 모든 경우의 수를 탐색하는 것이고, 백트래킹은 조건에 만족한 노드만을 찾아 탐색하다가 해를 찾으면 탐색을 멈추기 때문에 결과적으로 훨씬 적은 경우의 수를 탐색하게 된다. (물론 그렇지 않을 때도 있다.)

 

즉, 기본적인 뼈대는 DFS지만 만족해야 하는 조건을 넣어주어 탐색하는 경우의 수를 줄여주는 것.

 

 


 

가장 백트래킹을 이용한 대표적인 문제는 N-Queen 문제이다. 연습용으로 SWEA의 2806번 문제를 풀어보았다⬇⬇

[코드 구현]

def backtrack(r):
    global n, cnt
    # r = 행
    if r == n:
        # 모든 행을 다 거치면 해를 찾은 것
        cnt += 1
        return

    # 열, 대각선 파악
    for c in range(n):
        # 조건에 부합하는 지 확인
        if not col[c] and not dia_1[r+c] and not dia_2[n-r+c-1]:
            col[c]=1
            dia_1[r+c]=1
            dia_2[n-r+c-1]=1
            backtrack(r+1)
            col[c]=0
            dia_1[r+c]=0
            dia_2[n-r+c-1]=0

for tc in range(1, int(input())+1):
    n = int(input())
    # 열, 대각선(상향, 하향) 확인하기 위한 리스트
    col = [0]*n
    dia_1 = [0]*(2*n-1)
    dia_2 = [0]*(2*n-1)
    cnt = 0
    backtrack(0)
    print('#{} {}'.format(tc, cnt))

[문제 출처]

SWEA_2806번 N-Queen문제

8*8 체스보드에 8개의 퀸을 서로 공격하지 못하게 놓는 문제는 잘 알려져 있는 문제이다.퀸은 같은 행, 열, 또는 대각선 위에 있는 말을 공격할 수 있다. 이 문제의 한가지 정답은 아래 그림과 같다. 


이 문제의 조금 더 일반화된 문제는 Franz Nauck이 1850년에 제기했다. N*N 보드에 N개의 퀸을 서로 다른 두 퀸이 공격하지 못하게 놓는 경우의 수는 몇가지가 있을까? N이 주어졌을 때, 퀸을 놓는 방법의 수를 구하는 프로그램을 작성하시오.

https://swexpertacademy.com/main/code/problem/problemDetail.do?contestProbId=AV7GKs06AU0DFAXB&categoryId=AV7GKs06AU0DFAXB&categoryType=CODE

 

SW Expert Academy

SW 프로그래밍 역량 강화에 도움이 되는 다양한 학습 컨텐츠를 확인하세요!

swexpertacademy.com

 


SWEA의 learn - course -Advanced - 05 백트래킹- 최소생산비용 문제를 추가적으로 풀었다.

두 가지 실수한 점은 

- 두 번째 if문 조건을 넣지 않았다가 시간초과가 떠서 페일이었고

- 함수 호출하고 cost -= arr[product][i] 이전 비용을 다시 빼주지 않아서 제대로 된 결과를 얻지 못 했다.

[코드 구현]

def dfs_cost(product):
    global cost, res
    if product == n:
        res = min(res, cost)
        return
    if cost > res:
        return

    for i in range(n):
        if not factory[i]:
            factory[i] = 1
            cost += arr[product][i]
            dfs_cost(product+1)
            cost -= arr[product][i]
            factory[i] = 0
    return

for tc in range(1, int(input())+1):
    n = int(input())
    arr = [list(map(int, input().split())) for _ in range(n)]
    factory = [0]*n
    cost = 0
    res = 1000000000000

    dfs_cost(0)
    print('#{} {}'.format(tc, res))

 

 

반응형