๋ฐฑํŠธ๋ž˜ํ‚น (Backtracking)

๋ฐฑํŠธ๋ž˜ํ‚น (Backtracking)

๊ฐœ์š”

๋ฐฑํŠธ๋ž˜ํ‚น์€ ํ•ด๋ฅผ ์ฐพ๋Š” ๋„์ค‘ ๋ง‰ํžˆ๋ฉด ๋˜๋Œ์•„๊ฐ€์„œ ๋‹ค์‹œ ํ•ด๋ฅผ ์ฐพ๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ๊ฐ€์ง€์น˜๊ธฐ(pruning)๋ฅผ ํ†ตํ•ด ๋ถˆํ•„์š”ํ•œ ํƒ์ƒ‰์„ ์ค„์ž…๋‹ˆ๋‹ค.


๋ชฉ์ฐจ

  1. ๋ฐฑํŠธ๋ž˜ํ‚น ๊ฐœ๋…
  2. ์ˆœ์—ด๊ณผ ์กฐํ•ฉ
  3. N-Queens
  4. ๋ถ€๋ถ„์ง‘ํ•ฉ
  5. ์Šค๋„์ฟ 
  6. ์—ฐ์Šต ๋ฌธ์ œ

1. ๋ฐฑํŠธ๋ž˜ํ‚น ๊ฐœ๋…

๊ธฐ๋ณธ ์›๋ฆฌ

๋ฐฑํŠธ๋ž˜ํ‚น:
1. ํ•ด๋ฅผ ํ•˜๋‚˜์”ฉ ๊ตฌ์„ฑํ•ด ๋‚˜๊ฐ
2. ์กฐ๊ฑด์„ ๋งŒ์กฑํ•˜์ง€ ์•Š์œผ๋ฉด ์ด์ „ ๋‹จ๊ณ„๋กœ ๋˜๋Œ์•„๊ฐ
3. ๊ฐ€์ง€์น˜๊ธฐ๋กœ ํƒ์ƒ‰ ๊ณต๊ฐ„ ์ถ•์†Œ

DFS + ์กฐ๊ฑด ๊ฒ€์‚ฌ + ๋˜๋Œ์•„๊ฐ€๊ธฐ

์ƒํƒœ ๊ณต๊ฐ„ ํŠธ๋ฆฌ

N=3์ผ ๋•Œ ์ˆœ์—ด ํƒ์ƒ‰:

                    []
         /          |          \
       [1]         [2]         [3]
       / \         / \         / \
    [1,2][1,3] [2,1][2,3] [3,1][3,2]
      |    |     |    |     |    |
   [1,2,3][1,3,2][2,1,3][2,3,1][3,1,2][3,2,1]

๊ฐ€์ง€์น˜๊ธฐ ์˜ˆ: ์ฒซ ์›์†Œ๊ฐ€ ์กฐ๊ฑด ์œ„๋ฐ˜ ์‹œ ํ•ด๋‹น ์„œ๋ธŒํŠธ๋ฆฌ ์ „์ฒด ์Šคํ‚ต

๊ธฐ๋ณธ ํ…œํ”Œ๋ฆฟ

def backtrack(candidate):
    if is_solution(candidate):
        output(candidate)
        return

    for next_choice in choices(candidate):
        if is_valid(next_choice):  # ๊ฐ€์ง€์น˜๊ธฐ
            candidate.append(next_choice)
            backtrack(candidate)
            candidate.pop()  # ๋˜๋Œ๋ฆฌ๊ธฐ

2. ์ˆœ์—ด๊ณผ ์กฐํ•ฉ

2.1 ์ˆœ์—ด (Permutation)

n๊ฐœ ์ค‘ r๊ฐœ๋ฅผ ์ˆœ์„œ ์žˆ๊ฒŒ ๋‚˜์—ด
nPr = n! / (n-r)!

[1, 2, 3]์˜ ๋ชจ๋“  ์ˆœ์—ด:
[1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1,2], [3,2,1]
// C++
void permute(vector<int>& nums, int start, vector<vector<int>>& result) {
    if (start == nums.size()) {
        result.push_back(nums);
        return;
    }

    for (int i = start; i < nums.size(); i++) {
        swap(nums[start], nums[i]);
        permute(nums, start + 1, result);
        swap(nums[start], nums[i]);  // ๋˜๋Œ๋ฆฌ๊ธฐ
    }
}

