Let’s Program a Calculus Student II
Turning Symbolic Differentiation Automatic25 April 2022
On the previous post, we wrote a data type representing a formula that could appear in a Calculus class and discussed how to find its derivative. The approach that we chose was rather algebraic: we took each of the formulas for a derivative and taught the program how to recursively apply them.
Today we will redefine these symbolic derivatives using a different approach: automatic differentiation. This new way to calculate derivatives will only depend on the evaluation function for expressions. This decouples differentiation from whatever representation we choose for our expressions and, even more important, it is always nice to learn different ways to build something!
I first heard of this idea while reading the documentation of the ad package and just had my mind blown. In here we will follow a simpler approach by constructing a simple AD implementation but for any serious business, I really recommend taking a look at that package. It is really awesome.
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving, DeriveFunctor #-}
module Calculus.AutoDiff where
import Calculus.Expression
On polymorphism, evaluation and reflection
Recall our evaluation function from the previous post. Its signature was
eval :: Floating a => Expr a -> a -> a
The way we interpreted it was that if we supplied an
expression e
and a value c
of
type a
, it would collapse the expression
substituting all instances of the variable
X
by c
and return the
resulting value. But thanks to currying we may also view
eval
as taking an expression e
and returning a Haskell function
eval e :: a -> a
. Thus, our code is
capable of transforming expressions into functions.
At this point, one may ask if we also could do the opposite. So, can we take an ordinary Haskell function and find the symbolic expression that it represents? The answer, quite surprisingly to me, is: yes, provided that it is polymorphic.
If you take a function such as
g :: Double -> Double
that only works
for a single type1, all
hope is lost. Any information regarding “the shape” of
the operation performed by the function will have
already disappeared at runtime and perhaps even been
optimized away by the compiler (as it should be).
Nevertheless, polymorphic functions that work for any
Floating type, such as
f :: Floating a => a -> a
, are
flexible enough to still carry information about its
syntax tree even at runtime. One reason for this is that
we defined a Floating
instance for
Expr a
, allowing the function
f
to be specialized to the type
Expr a -> Expr a
. Thus, we can convert
between polymorphic functions and expressions.
uneval :: (forall a. Floating a => a -> a) -> (forall b. Floating b => Expr b)
Notice the explicit forall
:
uneval
only accepts polymorphic
arguments.2 After
finding the right type signature, the inverse to
eval
is then really simple to write. The
arithmetic operations on a Expr a
just
build a syntax tree, thus we can construct an expression
from a polymorphic function by substituting its argument
by the constructor X
.
= f X uneval f
Let’s test it on ghci to see that it works:
ghci> uneval (\x -> x^2 + 1)
X :*: X :+: Const 1.0
it :: Floating b => Expr b
ghci> uneval (\x -> exp (-x) * sin x)
Apply Exp (Const (-1.0) :*: X) :*: Apply Sin X
it :: Floating b => Expr b
The uneval
function allows us to compute
a syntax tree for a polymorphic function during a
program’s runtime. We can then manipulate this
expression and turn the result back into a function
through eval
. Or, if we know how to do some
interesting operation with functions, we can do the
opposite process and apply it to our expression! This
will be our focus on the next section.
Automatic Differentiation
In math, derivatives are concisely defined via a limiting process: f'(x) = \lim_{\varepsilon \to 0}\frac{f(x + \varepsilon) - f(x)}{\varepsilon}.
But when working with derivatives in a computer program, we can’t necessarily take limits of an arbitrary function. Thus, how to deal with derivatives?
One approach is numerical differentiation, where we approximate the limit by using a really small \varepsilon:
= (f (x + eps) - f x) / eps
numDiff' eps f x
= numDiff' 1e-10 numDiff
This is prone to numerical stability issues and doesn’t compute the real derivative but only an approximation to it.
Another approach is what we followed in the previous
post: symbolic differentiation. This is the
same way that one is used to compute derivatives by
hand: you take the algebraic operations that you learned
in the calculus class and implement them as
transformations on an expression type representing a
syntax tree. One difficult of this, as you may have
noticed, is that symbolic calculations require lots of
rewriting to get the derivative in a proper form. They
also require that you work directly with expressions and
not with functions. This, despite being mitigated by our
eval
and uneval
operators, can
be pretty inefficient when your code is naturally
composed of functions. Besides that, if we wanted to
change our Expr
type, for example, to use a
more efficient operation under the hood, or adding a
:^:
constructor for power operations, or
adding new transcendental functions, we would have to
modify both our eval
and diff
functions to consider this.
A third option that solves all the previous issues is
Automatic
differentiation. This uses the fact that any
Floating a => a -> a
is in fact a
composition of arithmetic operations and some simple
transcendental functions such as exp
,
cos
, sin
, etc. Since we know
how to differentiate those, we can augment our function
evaluation to calculate at the same time both the
function value and the exact value of the derivative at
any given point. As we will see, we will even be able to
recover symbolic differentiation as a subcase of
automatic differentiation.
Dual Numbers
Here we will do the simplest case of automatic differentiation, namely forward-mode AD using dual numbers. This is only for illustrative purposes. If you are planning in to use automatic differentiation in a program, I like recommend taking a look at the ad package.
In mathematics, a dual number is an expression a + b\varepsilon with the additional property that \varepsilon^2 = 0. One can think of it as augmenting the real numbers with an infinitesimal factor. As another intuition: this definition is very similar to the complex numbers, with the difference that instead of i^2 = -1, we have \varepsilon^2 = 0.3
The nicety of the dual numbers is that they can
automatically calculate the derivative of any analytic
function. To view how this works, let’s look at the
Taylor Series of a f
expanded around a
point a
.
f(a + b\varepsilon) = \sum_{n=0}^\infty \frac{1}{n!}f^{(n)} (b\varepsilon)^n = f(a) + bf'(a)\varepsilon.
Therefore, applying f to a number with an infinitesimal part amounts to taking its first order expansion.
Ok, back to Haskell. As usual, we translate this definition into Haskell as a parameterized data type carrying two values.
data Dual a = Dual a a
deriving (Show, Eq)
Later, it will also be useful to have functions extracting the real and infinitesimal parts of a dual number.
Dual a _) = a
realPart (Dual _ b) = b epsPart (
Alright, just like with expressions we will want to
make Dual a
into a number. The sum and
product of two dual numbers are respectively linear and
bilinear because, well… Because we wouldn’t be calling
it “sum” and “product” if they weren’t. In math, it
reads as
\begin{aligned} (a + b\varepsilon) + (c + d\varepsilon) &= (a + c) + (b + d)\varepsilon, \\ (a + b\varepsilon) \cdot (c + d\varepsilon) &= ac + (bc + ad)\varepsilon + \cancel{bd\varepsilon^2}. \end{aligned}
If you found those as having a strong resemblance to the sum and product rules for derivatives, is because they do! These are our building blocks for differentiation.
instance Num a => Num (Dual a) where
-- Linearity
Dual a b) + (Dual c d) = Dual (a + c) (b + d)
(Dual a b) - (Dual c d) = Dual (a - c) (b - d)
(-- Bilinearity and cancel ε^2
Dual a b) * (Dual c d) = Dual (a * c) (b*c + a*d)
(-- Embed integers as only the real part
fromInteger n = Dual (fromInteger n) 0
-- These below are not differentiable functions...
-- But their first order expansion equals this except at zero.
abs (Dual a b) = Dual (abs a) (b * signum a)
signum (Dual a b) = Dual (signum a) 0
For division, we use the same trick as with complex numbers and multiply by the denominators conjugate.
\begin{aligned} \frac{a + b\varepsilon}{c + d\varepsilon} &= \frac{a + b\varepsilon}{c + d\varepsilon} \cdot \frac{c - d\varepsilon}{c - d\varepsilon} \\ &= \frac{ac + (bc - ad)\varepsilon}{c^2} \\ &= \frac{a}{c} + \frac{bc - ad}{c^2}\varepsilon \end{aligned}
instance (Fractional a) => Fractional (Dual a) where
Dual a b) / (Dual c d) = Dual (a / c) ((b*c - a*d) / c^2)
(fromRational r = Dual (fromRational r) 0
Finally, to extend the transcendental functions to the dual numbers, we use the first order expansion described above. We begin by writing a helper function that represents this expansion.
-- First order expansion of a function f with derivative f'.
fstOrd :: Num a => (a -> a) -> (a -> a) -> Dual a -> Dual a
Dual a b) = Dual (f a) (b * f' a) fstOrd f f' (
And the floating instance is essentially our calculus cheatsheet again.
instance Floating a => Floating (Dual a) where
-- Embed as a real part
pi = Dual pi 0
-- First order approximation of the function and its derivative
exp = fstOrd exp exp
log = fstOrd log recip
sin = fstOrd sin cos
cos = fstOrd cos (negate . sin)
asin = fstOrd asin (\x -> 1 / sqrt (1 - x^2))
acos = fstOrd acos (\x -> -1 / sqrt (1 - x^2))
atan = fstOrd atan (\x -> 1 / (1 + x^2))
sinh = fstOrd sinh cosh
cosh = fstOrd cosh sinh
asinh = fstOrd asinh (\x -> 1 / sqrt (x^2 + 1))
acosh = fstOrd acosh (\x -> 1 / sqrt (x^2 - 1))
atanh = fstOrd atanh (\x -> 1 / (1 - x^2))
Derivatives of functions
Now that we have setup all the dual number tooling, it is time to calculate some derivatives. From the first order expansion f(a + b\varepsilon) = f(a) + bf'(a)\varepsilon, we see that by applying a function to a + \varepsilon, that is, setting b = 1, we calculate f and its derivative at a. Let’s test this in ghci:
ghci> f x = x^2 + 1
f :: Num a => a -> a
ghci> f (Dual 3 1)
Dual 10 6
it :: Num a => Dual a
Just as we expected! We can thus write a differentiation function by doing this procedure and taking only the \varepsilon component.
= epsPart (f (Dual c 1)) autoDiff f c
Some cautionary words: remember from the previous
discussion that to have access to the structure of a
function, we need it to be polymorphic. In special, our
autoDiff
has type
Num a => (Dual a -> Dual b) -> (a -> b)
.
It gets a function on dual numbers and spits out a
function on numbers. But, for our use case it is fine
because we can specialize this signature to
autoDiff :: (forall a . Floating a => a -> a) -> (forall a . Floating a => a -> a)
Derivatives of expressions
Recall we can use eval
to turn an
expression into a function and, reciprocally, we can
apply a polymorphic function to the constructor
X
to turn it into an expression. This
property, which for the mathematicians among you
probably resembles a lot a similarity transformation,
allows us to “lift” autoDiff
into the world
of expressions. So, what happens if we take
eval f
and compute its derivative at the
point X
? We get the symbolic
derivative of f
of course!
= autoDiff (eval f) X diff_ f
Some tests in the REPL to see that it works:
ghci> diff_ (sin (X^2))
(Const 1.0 :*: X :+: X :*: Const 1.0) :*: Apply Cos (X :*: X)
it :: Floating a => Expr a
This function has a flaw nevertheless. It depends too
much on polymorphism. While our symbolic differentiator
from the previous post worked for an expression
f :: Expr Double
, for example, this new
function depends on being able to convert f
to a polymorphic function, which it can’t do in this
case. This gets clear by looking at the type signature
of diff_
:
diff_ :: Floating a => Expr (Dual (Expr a)) -> Expr a
But not all hope is lost! Our differentiator works.
All we need is to discover how to turn an
Expr a
into an
Expr (Dual (Expr a))
and we can get the
proper type.
Let’s think… Is there a canonical way of embedding a
value as an expression? Of course there is! The
Const
constructor does exactly that.
Similarly, we can view a “normal” number as a dual
number with zero infinitesimal part. Thus, if we change
each coefficient in an expression by the rule
\ c -> Dual (Const c) 0
, we get an
expression of the type we need without changing any
meaning.
To help us change the coefficients, let’s give a
Functor
instance to Expr
. We
could write it by hand but let’s use some GHC magic to
automatically derive it for us.
deriving instance Functor Expr
Finally, our differentiation function is equal to
diff_
, except that it first converts all
coefficients of the input to the proper type.
-- Symbolically differentiate expressions
diff :: Floating a => Expr a -> Expr a
= autoDiff (eval (fmap from f)) X
diff f where from x = Dual (Const x) 0
Just apply it to a monomorphic expression and voilà!
ghci> diff (sin (X^2) :: Expr Double)
(Const 1.0 :*: X :+: X :*: Const 1.0) :*: Apply Cos (X :*: X)
it :: Expr Double
References
- ad package on Hackage
- The Simple Essence of Automatic Differentiation by Conal Elliott.
Also known by the fancy name of
monomorphic function
. These are functions without any free type parameter. That is, no lowercase type variable appears in the type signature.↩︎This is not the most general possible definition of
uneval
. But it is simple and clear enough for this presentation.↩︎If you’re into Algebra, you can view the complex numbers as the polynomial quotient \mathbb{R}[X] / \langle X^2 + 1 \rangle while the dual numbers are \mathbb{R}[X] / \langle X^2 \rangle.↩︎