Learning sudoku by doing gradient descent on a linear program¶
"OptNet: Differentiable Optimization as a Layer in Neural Networks" (arxiv) describes a way to differentiate linear programs, which is an interesting thing one might think to do. cvxpylayers
is a python package implementing their ideas using the cvxpy
backend, and I'm going to try and use that implementation to train a model to learn sudoku. You can download the Jupyter notebook here, which includes a bunch of printing/visualization code that I've left out.
Introduction¶
Constraint programming lets you find the optimal values of a function under some constraints. In general, we specfiy the dependent/output variables we're solving for, an objective function we'd like to minimize or maximize, and the constraints that define the set of solutions; then we call a solver to do math we don't understand. If there are no solutions, the LP is infeasible, and if there are infinite solutions, the LP is unbounded. There's a lot of good theory behind this stuff but practically, a solver will give you a solution or throw an error if your problem is malformed, infeasible, or unbounded.
Here we use cvxpy
to find the smallest value of $x + y$ under the constraints that $4x - y = 1$ and $x, y$ are nonnegative integers. It turns out $x=1$ and $y=3$.
import cvxpy as cvx
# Set up output variables
x = cvx.Variable(integer=True)
y = cvx.Variable(integer=True)
# Set up model parameters
a = 4
b = -1
# Set up problem constraints & objective
constraints = [a*x + b*y == 1, x >= 0, y >= 0]
objective = cvx.Minimize(cvx.sum(x))
problem = cvx.Problem(objective, constraints)
# Find a solution satisfying the constraints
problem.solve()
x.value.item(), y.value.item()
(1.0, 3.0)
We can generate completed sudokus like this by writing the rules as a set of constraints. The rules are that every row, column, and 3x3 square (what I've learned is called a 'chute') contains each number 1 through 9.
It's a bit tricky to do this using a 9x9 grid. Instead, we'll use a one-hot encoding, as in the OptNet paper. This means we have a 9x9x9 tensor indexed by row, column, and digit: board[row, col, digit]
. If we want to put a 1 in the top left corner of the board, we'd write board[0, 0, 0] = 1
; and to put a 2 directly to the right of it we'd write board[0, 1, 1] = 1
.
The rule each row has one of each digit
can be written in python-constraint form as sum(board[row, :, digit]) == 1 for all rows & digits
; for columns it'd be sum(board[:, col, digit]) == 1 for all columns & digits
. Since there's only one digit in each cell, we should also write sum(board[row, column, :]) == 1 for all rows & columns
.
# Set up one-hot output board
board = cvx.Variable((9, 9, 9), boolean=True)
# Set up one-hot sudoku constraints
rows_ct = [cvx.sum(board[r, :, d]) == 1 for r in range(9) for d in range(9)] # Each row has one of each digit
cols_ct = [cvx.sum(board[:, c, d]) == 1 for c in range(9) for d in range(9)] # Each col has one of each digit
nums_ct = [cvx.sum(board[r, c, :]) == 1 for r in range(9) for c in range(9)] # Each cell has one digit
chutes_ct = [cvx.sum(board[r*3:(r+1)*3, c*3:(c+1)*3, d]) == 1 for r in range(3) for c in range(3) for d in range(9)] # ditto w/ chutes
To solve from an initial solution or hint, we can find the board closest to it that satisfies the constraints. Since both of these are just vectors we'll minimize their distance from each other. Then our objective should be to find a board such that sum((board - clues)^2)
is minimal.
# Set up initial solution/input board
input = cvx.Parameter((9, 9, 9))
# Set up the objective & constraints
objective = cvx.Minimize(cvx.sum_squares(board - input)) # Find board closest to initial solution
constraints = rows_ct + cols_ct + nums_ct + chutes_ct # Constraints describing solved sudoku
Here we'll set the first row to be 1 through 9 in order, and solve from there.
import numpy as np
# Set initial row to be 1 -> 9
input.value = np.zeros((9, 9, 9))
for c in range(9):
input.value[0, c, c] = 1
# Set problem & solve
sudoku = cvx.Problem(objective, constraints)
sudoku.solve()
# Make the solution pretty & readable
to_decimal(board)
1 2 3 4 5 6 7 8 9 5 9 4 2 8 7 6 1 3 7 6 8 1 9 3 5 4 2 2 1 7 9 6 4 8 3 5 8 4 9 5 3 2 1 6 7 6 3 5 7 1 8 2 9 4 4 7 1 6 2 9 3 5 8 9 8 6 3 7 5 4 2 1 3 5 2 8 4 1 9 7 6
Linear programming (and constraint programming in general) seems to provide some powerful tools related to reasoning. By being creative with our constraints, we can prove some statements about sudoku to be true or false. We do this by adding a constraint that defines the set of counterexamples, then trying to solve the resulting problem. If we can't solve it, the set of counterexamples is empty and the statement is true; otherwise the solution is a counterexample.
Consider the statement a diagonal (starting at the top left) of a sudoku board has less than six unique numbers
. We know the diagonal has to have at least three unique numbers since a chute has to have a repeat otherwise, but arguments going the other way don't seem as straightforward.
We can test this statement by trying to find a counterexample, e.g. a sudoku with six numbers present and three missing on the diagonal. The particular numbers don't matter, so we can just pick 1 -> 6 and 7 -> 9 respectively. Then the constraint ensures that digits 1 -> 6 contain at least one entry on the diagonal, and digits 7 -> 9 contain none.
# Counterexample - a diagonal with exactly six unique numbers
diagonal_example_ct_1 = [cvx.sum([board[r, r, d] for r in range(9)]) >= 1 for d in range(6)]
diagonal_example_ct_2 = [cvx.sum([board[r, r, d] for r in range(9)]) == 0 for d in range(6, 9)]
cvx.Problem(cvx.Minimize(0), constraints + diagonal_example_ct_1 + diagonal_example_ct_2).solve()
to_decimal(board)
6 8 9 4 2 7 5 3 1 5 3 4 9 6 1 7 8 2 2 7 1 5 8 3 4 9 6 1 4 7 6 9 2 3 5 8 9 2 3 8 1 5 6 7 4 8 5 6 7 3 4 1 2 9 7 6 8 1 5 9 2 4 3 3 1 5 2 4 8 9 6 7 4 9 2 3 7 6 8 1 5
Here we can see the statement is false, since this is a valid sudoku with only numbers 1 -> 6 on the diagonal. In fact, a diagonal can have any number of unique digits $\geq$ three.
for n in range(1, 10):
diagonal_example_ct_1 = [cvx.sum([board[r, r, d] for r in range(9)]) >= 1 for d in range(n)]
diagonal_example_ct_2 = [cvx.sum([board[r, r, d] for r in range(9)]) == 0 for d in range(n, 9)]
diagonal_sudoku = cvx.Problem(cvx.Minimize(0), constraints + diagonal_example_ct_1 + diagonal_example_ct_2)
diagonal_sudoku.solve()
if diagonal_sudoku.status in ['infeasible', 'unbounded']:
print(f'{n}: no examples')
else:
print(f'{n}: yes examples')
1: no examples 2: no examples 3: yes examples 4: yes examples 5: yes examples 6: yes examples 7: yes examples 8: yes examples 9: yes examples
Hopefully this should convince you that linear programs have some powerful reasoning/AI potential. They seem to implement set logic, where sets are defined descriptively by a set of constraints. We can test whether something is in a set, whether the set is empty, (sometimes if we're creative enough) if it's in the complement, and so forth. But we don't do this by comparing set-elements; instead we use the set's definition (or 'properties', loosely speaking).
To drive this point further, I've made a linear program that generates sudokus uniquely defined by hints of a certain size. Here 'unique' means each cell has 8 other numbers among its row, column, and/or chute, i.e. has only one solution.
from itertools import product
# Generates random sudokus
class Sudoku:
def __init__(self):
self.board = cvx.Variable((9, 9, 9), boolean=True)
self.constraints\
= [cvx.sum(self.board[r, :, d]) == 1 for r in range(9) for d in range(9)]\
+ [cvx.sum(self.board[:, c, d]) == 1 for c in range(9) for d in range(9)]\
+ [cvx.sum(self.board[r, c, :]) == 1 for r in range(9) for c in range(9)]\
+ [cvx.sum(self.board[3*r:3*(r+1), 3*c:3*(c+1), d]) == 1 for r in range(3) for c in range(3) for d in range(9)]
self.make_objective = lambda: cvx.Minimize(cvx.sum_squares(self.board - self.random_ones()))
# Makes a sudoku uniquely defined by a hint of size n
def unique_hint_of_length_n(self, n):
minimum_hint = cvx.Variable((9, 9, 9), boolean=True)
hint_ct = [self.board >= minimum_hint, cvx.sum(minimum_hint) == n]
uniqueness_ct = []
for y in range(0, 9, 3):
for x in range(0, 9, 3):
for r in range(y, y+3):
for c in range(x, x+3):
uniqueness_ct += [cvx.sum(minimum_hint[y:y+3, x:x+3, :]) + cvx.sum(minimum_hint[:, c, :]) + cvx.sum(minimum_hint[r, :, :]) >= 8]
problem = cvx.Problem(cvx.Minimize(cvx.sum_squares(minimum_hint - self.random_ones())), constraints=self.constraints + hint_ct + uniqueness_ct)
problem.solve()
if problem.status not in ['infeasible', 'unbounded']:
return minimum_hint.value.copy()
def complete_solution(self, hint):
hint_ct = [self.board[r, c, d] == hint[r, c, d] for r, c, d in product(range(9), repeat=3) if hint[r, c, d] == 1]
sudoku = cvx.Problem(self.make_objective(), self.constraints + hint_ct)
sudoku.solve()
if sudoku.solution.status not in ['infeasible', 'unbounded']:
return self.board.value.copy()
# Minimizes solution's distance to some randomly strewn 1's
def random_ones(self, n_random=3):
random_ones = np.zeros((9, 9, 9), dtype=int)
for _ in range(n_random):
random_ones.reshape(-1)[np.random.randint(9*9*9)] = 1
return random_ones
def generate(self):
sudoku = cvx.Problem(self.make_objective(), self.constraints)
sudoku.solve()
if sudoku.solution.status not in ['infeasible', 'unbounded']:
return self.board.value.copy()
to_decimal(Sudoku().unique_hint_of_length_n(27))
0 3 0 0 0 6 9 0 0 0 0 9 0 7 8 0 0 0 6 0 0 0 0 0 0 3 7 5 0 4 0 0 0 2 0 0 0 0 8 5 0 0 3 0 0 0 0 0 4 1 0 0 0 5 3 0 0 8 0 0 0 2 0 0 0 0 0 9 4 6 0 0 0 5 2 0 0 0 0 4 0
Experiment Setup¶
For our experiment, we'll train a linear program to solve incomplete sudokus like in OptNet. As a linear program, it will minimize the distance to an input hint under learned constraints; as a neural net it minimizes the prediction error between the output and the complete solution, updating these constraints.
cvxpylayers
gives us this ability by putting an LP solver into a neural network layer. It's counterintuitive as to what exactly that would entail, since linear programs aren't usually represented as functions with inputs and outputs. Generally, the inputs to a linear program are an objective function we want to minimize $f(x)$ and some functions representing constraints in the form $g(x) \leq 0$ or $h(x) = 0$. The output is an optimal solution $x_{\text{opt}}$ minimizing $f$ under $g, h$. It turns out that under mild assumptions, whether $x$ is optimal depends on the derivatives of $f$, $g$, and $h$; and so perhaps some talented mathematicians can figure out the derivative of $x_{\text{opt}}$ from $f'$, $g'$, and $h'$ (which is what OptNet does).
A sort of brute force way to approach sudoku is to learn $g$ and $h$ as rules in the form of a (9*9*9)^2
tensor relating every cell to a possible game state. This is actually not too big of a number for modern hardware and is an ideal approach, since we might accidentally cheat and encode some prior information about the game trying to lower it. Unfortunately, cvxpy
doesn't seem to be able to handle matrices with (9*9*9)^2
floats, so I've devised a way of splitting up the workload among multiple layers in parallel. This probably has to do with the fact that cvxpy
is actually a compiler and the IR isn't lowered super far down into machine code.
import torch
import torch.nn as nn
from cvxpylayers.torch import CvxpyLayer
T = torch.float32
def add_and_norm(x, y):
z = x + y
m, s = z.mean(), z.std()
return (z - m) / s
# Layers representing the rules of sudoku perhaps
class SudokuNetLayer(nn.Module):
def __init__(self):
super().__init__()
# Poor man's attention -- learns to set the value of unimportant cells to 0
self.mask = nn.Parameter(torch.rand(9*9*9, dtype=T))
# LP trainable parameters
self._A = nn.Parameter(torch.rand(9*9*9, dtype=T), requires_grad=True)
self._b = nn.Parameter(torch.rand(9*9*9, dtype=T), requires_grad=True)
# LP solver parameters
self.input = cvx.Parameter(9*9*9)
self.A = cvx.Parameter(9*9*9)
self.b = cvx.Parameter(9*9*9)
# LP output variables
self.x = cvx.Variable(9*9*9)
# LP layer
self.objective = cvx.Minimize(cvx.sum_squares(self.input - self.x))
self.constraints = [self.A * self.x <= self.b]
self.problem = cvx.Problem(self.objective, self.constraints)
self.cvx = CvxpyLayer(self.problem, parameters=[self.input, self.A, self.b], variables=[self.x])
def forward(self, board):
board = board.reshape(-1)
mask = torch.clamp(self.mask, min=0, max=1)
input = board * mask
(x,) = self.cvx(input, self._A, self._b)
return x
# Model definition
class SudokuNet(nn.Module):
# Initialize the parallel layers & "just one more" attention layer
def __init__(self, n_rules=9):
super().__init__()
self.rules = nn.ModuleList()
for _ in range(n_rules):
self.rules.append(SudokuNetLayer())
self.attn = nn.MultiheadAttention(embed_dim=9*9*9, num_heads=9)
# Attend through the output of the parallel layers
def forward(self, input):
rule_predictions = tuple(rule(input) for rule in self.rules)
stacked = torch.stack(rule_predictions)
V, _ = self.attn(stacked, stacked, stacked, need_weights=False)
V = torch.sum(V, dim=0)
V = V.reshape(9, 9, 9)
return V
SudokuNet()
SudokuNet( (rules): ModuleList( (0-8): 9 x SudokuNetLayer( (cvx): CvxpyLayer() ) ) (attn): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=729, out_features=729, bias=True) ) )
To train the model, we'll give it a uniquely defining hint with a size of around 60-70, then minimize the mean squared error between the model's predictions and the complete solution. To track how well it's doing, we'll also save the absolute difference between the prediction and the complete solution for each digit into a black and white image, and call it the 'reconstruction error'.
# Logging & visualization
from torch.utils.tensorboard import SummaryWriter
log_dir = 'sudoku'
writer = SummaryWriter(log_dir=log_dir, flush_secs=1)
# Initialize sudoku generator, network & optimizer
sudoku = Sudoku()
net = SudokuNet()
opt = torch.optim.Adam(net.parameters(), lr=1e-4)
# Training loop
for step in range(60000):
# Generate a random sudoku board
input = sudoku.unique_hint_of_length_n(np.random.randint(60, 70))
solution = sudoku.complete_solution(input)
# Solve it
input = torch.tensor(input, dtype=T)
solution = torch.tensor(solution, dtype=T)
output = net(input)
# Accumulate loss
loss = nn.functional.mse_loss(output, solution)
# Backprop
opt.zero_grad()
loss.backward()
opt.step()
# Track loss & difference to solution
writer.add_scalars('loss', {'mse': loss.item()}, global_step=step)
for i in range(9):
writer.add_image(f'solution/{i}', solution[:, :, i], dataformats='HW', global_step=step)
writer.add_image(f'diff/{i}', torch.abs(output[:, :, i] - solution[:, :, i]), dataformats='HW', global_step=step)
step += 1
writer.flush()
Results & discussion¶
It takes on the order of days to train the model, but it does make progress. For reasons I'll get into in a moment, we can be relatively sure the model converges and will start to make correct predictions, but first I'd like to zoom out and get a bigger view of things.
The difference between science fiction AI and the AI models we hear so much about today might be summarized by counterfactual reasoning; i.e. the ability to correctly answer 'what if?' questions. We often say that an AI is supposed to learn a 'model of the world', but what ends up happening is it learns a passive set of relationships in some data. What we really want is for it to learn cause and effect, with an active idea of how relationships change in response to things. (This is a distinction I borrow from GEB, which I highly recommend everyone to read.). The connection between counterfactual reasoning and causal reasoning is well-known, to the point that some argue they are the same thing. And while the jury's still out as to whether or not that's true, it can't be denied that it's an essential part of reasoning in general. We know from our sudoku generator that linear programs can get counterfactuality for free, being able to even tell you when your hypotheticals don't warrant a response and are infeasible.
We do have models that reason counterfactually in the continuous case. Variational autoencoders, for example, can give reasonable answers to questions like "what's between a 9
and a g
". But this assumes this sort of question is meaningful, which is insane. This is a problem that's baked pretty deep into the AI architectures that've become recently popular, which essentially answer questions like "what would a thing in between the history of the Balkan wars
and the format of a 9th grade APUSH essay
look like". It's all running on gradient descent, and you can't take the gradient of a thing that's not continuous (unless you use subgradients, but that's not what these models use).
The goal, then, should be an architecture capable of learning counterfactual reasoning in the discrete case. Maybe instead of a discrete function, we could approximate the boundary of a related discrete set in some continuous way; this is the idea behind SudokuNet. In earlier drafts of this post I had the model learn 9x9 bitmasks representing rules, where for each bitmask the rule is cells a & b must contain different numbers if bitmask[a] == 1 and bitmask[b] == 1
; i.e. sum(bitmask * sudoku[:, :, d]) <= 1 for d in range(9)
. You might convince yourself that all the rules of sudoku can be represented this way, and an LP finding discrete solutions with the correct bitmask constraints would just be finding sudokus. In fact, it turns out encoding the rules this way gets you most of the way there. With some regularization (maximizing the number of 1's in a bitmask, making sure bitmasks are different from each other), a humble linear model will learn the correct masks with no fancy LP backpropagation needed. (Hence our current SudokuNet should eventually converge, since it's just a more complicated version with 9x9x9 bitmasks and less regularization.). This isn't surprising, since sudoku is a somewhat famous example of a game about satisfying constraints, but there does seem to be some interesting potential here for situations that are "like sudoku if you squint through a stack of 128 neural networks".
There are probably more specific methods of learning constraints for linear programs solving discrete problems. Straightforward gradient descent on a continuous model might sort of be like wandering drunk at night in comparison to these methods. For example, if you were able to determine not only if a constraint makes a solution infeasible, but how many solutions it excludes if it is feasible, you'd be able to greedily optimize for constraints that maximize this number, i.e. the constraints that have the most importance to the problem. There's some recent work related to this question, the volume of the set of solutions to a LP, that seems pretty interesting.
I'll be trying to add a comments section soon, but if you'd like to discuss this post you can find my info on my github. Thanks for reading!