回溯算法: 以数独求解为例
回溯问题的解决核心是遍历决策树. 在进行决策时, 实际需要考虑的问题是:
- 路径: 目前已经做的决策是什么?
- 选择列表: 还可以做那些选择?
- 结束条件: 何时算作 “到达决策树底端”, 问题解决?
回溯问题的框架为:
1
2
3
4
5
6
7
8
9
10
11
12
result = []
def backtrack(路径, 选择列表):
if 满足结束条件:
result.add(新路径)
return
for 选择 in 选择列表:
作出新选择 # 在原有决策的基础上作出下一步的选择
backtrack(新路径, 新选择列表) #检查能否进一步做选择
撤销新选择 # 如果不能进一步做选择的话则撤销无效的选择
下面以数独求解问题为例:
Q: 给定一个 $9*9$ 的数独, 求出它的其中一个解.
程序输出:
- 我们需要打印待求解的数独和数独的一个解. 鉴于我们求解过程中修改是在数独的基础上进行的, 因此两个情况的代码可以复用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
def print_board(self, flag=0): # Print the sudoku board if flag == 1: print("\n= = = = = SOLUTION = = = = = \n") for r in range(self.len): if r % 3 == 0 and r != 0: print("= = = = = = = = = = = =") for c in range(len(self.board[0])): if c % 3 == 0 and c != 0: print(" | ", end="") if c == 8: print(self.board[r][c]) else: print(str(self.board[r][c]) + " ", end="")
- 除此以外, 我们还希望得知程序的运行时间以评价执行效率:
1 2 3
def stats(self): self.end = time.time() print("Time elapsed: {}s".format(self.end - self.start))
程序逻辑:
- 要求解数独, 我们首先需要找到待填的空位:
1 2 3 4 5 6 7 8 9
def search_empty(self): # Search for empty slots for r in range(self.len): for c in range(self.len): if self.board[r][c] == 0: return r, c # return row, column return None
- 要定义回溯函数, 我们还需要确定选择列表:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
def find_valid_nums(self, pos): # find all possible valid numbers for position 'pos' num_list = list(range(1, 10)) for i in range(9): if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list: num_list.remove(self.board[pos[0]][i]) if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list: num_list.remove(self.board[i][pos[1]]) x, y = pos[0] // 3, pos[1] // 3 for i in range(3*x, 3*x + 3): for j in range(3*y, 3*y + 3): if self.board[i][j] != 0 and self.board[i][j] in num_list: num_list.remove(self.board[i][j]) return num_list
- 此外, 我们还需要依据数独规则判断作出的选择是否有效:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
def validate(self, num, pos): # Validate the selected choice for c in range(len(self.board[0])): if self.board[pos[0]][c] == num and pos[1] != c: return False for r in range(len(self.board[0])): if self.board[r][pos[1]] == num and pos[0] != r: return False x = pos[0] // 3 y = pos[1] // 3 for i in range(x*3, x*3 + 3): for j in range(y*3, y*3 + 3): if self.board[i][j] == num and (i,j) != pos: return False
- 最后, 我们可以定义回溯主函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
def solution(self): # Solution, using DFS Algorithm find = self.search_empty() if not find: return True else: row, col = find for i in self.find_valid_nums([row, col]): if self.validate(i, (row, col)): self.board[row][col] = i if self.solution(): return True self.board[row][col] = 0 return False
程序整理总结如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time
class Sudoku:
"""Define Sudoku Class, and work out its solution"""
def __init__(self):
# Initialize
self.board = [
[7, 0, 0, 0, 3, 4, 8, 0, 0],
[8, 0, 4, 6, 0, 0, 0, 0, 0],
[0, 3, 9, 0, 5, 0, 0, 0, 0],
[1, 0, 0, 5, 0, 0, 6, 0, 0],
[0, 4, 0, 7, 0, 9, 0, 3, 0],
[0, 0, 3, 0, 0, 8, 0, 0, 9],
[0, 0, 0, 0, 7, 0, 3, 2, 0],
[0, 2, 6, 0, 0, 1, 9, 0, 5],
[0, 0, 7, 9, 2, 0, 0, 0, 4]
]
self.len = len(self.board)
self.start = time.time()
self.end = 0
def main(self):
# Construct Main Program, maintain outputs
self.print_board()
self.solution()
self.print_board(1)
self.stats()
def search_empty(self):
# Search for empty slots
for r in range(self.len):
for c in range(self.len):
if self.board[r][c] == 0:
return r, c # return row, column
return None
def print_board(self, flag=0):
# Print the sudoku board
if flag == 1:
print("\n= = = = = SOLUTION = = = = = \n")
for r in range(self.len):
if r % 3 == 0 and r != 0:
print("= = = = = = = = = = = =")
for c in range(len(self.board[0])):
if c % 3 == 0 and c != 0:
print(" | ", end="")
if c == 8:
print(self.board[r][c])
else:
print(str(self.board[r][c]) + " ", end="")
def find_valid_nums(self, pos):
# find all possible valid numbers for position 'pos'
num_list = list(range(1, 10))
for i in range(9):
if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
num_list.remove(self.board[pos[0]][i])
if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
num_list.remove(self.board[i][pos[1]])
x, y = pos[0] // 3, pos[1] // 3
for i in range(3*x, 3*x + 3):
for j in range(3*y, 3*y + 3):
if self.board[i][j] != 0 and self.board[i][j] in num_list:
num_list.remove(self.board[i][j])
return num_list
def validate(self, num, pos):
# Validate the selected choice
for c in range(len(self.board[0])):
if self.board[pos[0]][c] == num and pos[1] != c:
return False
for r in range(len(self.board[0])):
if self.board[r][pos[1]] == num and pos[0] != r:
return False
x = pos[0] // 3
y = pos[1] // 3
for i in range(x*3, x*3 + 3):
for j in range(y*3, y*3 + 3):
if self.board[i][j] == num and (i,j) != pos:
return False
return True
def solution(self):
# Solution, using DFS Algorithm
find = self.search_empty()
if not find:
return True
else:
row, col = find
for i in self.find_valid_nums([row, col]):
if self.validate(i, (row, col)):
self.board[row][col] = i
if self.solution():
return True
self.board[row][col] = 0
return False
def stats(self):
self.end = time.time()
print("Time elapsed: {}s".format(self.end - self.start))
if __name__ == "__main__":
sud = Sudoku()
sud.main()