vector<vector<int>> permutations(vector<int>& nums) {
    vector<vector<int>> result;
    permute(nums, 0, result);
    return result;
}
def permutations(nums):
    result = []

    def backtrack(start):
        if start == len(nums):
            result.append(nums[:])
            return

        for i in range(start, len(nums)):
            nums[start], nums[i] = nums[i], nums[start]
            backtrack(start + 1)
            nums[start], nums[i] = nums[i], nums[start]

    backtrack(0)
    return result

# ๋˜๋Š” itertools
from itertools import permutations
list(permutations([1, 2, 3]))

2.2 ์กฐํ•ฉ (Combination)

n๊ฐœ ์ค‘ r๊ฐœ๋ฅผ ์ˆœ์„œ ์—†์ด ์„ ํƒ
nCr = n! / (r! ร— (n-r)!)

[1, 2, 3, 4]์—์„œ 2๊ฐœ ์„ ํƒ:
[1,2], [1,3], [1,4], [2,3], [2,4], [3,4]
// C++
void combine(int n, int r, int start, vector<int>& current,
             vector<vector<int>>& result) {
    if (current.size() == r) {
        result.push_back(current);
        return;
    }

    for (int i = start; i <= n; i++) {
        current.push_back(i);
        combine(n, r, i + 1, current, result);
        current.pop_back();
    }
}

vector<vector<int>> combinations(int n, int r) {
    vector<vector<int>> result;
    vector<int> current;
    combine(n, r, 1, current, result);
    return result;
}
def combinations(n, r):
    result = []

    def backtrack(start, current):
        if len(current) == r:
            result.append(current[:])
            return

        for i in range(start, n + 1):
            current.append(i)
            backtrack(i + 1, current)
            current.pop()

    backtrack(1, [])
    return result

# ๋˜๋Š” itertools
from itertools import combinations
list(combinations([1, 2, 3, 4], 2))

2.3 ์ค‘๋ณต ์ˆœ์—ด/์กฐํ•ฉ

# ์ค‘๋ณต ์ˆœ์—ด: ๊ฐ™์€ ์›์†Œ ์—ฌ๋Ÿฌ ๋ฒˆ ์„ ํƒ ๊ฐ€๋Šฅ
def permutations_with_repetition(nums, r):
    result = []

    def backtrack(current):
        if len(current) == r:
            result.append(current[:])
            return

        for num in nums:
            current.append(num)
            backtrack(current)
            current.pop()

    backtrack([])
    return result

# ์ค‘๋ณต ์กฐํ•ฉ
def combinations_with_repetition(nums, r):
    result = []

    def backtrack(start, current):
        if len(current) == r:
            result.append(current[:])
            return

        for i in range(start, len(nums)):
            current.append(nums[i])
            backtrack(i, current)  # i+1์ด ์•„๋‹Œ i
            current.pop()

    backtrack(0, [])
    return result

3. N-Queens

๋ฌธ์ œ

Nร—N ์ฒด์ŠคํŒ์— N๊ฐœ์˜ ํ€ธ์„ ์„œ๋กœ ๊ณต๊ฒฉํ•  ์ˆ˜ ์—†๊ฒŒ ๋ฐฐ์น˜

ํ€ธ์˜ ๊ณต๊ฒฉ ๋ฒ”์œ„: ๊ฐ€๋กœ, ์„ธ๋กœ, ๋Œ€๊ฐ์„ 

4ร—4 ์˜ˆ์‹œ (ํ•˜๋‚˜์˜ ํ•ด):
. Q . .
. . . Q
Q . . .
. . Q .

์•Œ๊ณ ๋ฆฌ์ฆ˜

ํ–‰ ๋‹จ์œ„๋กœ ํ€ธ ๋ฐฐ์น˜:
1. ์ฒซ ํ–‰์— ํ€ธ ๋ฐฐ์น˜ ์‹œ๋„
2. ๋‹ค์Œ ํ–‰์— ํ€ธ ๋ฐฐ์น˜ (์ถฉ๋Œ ๊ฒ€์‚ฌ)
3. ์ถฉ๋Œํ•˜๋ฉด ๋ฐฑํŠธ๋ž˜ํ‚น
4. N๊ฐœ ๋ฐฐ์น˜ ์™„๋ฃŒํ•˜๋ฉด ํ•ด ์ถœ๋ ฅ

์ถฉ๋Œ ๊ฒ€์‚ฌ:
- ๊ฐ™์€ ์—ด: cols[col] == True
- ๋Œ€๊ฐ์„ 1 (โ†˜): row - col ๊ฐ’์ด ๊ฐ™์Œ
- ๋Œ€๊ฐ์„ 2 (โ†™): row + col ๊ฐ’์ด ๊ฐ™์Œ

