Learning sudoku by doing gradient descent on a linear program¶

"OptNet: Differentiable Optimization as a Layer in Neural Networks" describes a way to differentiate linear programs, which is an interesting thing one might think to do. cvxpylayers is a python package implementing their ideas using the cvxpy backend, and I'm going to try and use that implementation to train a model to learn sudoku. You can download the Jupyter notebook here, which includes a bunch of printing/visualization code that I've left out.

Introduction¶

The general motivation and background for this topic is summarized here, with the "Integer linear programming for natural language processing applications" section being especially relevant.

Constraint programming lets you find the optimal values of a function under some constraints. In general, we specfiy the dependent/output variables we're solving for, an objective function we'd like to minimize or maximize, and the constraints that define the set of solutions; then we call a solver to do math we don't understand. If there are no solutions, the LP is infeasible, and if there are infinite solutions, the LP is unbounded. There's a lot of good theory behind this but practically, a solver will give you a solution or throw an error if your problem is infeasible, unbounded, or otherwise inadmissible.

Here I use cvxpy to find the smallest value of $x + y$ under the constraints that $4x - y = 1$ and $x, y$ are nonnegative integers. It turns out $x=1$ and $y=3$.

In [67]:
import cvxpy as cvx
import numpy as np

# Set up output variables
x = cvx.Variable(integer=True)
y = cvx.Variable(integer=True)

# Set up model parameters
a = 4
b = -1

# Set up problem constraints & objective
constraints = [a*x + b*y == 1, x >= 0, y >= 0]
objective   = cvx.Minimize(cvx.sum(x))
problem     = cvx.Problem(objective, constraints)

# Find a solution satisfying the constraints
problem.solve()
x.value.item(), y.value.item()
Out[67]:
(1.0, 3.0)

For things that can be totally defined by a set of rules, this is very useful. For example, we don't need to know anything apart from the rules of sudoku to make, complete, or verify a puzzle. By writing the rules as a set of constraints, we can use linear programming to do this naturally. The rules are that every row, column, and 3x3 square (what I've learned is called a 'chute') contains each number 1 through 9.

It's a bit tricky to do this using a 9x9 grid. Instead, I'll use a one-hot encoding, as in the OptNet paper. This means we have a 9x9x9 tensor indexed by row, column, and digit: board[row, col, digit]. If we want to put a 1 in the top left corner of the board, we'd write board[0, 0, 0] = 1, and to put a 2 directly to the right of it we'd write board[0, 1, 1] = 1.

The rule 'each row has one of each digit' can be written in python-constraint form as

sum(board[row, :, digit]) == 1 for all rows & digits;

for columns it'd be

sum(board[:, col, digit]) == 1 for all columns & digits.

Since there's only one digit in each cell, we should also write

sum(board[row, column, :]) == 1 for all rows & columns.

In [ ]:
# Set up one-hot output board
board = cvx.Variable((9, 9, 9), boolean=True)

# Set up one-hot sudoku constraints
rows_ct   = [cvx.sum(board[r, :, d]) == 1 for r in range(9) for d in range(9)] # Each row has one of each digit
cols_ct   = [cvx.sum(board[:, c, d]) == 1 for c in range(9) for d in range(9)] # Each col has one of each digit 
nums_ct   = [cvx.sum(board[r, c, :]) == 1 for r in range(9) for c in range(9)] # Each cell has one digit
chutes_ct = [cvx.sum(board[r*3:(r+1)*3, c*3:(c+1)*3, d]) == 1 for r in range(3) for c in range(3) for d in range(9)] # ditto w/ chutes

To solve from an initial solution or hint, we'd find the board closest to it that satisfies the constraints. Since they're both just vectors, we can just minimize their distance. Then, our objective should be to find a board such that sum((board - hint)^2) is minimal.

In [ ]:
# Set up initial solution/input board
input = cvx.Parameter((9, 9, 9))

# Set up the objective & constraints
objective = cvx.Minimize(cvx.sum_squares(board - input)) # Find board closest to initial solution
constraints = rows_ct + cols_ct + nums_ct + chutes_ct  # Constraints describing solved sudoku

