Solving linear diophantine equations in Haskell

Kevin Cheung

November 19, 2022

Linear diophantine equations

A linear diophantine equation has the form \[a_1 x_1+ \cdots + a_n x_n = b\] where \(a_1,\ldots,a_n, b\) are integers and \(x_1,\ldots, x_n\) are integer variables.

The case when \(n = 2\) often appears in books and resources that discuss the extended Euclidean algorithm. In this article, we will look at the general case and develop a Haskell function that solves such equations.

Simplifying assumptions

We will make the simplyfing assumption that \(a_1,\ldots,a_n \geq 0\). The reason is that \((x_1,\ldots,x_n) = (v_1,\ldots,v_n)\) is a solution to \[a_1 x_1+ \cdots + a_n x_n = b\] if and only if \((x_1,\ldots,x_k,\ldots,x_n) = (v_1,\ldots,-v_k,\ldots,v_n)\) is a solution to \[a_1 x_1+ \cdots + (-a_k)x_k + \cdots + a_n x_n = b.\] Hence, if we have all the solutions to \[|a_1| x_1+ \cdots |a_n| x_n = b,\] we can easily recover all the solutions to \[a_1 x_1+ \cdots a_n x_n = b.\]

We also make the assumption that at least one of \(a_1,\ldots,a_n\) is equal to 0. Otherwise, the equation is trivial to solve.

A recursive algorithm

We now give a recursive algorithm. Let \(\mathbb{Z}\) denote the set of integers.

The base case is when exactly one of \(a_1,\ldots,a_n\), say \(a_k\), is nonzero.

In this case, the equation has a solution if and only if \(a_k\) divides \(b\). If \(a_k\) does divide \(b\), then the set of solutions is given by \[\{ (v_1,\ldots,b/a_k,\ldots, v_n) : v_i \in \mathbb{Z}, i \in \{1,\ldots,n\}\setminus \{k\}\}.\]

For example, for the equation \(a_1x_1 + a_2 x_2 + a_3x_3 = b\) where \(a_1 = a_2 = 0\), \(a_3 = 3\), and \(b = 6\), the set of solutions is \(\{ (u,v,2) : u, v \in \mathbb{Z}\}\).

Now, suppose that at least two of \(a_1,\ldots,a_n\), say \(a_k\), are nonzero. Let \(J = \{ i \in \{1,\ldots,n\} : a_j > 0\}\). Hence, \(J\) is the set of indices of the nonzero coefficients. Hence, all the solutions to \[\sum_{i = 1}^n a_i x_i = b\] are given by solutions t \[\sum_{i \in J} a_i x_i = b\] extended with arbitrary integer values for the variables \(x_i\), \(i \notin J\).

Choose \(k \in J\) such that \(a_k = \min \{ a_i : i \in J \}\) where ties are arbitrarily broken. For example, if \(a_1 = 3, a_2 = 2, a_3 = 2, a_4 = 0\), then \(k\) can be 2 or 3.

For each \(i \in J\), \(i\neq k\), let \(q_i\) and \(r_i\) denote, respectively, the quotient and remainder of \(a_i\) when divided by \(a_k\). Then, \(0 \leq r_i < a_i\) and \(a_i = q_i a_k + r_i\).

Observe that if \(x_i = v_i\), \(i \in J\) is a solution to \[a_k x_k + \sum_{i \in J, i\neq k} r_i x_i = b,\] then \(x_i = v_i\), \(i \in J, i \neq k\) and \(x_k = v_k - \displaystyle\sum_{i \in J, i \neq k} q_i v_i\), give a solution to \[\sum_{i \in J} a_i x_i = b.\]

Indeed, \[\begin{align*} a_k \left(v_k - \sum_{i \in J, i\neq k} q_i v_i\right) + \sum_{i \in J, i\neq k} a_i v_i & = a_k v_k + \sum_{i \in J, i\neq k} (a_i - q_i a_k) v_i \\ & = a_k v_k + \sum_{i \in J, i\neq k} r_i v_i \\ & = b \end{align*}\]

Conversely, if \(x_i = v_i\), \(i \in J\) is a solution to \[\sum_{i \in J} a_i x_i = b,\] then \(x_i = v_i\), \(i \in J, i \neq k\) and \(x_k = v_k + \displaystyle\sum_{i \in J, i \neq k} q_i v_i\), \[a_k x_k + \sum_{i \in J, i\neq k} r_i x_i = b.\] Indeed, \[\begin{align*} a_k \left(v_k + \sum_{i \in J, i\neq k} q_i v_i\right) + \sum_{i \in J, i\neq k} r_i v_i & = a_k v_k + \sum_{i \in J, i\neq k} (r_i + q_i a_k) v_i \\ & = a_k v_k + \sum_{i \in J, i\neq k} a_i v_i \\ & = \sum_{i \in J} a_i v_i = b \end{align*}\]

Consequently, all the solutions to \[a_1 x_1 + \cdots + a_n x_n = b\] are given by \(v \in \mathbb{Z}^n\) such that \(x_i = v_i\), \(i \in J\) is a solution to \[a_k x_k + \sum_{i \in J, i\neq k} r_i x_i = b.\] So we can solve this last equation and construct the set of solutions to the original equation.

We can continue this process until we reach the base case since the sum of coefficients decrease by a positive integer each time.

Example

We illustrate the algorithm on \[a_1 x_1 + a_2 x_2 + a_3 x_3 + a_4 x_4 = 2.\] where \(a_1 = a_3 = 4\), \(a_2 = 3\), and \(a_4 = 8\). (Hence, the equation is \(4 x_1 + 3 x_2 + 4 x_3 + 8 x_4 = 2.\))

Note that \(a_2\) is the smallest nonzero coefficient, we obtain the equation \[r_1x_1 + a_2 x_2 + r_3 x_3 + r_4 x_4 = 2.\] where \((q_1,r_1) = (q_3,r_3) = (1,1)\) and \((q_4, r_4) = (2,2)\). (Hence, the equation is \(x_1 + 3 x_2 + x_3 + 2 x_4 = 2.\))

Now, \(r_1\) and \(r_3\) are the smallest coefficients. We choose the former and obtain the equation \[x_1 + 0x_2 + 0x_3 + 0x_4 = 2.\] The set of solutions to this last equation is \[\{(2,u,v,w) : u,v,w \in \mathbb{Z}\}.\]

Hence, the set of solutions to \(r_1x_1 + a_2 x_2 + r_3 x_3 + r_4 x_4 = 2\) is \[\{(2-3u-v-2w,u,v,w) : u,v,w \in \mathbb{Z}\}.\]

Finally, the set of solutions to \(a_1 x_1 + a_2 x_2 + a_3 x_3 + a_4 x_4 = 2\) is \[\{(2-3u-v-2w,u-(2-3u-v-2w)-v-2w,v,w) : u,v,w \in \mathbb{Z}\}.\] Simplifying gives \[\{(2-3u-v-2w,-2+4u,v,w) : u,v,w \in \mathbb{Z}\}.\]

Alternatively, we can specify the solutions as \[(x_1,x_2,x_3,x_4) = (2,-2,0,0) + u(-3,4,0,0) + v(-1,0,1,0) + w(-2,0,0,1)\] for all \(u,v,w \in \mathbb{Z}.\)

Haskell implementation

We will develop a small module with a solveLinearDiophantine function. Let’s begin with

{-# LANGUAGE StrictData #-}

module SolveLinearDiophantine (solveLinearDiophantine) where

Next, we define a data type for the solution set.

In the example above, we see that there are two parts to the solution set. One consists of the particular solution \((2,-2,0,0)\). The other consist of arbitrary integer linear combinations of a finite set of integer tuples. Since we will need to refer to components of tuples by indices, a natural choice for the type for representing tuples is Vector. So we will use a single Vector for the first part and a list of Vectors for the second part. Namely,

import Data.Vector (Vector, (!), (//))

data Solution a = Solution !(Vector a) ![Vector a] deriving (Eq, Show)

Note that we use strict fields here to avoid space leak.

The type signature of solveLinearDiophantine is as follows:

solveLinearDiophantine :: Integral a =>    Vector a -- coefficients
                                        -> a        -- right-hand side
                                        -> Maybe (Solution a)

For the purpose of this article, we assume that all the coefficients are nonnegative to reduce noise in the code. Extending the implementation to handle negative coefficients is left as an exercise.

There are three cases to consider:

  1. All the coefficients are 0.
  2. Exactly one coefficient is nonzero.
  3. Two or more coefficients are nonzero.

So we begin with the following skeleton:

solveLinearDiophantine as b = 
  case length nzs of
    0 -> undefined
    1 -> undefined
    _ -> undefined
  where
    n = V.length as
    nzs = V.toList $ V.findIndices (/= 0) as -- indices to nonzero coefficients

The case when the coefficients are 0 is easy. If b is \(0\), then we can set the particular solution to \([0,\ldots,0]\) and the homogeneous solutions to the list \([1,0,\ldots,0], [0,1,0,\ldots,0],\ldots, [0,\ldots,0,1]\).

If exactly one coefficient is nonzero, then we check if it divides \(b\). If not, there is no solution. Otherwise, let \(q\) be the quotient. In this case, the particular solution must be \(q\) at the index of the nonzero coefficient and \(0\) everywhere else; the homogeneous solution is a list of all the tuples \([1,0,\ldots,0], [0,1,0,\ldots,0],\ldots, [0,\ldots,0,1]\) excluding the one with a \(1\) at the index of the nonzero coefficeint.

Since we need to frequently construct vectors with a particular value in one entry and \(0\) everywhere else, let’s have a helper function that does exactly that:

-- Generate a vector of length n with value v in index j and 0 everywhere else
genV :: Num a => Int -> a -> Int -> Vector a
genV n v j = V.generate n (\i -> if i == j then v else 0)

Code for handling the first two cases can now be added:

solveLinearDiophantine as b = 
  case length nzs of
    0 -> case b of
            0 -> let p = V.replicate n 0
                     h = map (genV n 1) [0..n-1]
                 in Just $ Solution p h
            _ -> Nothing
    1 -> let k = head nzs
             a = as ! k
             (q, r) = b `quotRem` a
         in case r of
              0 -> let p = genV n q k
                       h = [ genV n 1 i | i <- [0..n-1], i /= k ]
                   in Just $ Solution p h
              _ -> Nothing
    _ -> undefined
  where
    n = V.length as
    nzs = V.toList $ V.findIndices (/= 0) as

For the remaining case, we simply follow the mathematical description of the algorithm and form a new set of coefficients and make a recursive call. The returned solution will need to be updated accordingly. Since the upate needs to be performed on the returned particular solution as well as on each tuple in the homogeneous part, let’s have the following helper function:

modV :: Num a => Vector a -> Int -> [(Int, a)] -> Vector a
modV vs k jqs = let vk = vs ! k - sum [ q * vs ! j | (j, q) <- jqs ]
                in vs // [(k, vk)]

Mathematically, what this function does is the following: If vs is the tuple \((v_1,...,v_n)\) and jqs is the list of index-value pairs \((j_1, q_1),\ldots, (j_m,, q_m)\), then entry k of vs is replaced with \(v_k - \sum_{j = 1}^m q_j v_{j_1}\).

Putting everything together, we have the following code listing:

{-# LANGUAGE StrictData #-}

module SolveLinearDiophantine (solveLinearDiophantine) where

import qualified Data.Vector as V
import Data.Vector (Vector, (!), (//))

import Data.List (minimumBy)
import Data.Ord (comparing)

data Solution a = Solution !(Vector a) ![Vector a] deriving (Eq, Show)

-- Generate a vector of length n with v in index j and 0 everywhere else
genV :: Num a => Int -> a -> Int -> Vector a
genV n v j = V.generate n (\i -> if i == j then v else 0)

modV :: Num a => Vector a -> Int -> [(Int, a)] -> Vector a
modV vs k jqs = let vk = vs ! k - sum [ q * vs ! j | (j, q) <- jqs ]
                in vs // [(k, vk)]

solveLinearDiophantine :: Integral a => Vector a -> a -> Maybe (Solution a)
solveLinearDiophantine as b = 
  case length nzs of
    0 -> case b of
            0 -> let p = V.replicate n 0
                     h = map (genV n 1) [0..n-1]
                 in Just $ Solution p h
            _ -> Nothing
    1 -> let k = head nzs
             a = as ! k
             (q, r) = b `quotRem` a
         in case r of
              0 -> let p = genV n q k
                       h = [ genV n 1 i | i <- [0..n-1], i /= k ]
                   in Just $ Solution p h
              _ -> Nothing
    _ -> let (k, a) = minimumBy (comparing snd) [ (j, as ! j) | j <- nzs ]
             (qs, rs) = unzip [ let (q, r) = as ! j' `quotRem` a
                                in ((j', q), (j', r)) | j' <- nzs, j' /= k]
             as' = as // rs
         in case solveLinearDiophantine as' b of
              Nothing -> Nothing
              Just (Solution p h) -> 
                let p' = modV p k qs
                    h' = map (\vs -> modV vs k qs) h
                in Just $ Solution p' h'
  where
    n = V.length as
    nzs = V.toList $ V.findIndices (/= 0) as

This is probably not the most efficient implementation but it closely follows the mathematical description and gets the job done. If we now fire up ghci with this module loaded, we can solve the example above:

λ> solveLinearDiophantine (V.fromList [4,3,4,8]) (2::Int)
Just (Solution [2,-2,0,0] [[-3,4,0,0],[-1,0,1,0],[-2,0,0,1]])

The answer matches what was obtained by hand.

As mentioned earlier, you can now try to make the necessary modifications so that the coefficients do not need to be all nonnegative. (Hint: You might find the Haskell function signum helpful.)