๊ตฌํ˜„

// C++
class NQueens {
private:
    int n;
    vector<bool> cols, diag1, diag2;
    vector<vector<string>> results;

    void backtrack(int row, vector<int>& queens) {
        if (row == n) {
            results.push_back(generateBoard(queens));
            return;
        }

        for (int col = 0; col < n; col++) {
            if (cols[col] || diag1[row - col + n - 1] || diag2[row + col])
                continue;

            queens[row] = col;
            cols[col] = diag1[row - col + n - 1] = diag2[row + col] = true;

            backtrack(row + 1, queens);

            cols[col] = diag1[row - col + n - 1] = diag2[row + col] = false;
        }
    }

    vector<string> generateBoard(const vector<int>& queens) {
        vector<string> board(n, string(n, '.'));
        for (int i = 0; i < n; i++) {
            board[i][queens[i]] = 'Q';
        }
        return board;
    }

public:
    vector<vector<string>> solveNQueens(int n) {
        this->n = n;
        cols.assign(n, false);
        diag1.assign(2 * n - 1, false);
        diag2.assign(2 * n - 1, false);

        vector<int> queens(n);
        backtrack(0, queens);

        return results;
    }
};
def solve_n_queens(n):
    results = []
    cols = set()
    diag1 = set()  # row - col
    diag2 = set()  # row + col

    def backtrack(row, queens):
        if row == n:
            board = ['.' * q + 'Q' + '.' * (n - q - 1) for q in queens]
            results.append(board)
            return

        for col in range(n):
            if col in cols or (row - col) in diag1 or (row + col) in diag2:
                continue

            cols.add(col)
            diag1.add(row - col)
            diag2.add(row + col)

            backtrack(row + 1, queens + [col])

            cols.remove(col)
            diag1.remove(row - col)
            diag2.remove(row + col)

    backtrack(0, [])
    return results

# ํ•ด์˜ ๊ฐœ์ˆ˜๋งŒ ์„ธ๊ธฐ
def count_n_queens(n):
    count = 0
    cols = set()
    diag1 = set()
    diag2 = set()

    def backtrack(row):
        nonlocal count
        if row == n:
            count += 1
            return

        for col in range(n):
            if col in cols or (row - col) in diag1 or (row + col) in diag2:
                continue

            cols.add(col)
            diag1.add(row - col)
            diag2.add(row + col)

            backtrack(row + 1)

            cols.remove(col)
            diag1.remove(row - col)
            diag2.remove(row + col)

    backtrack(0)
    return count

4. ๋ถ€๋ถ„์ง‘ํ•ฉ

๋ชจ๋“  ๋ถ€๋ถ„์ง‘ํ•ฉ ์ƒ์„ฑ

[1, 2, 3]์˜ ๋ถ€๋ถ„์ง‘ํ•ฉ:
[], [1], [2], [3], [1,2], [1,3], [2,3], [1,2,3]

์ด 2^n๊ฐœ
// C++
vector<vector<int>> subsets(vector<int>& nums) {
    vector<vector<int>> result;
    vector<int> current;

    function<void(int)> backtrack = [&](int start) {
        result.push_back(current);

        for (int i = start; i < nums.size(); i++) {
            current.push_back(nums[i]);
            backtrack(i + 1);
            current.pop_back();
        }
    };

    backtrack(0);
    return result;
}

// ๋น„ํŠธ๋งˆ์Šคํฌ ๋ฐฉ๋ฒ•
vector<vector<int>> subsetsBitmask(vector<int>& nums) {
    int n = nums.size();
    vector<vector<int>> result;

    for (int mask = 0; mask < (1 << n); mask++) {
        vector<int> subset;
        for (int i = 0; i < n; i++) {
            if (mask & (1 << i)) {
                subset.push_back(nums[i]);
            }
        }
        result.push_back(subset);
    }

    return result;
}
def subsets(nums):
    result = []

    def backtrack(start, current):
        result.append(current[:])

        for i in range(start, len(nums)):
            current.append(nums[i])
            backtrack(i + 1, current)
            current.pop()

    backtrack(0, [])
    return result

# ๋น„ํŠธ๋งˆ์Šคํฌ
def subsets_bitmask(nums):
    n = len(nums)
    result = []

    for mask in range(1 << n):
        subset = [nums[i] for i in range(n) if mask & (1 << i)]
        result.append(subset)

    return result