We could see if there's a sudoku where the first row is the numbers 1 -> 9 by setting that as our hint:

In [70]:
# Set initial row to be 1 -> 9
input.value = np.zeros((9, 9, 9))
for c in range(9):
    input.value[0, c, c] = 1

# Set problem & solve
sudoku = cvx.Problem(objective, constraints)
sudoku.solve()

# Make the solution pretty & readable
to_decimal(board)
1 2 3 4 5 6 7 8 9 
5 9 4 2 8 7 6 1 3 
7 6 8 1 9 3 5 4 2 
2 1 7 9 6 4 8 3 5 
8 4 9 5 3 2 1 6 7 
6 3 5 7 1 8 2 9 4 
4 7 1 6 2 9 3 5 8 
9 8 6 3 7 5 4 2 1 
3 5 2 8 4 1 9 7 6 

Linear programming (and constraint programming in general) seems to provide some powerful tools related to reasoning. By being creative with constraints, one can prove statements about sudoku to be true, or otherwise find a counterexample. Add a constraint that describes the set of counterexamples, then try to solve the resulting problem -- if it's infeasible, there aren't any counterexamples and the statement is true; if it's feasible then any solution is a counterexample.

Consider the statement 'every diagonal (starting at the top left) of a sudoku board has less than six unique digits'. As a sort of example, we know the diagonal has to have at least three unique digits since a chute will have a repeat otherwise.

One counterexample would be a sudoku with six digits present and three missing on the diagonal. The particular numbers don't matter, so we can just pick 1 -> 6 and 7 -> 9 respectively. Then the counterexample constraint should be that digits 1 -> 6 have at least one entry on the diagonal, and digits 7 -> 9 have none.

In [ ]:
# Counterexample - a diagonal with exactly six unique digits
diagonal_example_ct_1 = [cvx.sum([board[r, r, d] for r in range(9)]) >= 1 for d in range(6)]
diagonal_example_ct_2 = [cvx.sum([board[r, r, d] for r in range(9)]) == 0 for d in range(6, 9)]
cvx.Problem(cvx.Minimize(0), constraints + diagonal_example_ct_1 + diagonal_example_ct_2).solve()
to_decimal(board)
6 8 9 4 2 7 5 3 1 
5 3 4 9 6 1 7 8 2 
2 7 1 5 8 3 4 9 6 
1 4 7 6 9 2 3 5 8 
9 2 3 8 1 5 6 7 4 
8 5 6 7 3 4 1 2 9 
7 6 8 1 5 9 2 4 3 
3 1 5 2 4 8 9 6 7 
4 9 2 3 7 6 8 1 5 

The statement's false, since here's a valid sudoku with digits 1 -> 6 on the diagonal. In fact, a diagonal can have any number of unique digits $\geq$ three.

In [120]:
for n in range(1, 10):
    diagonal_example_ct_1 = [cvx.sum([board[r, r, d] for r in range(9)]) >= 1 for d in range(n)]
    diagonal_example_ct_2 = [cvx.sum([board[r, r, d] for r in range(9)]) == 0 for d in range(n, 9)]
    diagonal_sudoku = cvx.Problem(cvx.Minimize(0), constraints + diagonal_example_ct_1 + diagonal_example_ct_2)
    diagonal_sudoku.solve()
    if diagonal_sudoku.status in ['infeasible', 'unbounded']:
        print(f'{n}: no examples')
    else:
        print(f'{n}: yes examples')
1: no examples
2: no examples
3: yes examples
4: yes examples
5: yes examples
6: yes examples
7: yes examples
8: yes examples
9: yes examples

I hope this illustrates some of the allure of linear programs. They implement set logic, where sets are defined descriptively by constraints. We can test whether something is in a set, whether a set is empty, (sometimes if we're creative enough) if it's in the complement, and so on. And we don't need to compare anything directly, instead we use a set's description (or properties, loosely speaking). On one hand, this might be limited to situations where sets are useful.

And to drive this point further, I've made a linear program for finding hints with solutions guaranteed to be unique. The constraint for 'uniqueness' here is that each cell in the hint has at least 8 digits among its row, column, and/or chute (i.e. each cell must have a unique solution).

In [ ]:
class Sudoku:
    # Makes a sudoku uniquely defined by a hint of size n
    def unique_hint_of_length_n(self, n):
        minimum_hint = cvx.Variable((9, 9, 9), boolean=True)
        hint_ct = [self.board >= minimum_hint, cvx.sum(minimum_hint) == n]
        uniqueness_ct = []
        for y in range(0, 9, 3):
            for x in range(0, 9, 3):
                for r in range(y, y+3):
                    for c in range(x, x+3):
                        uniqueness_ct += [cvx.sum(minimum_hint[y:y+3, x:x+3, :]) + cvx.sum(minimum_hint[:, c, :]) + cvx.sum(minimum_hint[r, :, :]) >= 8]
        problem = cvx.Problem(cvx.Minimize(cvx.sum_squares(minimum_hint - self.random_ones())), constraints=self.constraints + hint_ct + uniqueness_ct)
        problem.solve()
        if problem.status not in ['infeasible', 'unbounded']:
            return minimum_hint.value.copy()

    # Other stuff useful for later
    def __init__(self):
        self.board = cvx.Variable((9, 9, 9), boolean=True)
        self.constraints\
        = [cvx.sum(self.board[r, :, d]) == 1 for r in range(9) for d in range(9)]\
        + [cvx.sum(self.board[:, c, d]) == 1 for c in range(9) for d in range(9)]\
        + [cvx.sum(self.board[r, c, :]) == 1 for r in range(9) for c in range(9)]\
        + [cvx.sum(self.board[3*r:3*(r+1), 3*c:3*(c+1), d]) == 1 for r in range(3) for c in range(3) for d in range(9)]
        self.make_objective = lambda: cvx.Minimize(cvx.sum_squares(self.board - self.random_ones()))

    def complete_solution(self, hint):
        from itertools import product
        hint_ct = [self.board[r, c, d] == hint[r, c, d] for r, c, d in product(range(9), repeat=3) if hint[r, c, d] == 1]
        sudoku = cvx.Problem(self.make_objective(), self.constraints + hint_ct)
        sudoku.solve()
        if sudoku.solution.status not in ['infeasible', 'unbounded']:
            return self.board.value.copy()

    def random_ones(self, n_random=3):
        random_ones = np.zeros((9, 9, 9), dtype=int)
        for _ in range(n_random):
            random_ones.reshape(-1)[np.random.randint(9*9*9)] = 1
        return random_ones

    def generate(self):
        sudoku = cvx.Problem(self.make_objective(), self.constraints)
        sudoku.solve()
        if sudoku.solution.status not in ['infeasible', 'unbounded']:
            return self.board.value.copy()

to_decimal(Sudoku().unique_hint_of_length_n(27))
0 3 0 0 0 6 9 0 0 
0 0 9 0 7 8 0 0 0 
6 0 0 0 0 0 0 3 7 
5 0 4 0 0 0 2 0 0 
0 0 8 5 0 0 3 0 0 
0 0 0 4 1 0 0 0 5 
3 0 0 8 0 0 0 2 0 
0 0 0 0 9 4 6 0 0 
0 5 2 0 0 0 0 4 0 

Experiment¶

Naturally, we'll want to apply this as a deep learning model, so I don't have to think too hard to translate an idea into some constraint.

cvxpylayers lets us do this by putting an LP solver into a neural network layer. It's counterintuitive as to what exactly that would entail, since linear programs aren't usually represented as functions with inputs and outputs. Generally, the inputs to a linear program would be an objective function we want to minimize $f(x)$ with some functions representing constraints in the form $g(x) \leq 0$ or $h(x) = 0$. The output is an optimal solution $x_{\text{opt}}$ where $f(x_{\text{opt}})$ is minimal and $g(x_{\text{opt}}) \leq 0$ and $h(x_{\text{opt}}) = 0$. It turns out that, though under some assumptions, whether $x$ is optimal depends on only the first derivatives of $f$, $g$, and $h$; from here we know that someone with a greater mathematics background than ours' can find the derivative of $x_{\text{opt}}$ from $f'$, $g'$, and $h'$.

Appreciating this result, we can train a model to solve incomplete sudokus as in the OptNet paper (this one is larger). As a linear program, it tries to complete a hint by finding a solution closest to it under learned constraints; as a neural network it updates these constraints using the error between the completed hint and the actual solution.

A sort of brute force way to approach this is to learn $g$ and $h$ as rules in the form of a $(9 \cdot 9 \cdot 9)^2$ tensor relating every cell to a possible game state. This is actually not too big of a number for modern hardware and is an ideal approach, since we might accidentally 'cheat' and encode some prior information about the game trying to be more efficient. Unfortunately, cvxpy doesn't seem to be able to handle matrices with $(9 \cdot 9 \cdot 9)^2$ floats, so I've devised a way of splitting up the workload among multiple layers in parallel.

In [ ]:
import torch
import torch.nn as nn
from cvxpylayers.torch import CvxpyLayer

T = torch.float32

def add_and_norm(x, y):
    z = x + y
    m, s = z.mean(), z.std()
    return (z - m) / s

# Layers representing the rules of sudoku perhaps
class SudokuNetLayer(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        # Poor man's attention -- learns to set the value of unimportant cells to 0
        self.mask = nn.Parameter(torch.rand(9*9*9, dtype=T))

        # LP trainable parameters
        self._A = nn.Parameter(torch.rand(9*9*9, dtype=T), requires_grad=True)
        self._b = nn.Parameter(torch.rand(9*9*9, dtype=T), requires_grad=True)

        # LP solver parameters
        self.input = cvx.Parameter(9*9*9)
        self.A = cvx.Parameter(9*9*9)
        self.b = cvx.Parameter(9*9*9)

        # LP output variables
        self.x = cvx.Variable(9*9*9)

        # LP layer
        self.objective = cvx.Minimize(cvx.sum_squares(self.input - self.x))
        self.constraints = [self.A * self.x <= self.b]
        self.problem = cvx.Problem(self.objective, self.constraints)
        self.cvx = CvxpyLayer(self.problem, parameters=[self.input, self.A, self.b], variables=[self.x])

    def forward(self, board):
        board = board.reshape(-1)
        mask = torch.clamp(self.mask, min=0, max=1)
        input = board * mask
        (x,) = self.cvx(input, self._A, self._b)
        return x

# Model definition
class SudokuNet(nn.Module):

    # Initialize the parallel layers & "just one more" attention layer
    def __init__(self, n_rules=9):
        super().__init__()
        self.rules = nn.ModuleList()
        for _ in range(n_rules):
            self.rules.append(SudokuNetLayer())
        self.attn = nn.MultiheadAttention(embed_dim=9*9*9, num_heads=9)

    # Attend through the output of the parallel layers
    def forward(self, input):
        rule_predictions = tuple(rule(input) for rule in self.rules)
        stacked = torch.stack(rule_predictions)
        V, _ = self.attn(stacked, stacked, stacked, need_weights=False)
        V = torch.sum(V, dim=0)
        V = V.reshape(9, 9, 9)
        return V    

To train the model, I'll give a hint with a size of around 60-70, and minimize the mean squared error between the model's predictions and the unique solution. To track how well it's doing, I also save the absolute difference between the prediction and the complete solution for each digit into a grayscale image called reconstruction error.

There's a lot of tuning that can go into this. A lot of the layers are probably redundant since I didn't want to 'cheat'. There's also the question of how much better or worse the model trains given smaller or larger hints, or hints that don't lead to unique solutions. Getting GPU/MKL acceleration for cvxpy is not straightforward, so it's CPU-bound for now, which means training takes a very long time (the results you see here are gathered over days). It most definitely converges (as I'll talk about next), but I'll probably spin back on this post to confirm at some point.

In [ ]:
# Logging & visualization
from torch.utils.tensorboard import SummaryWriter
log_dir = 'sudoku'
writer = SummaryWriter(log_dir=log_dir, flush_secs=1)

# Initialize sudoku generator, network & optimizer
sudoku = Sudoku()
net = SudokuNet()
opt = torch.optim.Adam(net.parameters(), lr=1e-4)

# Training loop
for step in range(60000):

    # Generate a random sudoku board
    input = sudoku.unique_hint_of_length_n(np.random.randint(60, 70))
    solution = sudoku.complete_solution(input)

    # Solve it
    input = torch.tensor(input, dtype=T)
    solution = torch.tensor(solution, dtype=T)
    output = net(input)

    # Accumulate loss
    loss = nn.functional.mse_loss(output, solution)

    # Backprop
    opt.zero_grad()
    loss.backward()
    opt.step()

    # Track loss & difference to solution
    writer.add_scalars('loss', {'mse': loss.item()}, global_step=step)

    for i in range(9):
        writer.add_image(f'solution/{i}', solution[:, :, i], dataformats='HW', global_step=step)
        writer.add_image(f'diff/{i}', torch.abs(output[:, :, i] - solution[:, :, i]), dataformats='HW', global_step=step)

    step += 1
    writer.flush()
SudokuNet V2 loss trend & example prediction deltas

Results & discussion¶

For reasons I'll get into in a moment, we can be relatively sure the model converges and will start to make correct predictions, but first I'd like to zoom out and get a bigger view of things.

One thing we'd like to see in AGI could be described as precise causal reasoning. We often say that we want to learn a 'model of the world', but sometimes this can mean a passive set of relationships in some data. What we can get more use out of is cause and effect, or an active idea of how relationships change in response to things.(This is a distinction I borrow from GEB, which I highly recommend everyone to read.). Linear programs naturally provide a form of counterfactual reasoning, which has ties to causal reasoning so deep that some argue they're the same thing. We don't know if that's that's true, but it's unarguably an important part of reasoning in general either way.

The models we have can reason counterfactually on continuous data pretty well. Variational autoencoders, for example, can give reasonable answers to questions like "what's between a 9 and a g". But their methods assume such answers are inherently meaningful, even if this question is itself without meaning. This assumption is baked pretty deep into the AI architectures of recent popularity. They all run gradient descent, with the justification that relationships between data can be approximated as continuous. However, we don't like approximations when they're wrong, and the fact that they'd exist discretely in machines anyway makes this kind of a moot point.

We don't yet have models that reason counterfactually on discrete data, at least not comparably well. Gradient descent is entrenched as unreasonably effective and it's not clear what a purely discrete alternative looks like. But there might be new workarounds; we know linear programs are counterfactual, discrete or otherwise, and now they're differentiable. Instead of a discrete function, we could approximate the boundary of a related discrete set in some continuous way; that's the idea behind SudokuNet.

Previous drafts had the model learn 9x9 bitmasks representing rules, where the rules are 'all digits covered by a bitmask must be unique'; i.e. sum(bitmask * sudoku[:, :, d]) <= 1 for d in range(9) for all the bitmasks. You might convince yourself that any rule in sudoku can be written like this, and an LP solving integers under correct masks would be solving arbitrary sudokus. In fact, just restricting the possible rules this way gets you most of the way there. With some regularization (maximizing the number of 1's in a bitmask, making sure bitmasks are different from each other), a linear model will learn the correct masks without backpropagating through a linear program. This is part of the "key advantage" from learning declaratively as in constrained conditional models, and why our straightforwardly larger SudokuNet should converge.

We shouldn't be surprised that a model designed around constraints, could easily learn a game about constraints. But more broadly, in situations where rules are more important than examples, or are "like sudoku if you squint through a stack of 128 neural networks", the potential of this method is untapped.

SudokuNet V1 loss trend & example prediction deltas

Because constraint programming methods and discrete problems already have such a well-established relationship, gradient descent might end up looking sort of like wandering drunk at night when it comes to sample efficiency. For example, if you were able to determine not only if a constraint makes a solution infeasible, but how many solutions it excludes if it is feasible, you'd be able to greedily optimize for constraints that maximize this number, i.e. the constraints that have the most importance to the problem. There's some recent work related to this question, the volume of the set of solutions to a LP, that seems pretty interesting.

I'll be trying to add a comments section soon, but if you'd like to discuss this post you can find my info on my github. Thanks for reading!