Gurobi學習筆記——求解數獨問題


Gurobi學習筆記——求解數獨問題

本文以Gurobi官方提供的數獨案例為例,將介紹以下知識點:

  • 設置變量的屬性Attribute
  • 如何固定變量的值
  • 使用生成器添加多個約束
  • quicksum() 函數的使用

設置變量的屬性

Gurobi中的Var類具有多個屬性(Attribute),如LB,UB,Obj等。詳情可以參見文檔根目錄/docs/refman.pdf

這些屬性可以在創建變量時,使用關鍵字參數傳入,也可以調用相應的方法進行相應的更改

# 創建一個目標漢中的系數為2.0,變量名稱為x的0-1變量
x = m.addVar(obj = 2.0, vtype = GRB.BINARY, name = 'x')
# 在得到Var對象后,再對屬性進行修改
var.setAttr(GRB.Attr.UB, 0.0)
var.setAttr("ub", 0.0)
# 可以直接用.訪問其屬性
var.UB = 0

固定變量的值

如果要固定某一變量的值,可以令該變量的上限和下限等於同一個常數

constant = 9
var.UB = constant
var.LB = constant

注意固定變量的值不等同於變量中Start屬性Start屬性為MIP問題中的初始解,可以用一些啟發式規則得到該問題的某一可行解,並以此進行計算。之后的初始解可能會發生改變。因此,應注意區分。

生成器生成多個約束

Python中的生成器(generator)對象,可以在需要的時候再對某一循環進行按需調用。詳情可通過廖雪峰老師的教程了解。

在Gurobi中,生成器通常按照列生成式的方法編寫,只不過將外部的方括號[]變為圓括號()

# 生成一個2*x列表 (x= 1,...,9)
>>> L = [2x for x in range(10)]
>>> L
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> g = (2x for x in range(10))
>>> g
<generator object <genexpr> at 0x1022ef630>

在日常建模中,通常針對某一集合中的每個元素遍歷並添加約束,如:

我們可以使用m.addConstr(),針對每一種i,j的組合,為其添加約束。

Gurobi也支持將生成器傳入m.addConstrs(), 以達到批量添加的目的。

# 創建3維0-1變量
vars = m.addVars(n, n, n, vtype=GRB.BINARY, name="G")

for i in range(n):
    for j in range(n):
        m.addConstr(vars.sum(i,j,'*')==1)
# 等價於
m.addConstrs((
      vars.sum(i,j,'*')==1 
      for i in range(n) 
      for j in range(n)))

兩種寫法的區別僅僅是將for循環寫在外面還是使用生成器,其他方面沒有區別,但前者的可讀性可能會稍強些

quicksum()函數的使用

quicksum(data)是Gurobi推薦的求和函數,其執行效率高於Python內置的sum()函數。因此,在大規模添加模型時,建議優先使用quicksum()

quicksum(data)接受含有Var或者表達式(LinExpr, QuadExpr)的List對象,並將其中的所有的元素相加,生成求和表達式

expr = quicksum([2*x, 2*y+1, 4*z*z])
expr = quicksum([x, y, z])

上文中的案例,除了可以用tupledictsum函數,也可以寫作quicksum

不過tupledictsum方法支持通配符*,書寫起來更加簡便。

for i in range(n):
    for j in range(n):
    	# 對於當前的i和j,在v維度上進行求和
        m.addConstr((gp.quicksum(vars[i,j,v] for v in range(n))==1)

前面說data需要List類型的對象,此處可以理解為生成器作為參數傳入list()函數中,可以一次性將生成器轉化為list對象

數獨案例

參考自Gurobi自帶的案例文件根目錄/examples/python/sudoku.py

如果要運行該案例,需要在命令行中添加數據文件根目錄/examples/data/sudoku1

數獨盤面是個九宮,每一宮又分為九個小格。在這八十一格中給出一定的已知數字和解題條件,利用邏輯和推理,在其他的空格上填入1-9的數字。使1-9每個數字在每一行、每一列和每一宮中都只出現一次,所以又稱“九宮格”。(參考自百度百科

本案例采用三維0-1決策變量x(i,j,v)x(i,j,v)=1代表在第ij列的格子上填的數字為v+1(i, j, k均從0開始計數);否則,x(i,j,v)=0

因此,具有以下約束:

約束1:每個格子只能有一個數

約束2: 每行元素不重復

約束3: 每列元素不重復

約束4: 每個子區域內(3*3),沒有重復的元素

以下是本人增加中文注釋后的代碼,希望通過前面的解讀,還是比較易懂的。

import gurobipy as gp
from gurobipy import GRB
import math


"""
假設數獨模型是這個樣子
.284763..
...839.2.
7..512.8.
..179..4.
3........
..9...1..
.5..8....
..692...5
..2645..8
"""
# 假設數據存放在同目錄下的data文件夾下
f = open("./data/sudoku1")

grid = f.read().split()

n = len(grid)
s = int(math.sqrt(n))

m = gp.Model()
vars = m.addVars(n, n, n, vtype=GRB.BINARY, name="G")


# 讀入數據,將已知的
for i in range(n):
    for j in range(n):
        # 如果該位置的數已知,則通過設置LB的方式,固定變量
        if grid[i][j] != '.':
            # 注意此處索引方式的不同
            # grid為二維list, vars為dict
            v = int(grid[i][j]) - 1
            vars[i, j, v].LB = 1

# 同一個位置,只能選一個數字
for i in range(n):
    for j in range(n):
        m.addConstr(vars.sum(i,j,'*')==1)
# 等價於
# m.addConstrs((
#       vars.sum(i,j,'*')==1 
#       for i in range(n) 
#       for j in range(n)))


# 添加行約束
# 對於每行而言,每個數字只能出現一次
for i in range(n):
    for v in range(n):
        m.addConstr(vars.sum(i, '*', v) == 1)
# 等價於
# m.addConstrs((
#     vars.sum('*', j, v) == 1 
#     for i in range(n) 
#     for v in range(n)))


# 添加列約束
# 對於每列而言,每個數字只能出現一次
for j in range(n):
    for v in range(n):
        m.addConstr(vars.sum('*', j, v) == 1)
#等價於
# m.addConstrs((
#     vars.sum(i, '*', v) == 1 
#     for j in range(n) 
#     for v in range(n)))


# 添加子矩陣約束
# 每個子矩陣內,數字不能重復
for i0 in range(s):
    for j0 in range(s):
        for v in range(n):
            m.addConstr(gp.quicksum(vars[i, j, v] for i in range(
                i0*s, (i0+1)*s) for j in range(j0*s, (j0+1)*s)) == 1)
# 等價於
# m.addConstrs((
#     gp.quicksum(vars[i, j, v] for i in range(i0*s, (i0+1)*s) for j in range(j0*s, (j0+1)*s)) == 1
#      for i0 in range(s) 
#      for j0 in range(s) 
#      for v in range(n)))


# 開始優化模型
m.optimize()

# 獲取vars變量的X屬性
# 獲得的tupledict對象,solution
solution = m.getAttr('X', vars)

for i in range(n):
    sol = ''
    for j in range(n):
        for v in range(n):
            if solution[i, j, v] > 0.5:
                sol += str(v+1)
    print(sol)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM