回溯算法

What‘s your decision?

Posted by R1NG on October 27, 2020 Viewed Times

回溯算法: 以数独求解为例

回溯问题的解决核心是遍历决策树. 在进行决策时, 实际需要考虑的问题是:

  1. 路径: 目前已经做的决策是什么?
  2. 选择列表: 还可以做那些选择?
  3. 结束条件: 何时算作 “到达决策树底端”, 问题解决?

回溯问题的框架为:

1
2
3
4
5
6
7
8
9
10
11
12
result = []
def backtrack(路径, 选择列表):
    if 满足结束条件:
        result.add(新路径)
        return
    
    for 选择 in 选择列表:
        作出新选择      # 在原有决策的基础上作出下一步的选择

        backtrack(新路径, 新选择列表)       #检查能否进一步做选择
        
        撤销新选择      # 如果不能进一步做选择的话则撤销无效的选择

下面以数独求解问题为例:

Q: 给定一个 $9*9$ 的数独, 求出它的其中一个解.

程序输出:

  1. 我们需要打印待求解的数独和数独的一个解. 鉴于我们求解过程中修改是在数独的基础上进行的, 因此两个情况的代码可以复用:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
     def print_board(self, flag=0):
         # Print the sudoku board
    
         if flag == 1:
             print("\n= = = = = SOLUTION = = = = = \n")
         for r in range(self.len):
             if r % 3 == 0 and r != 0:
                 print("= = = = = = = = = = = =")
             for c in range(len(self.board[0])):
                 if c % 3 == 0 and c != 0:
                     print(" | ", end="")
                 if c == 8:
                     print(self.board[r][c])
                 else:
                     print(str(self.board[r][c]) + " ", end="")
    
  2. 除此以外, 我们还希望得知程序的运行时间以评价执行效率:
    1
    2
    3
    
     def stats(self):
         self.end = time.time()
         print("Time elapsed: {}s".format(self.end - self.start))
    


程序逻辑:

  1. 要求解数独, 我们首先需要找到待填的空位:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    
     def search_empty(self):
         # Search for empty slots
    
         for r in range(self.len):
             for c in range(self.len):
                 if self.board[r][c] == 0:
                     return r, c   # return row, column
    
         return None
    
  2. 要定义回溯函数, 我们还需要确定选择列表:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
     def find_valid_nums(self, pos):
         # find all possible valid numbers for position 'pos'
    
         num_list = list(range(1, 10))
         for i in range(9):
             if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
                 num_list.remove(self.board[pos[0]][i])
             if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
                 num_list.remove(self.board[i][pos[1]])
         x, y = pos[0] // 3, pos[1] // 3
         for i in range(3*x, 3*x + 3):
             for j in range(3*y, 3*y + 3):
                 if self.board[i][j] != 0 and self.board[i][j] in num_list:
                     num_list.remove(self.board[i][j])
         return num_list
    
  3. 此外, 我们还需要依据数独规则判断作出的选择是否有效:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
     def validate(self, num, pos):
         # Validate the selected choice
    
         for c in range(len(self.board[0])):
             if self.board[pos[0]][c] == num and pos[1] != c:
                 return False
         for r in range(len(self.board[0])):
             if self.board[r][pos[1]] == num and pos[0] != r:
                 return False
         x = pos[0] // 3
         y = pos[1] // 3
         for i in range(x*3, x*3 + 3):
             for j in range(y*3, y*3 + 3):
                 if self.board[i][j] == num and (i,j) != pos:
                     return False
    
  4. 最后, 我们可以定义回溯主函数:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    
     def solution(self):
         # Solution, using DFS Algorithm
    
         find = self.search_empty()
         if not find:
             return True
         else:
             row, col = find
         for i in self.find_valid_nums([row, col]):
             if self.validate(i, (row, col)):
                 self.board[row][col] = i
    
                 if self.solution():
                     return True
    
                 self.board[row][col] = 0
    
         return False
    

程序整理总结如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time


class Sudoku:
    """Define Sudoku Class, and work out its solution"""
    def __init__(self):
        # Initialize

        self.board = [
            [7, 0, 0, 0, 3, 4, 8, 0, 0],
            [8, 0, 4, 6, 0, 0, 0, 0, 0],
            [0, 3, 9, 0, 5, 0, 0, 0, 0],
            [1, 0, 0, 5, 0, 0, 6, 0, 0],
            [0, 4, 0, 7, 0, 9, 0, 3, 0],
            [0, 0, 3, 0, 0, 8, 0, 0, 9],
            [0, 0, 0, 0, 7, 0, 3, 2, 0],
            [0, 2, 6, 0, 0, 1, 9, 0, 5],
            [0, 0, 7, 9, 2, 0, 0, 0, 4]
        ]
        self.len = len(self.board)
        self.start = time.time()
        self.end = 0

    def main(self):
        # Construct Main Program, maintain outputs

        self.print_board()
        self.solution()
        self.print_board(1)
        self.stats()

    def search_empty(self):
        # Search for empty slots

        for r in range(self.len):
            for c in range(self.len):
                if self.board[r][c] == 0:
                    return r, c   # return row, column

        return None

    def print_board(self, flag=0):
        # Print the sudoku board
        if flag == 1:
            print("\n= = = = = SOLUTION = = = = = \n")
        for r in range(self.len):
            if r % 3 == 0 and r != 0:
                print("= = = = = = = = = = = =")
            for c in range(len(self.board[0])):
                if c % 3 == 0 and c != 0:
                    print(" | ", end="")
                if c == 8:
                    print(self.board[r][c])
                else:
                    print(str(self.board[r][c]) + " ", end="")

    def find_valid_nums(self, pos):
        # find all possible valid numbers for position 'pos'

        num_list = list(range(1, 10))
        for i in range(9):
            if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
                num_list.remove(self.board[pos[0]][i])
            if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
                num_list.remove(self.board[i][pos[1]])
        x, y = pos[0] // 3, pos[1] // 3
        for i in range(3*x, 3*x + 3):
            for j in range(3*y, 3*y + 3):
                if self.board[i][j] != 0 and self.board[i][j] in num_list:
                    num_list.remove(self.board[i][j])
        return num_list

    def validate(self, num, pos):
        # Validate the selected choice

        for c in range(len(self.board[0])):
            if self.board[pos[0]][c] == num and pos[1] != c:
                return False
        for r in range(len(self.board[0])):
            if self.board[r][pos[1]] == num and pos[0] != r:
                return False
        x = pos[0] // 3
        y = pos[1] // 3
        for i in range(x*3, x*3 + 3):
            for j in range(y*3, y*3 + 3):
                if self.board[i][j] == num and (i,j) != pos:
                    return False

        return True

    def solution(self):
        # Solution, using DFS Algorithm

        find = self.search_empty()
        if not find:
            return True
        else:
            row, col = find
        for i in self.find_valid_nums([row, col]):
            if self.validate(i, (row, col)):
                self.board[row][col] = i

                if self.solution():
                    return True

                self.board[row][col] = 0

        return False

    def stats(self):
        self.end = time.time()
        print("Time elapsed: {}s".format(self.end - self.start))


if __name__ == "__main__":
    sud = Sudoku()
    sud.main()