ํ•ฉ์ด target์ธ ๋ถ€๋ถ„์ง‘ํ•ฉ

def subset_sum(nums, target):
    result = []

    def backtrack(start, current, current_sum):
        if current_sum == target:
            result.append(current[:])
            return

        if current_sum > target:  # ๊ฐ€์ง€์น˜๊ธฐ
            return

        for i in range(start, len(nums)):
            current.append(nums[i])
            backtrack(i + 1, current, current_sum + nums[i])
            current.pop()

    backtrack(0, [], 0)
    return result

5. ์Šค๋„์ฟ 

๋ฌธ์ œ

9ร—9 ๊ฒฉ์ž, ๊ฐ ํ–‰/์—ด/3ร—3 ๋ฐ•์Šค์— 1-9๊ฐ€ ํ•œ ๋ฒˆ์”ฉ

5 3 . | . 7 . | . . .
6 . . | 1 9 5 | . . .
. 9 8 | . . . | . 6 .
------+-------+------
8 . . | . 6 . | . . 3
4 . . | 8 . 3 | . . 1
7 . . | . 2 . | . . 6
------+-------+------
. 6 . | . . . | 2 8 .
. . . | 4 1 9 | . . 5
. . . | . 8 . | . 7 9

๊ตฌํ˜„

def solve_sudoku(board):
    def is_valid(board, row, col, num):
        # ํ–‰ ๊ฒ€์‚ฌ
        if num in board[row]:
            return False

        # ์—ด ๊ฒ€์‚ฌ
        for r in range(9):
            if board[r][col] == num:
                return False

        # 3ร—3 ๋ฐ•์Šค ๊ฒ€์‚ฌ
        box_row, box_col = 3 * (row // 3), 3 * (col // 3)
        for r in range(box_row, box_row + 3):
            for c in range(box_col, box_col + 3):
                if board[r][c] == num:
                    return False

        return True

    def solve():
        for row in range(9):
            for col in range(9):
                if board[row][col] == '.':
                    for num in '123456789':
                        if is_valid(board, row, col, num):
                            board[row][col] = num

                            if solve():
                                return True

                            board[row][col] = '.'  # ๋ฐฑํŠธ๋ž˜ํ‚น

                    return False  # ๋ชจ๋“  ์ˆซ์ž ์‹คํŒจ

        return True  # ๋นˆ ์นธ ์—†์Œ = ์™„๋ฃŒ

    solve()

6. ์—ฐ์Šต ๋ฌธ์ œ

๋ฌธ์ œ 1: ๋ฌธ์ž์—ด์˜ ๋ชจ๋“  ์ˆœ์—ด

์ค‘๋ณต ๋ฌธ์ž๊ฐ€ ์žˆ์„ ๋•Œ ์ค‘๋ณต ์—†์ด ์ˆœ์—ด ์ƒ์„ฑ

์ •๋‹ต ์ฝ”๋“œ
def permute_unique(nums):
    result = []
    nums.sort()

    def backtrack(current, remaining):
        if not remaining:
            result.append(current[:])
            return

        for i in range(len(remaining)):
            # ์ค‘๋ณต ์Šคํ‚ต
            if i > 0 and remaining[i] == remaining[i-1]:
                continue

            backtrack(current + [remaining[i]],
                     remaining[:i] + remaining[i+1:])

    backtrack([], nums)
    return result

์ถ”์ฒœ ๋ฌธ์ œ

๋‚œ์ด๋„ ๋ฌธ์ œ ํ”Œ๋žซํผ ์œ ํ˜•
โญโญ N๊ณผ M ๋ฐฑ์ค€ ์ˆœ์—ด
โญโญ N-Queens ๋ฐฑ์ค€ N-Queens
โญโญ Subsets LeetCode ๋ถ€๋ถ„์ง‘ํ•ฉ
โญโญโญ Sudoku Solver LeetCode ์Šค๋„์ฟ 
โญโญโญ Combination Sum LeetCode ์กฐํ•ฉ

๋ฐฑํŠธ๋ž˜ํ‚น ํ…œํ”Œ๋ฆฟ

def backtrack(state):
    if is_goal(state):
        save_solution(state)
        return

    for choice in get_choices(state):
        if is_valid(choice, state):
            make_choice(state, choice)
            backtrack(state)
            undo_choice(state, choice)

๋‹ค์Œ ๋‹จ๊ณ„


์ฐธ๊ณ  ์ž๋ฃŒ

  • Backtracking
  • Introduction to Algorithms (CLRS) - Backtracking
to navigate between lessons