 Dual Numbers and Automatic Differentiation
 Differential Implementation
 Term List Algebra
 Differential Type Implementation
 Accessor Methods
 Constructors
 Differential API
 Differential Parts API
 Comparison, Control Flow
 Chain Rule and Lifted Functions
 Derivatives of Differentials
 Generic Method Installation
Dual Numbers and Automatic Differentiation
This literate essay develops an implementation of a type called Differential. A Differential is a generalization of a type called a "dual number", and the glowing, pulsing core of the SICMUtils implementation of forwardmode automatic differentiation.
As we'll discuss, passing these numbers as arguments to some function \(f\) built out of the sicmutils.generic operators allows us to build up the derivative of \(f\) in parallel to our evaluation of \(f\). Complex programs are built out of simple pieces that we know how to evaluate; we can build up derivatives of entire programs in a similar way, building them out of the derivatives of the smaller pieces of those programs.
(ns sicmutils.differential
"This namespace contains an implementation of [[Differential]], a generalized
dual number type that forms the basis for the forwardmode automatic
differentiation implementation in sicmutils.
See `sicmutils.calculus.derivative` for a fleshedout derivative
implementation using [[Differential]]."
(:referclojure :rename {compare corecompare}
#?@(:cljs [:exclude [compare]]))
(:require [clojure.string :refer [join]]
[sicmutils.function :as f]
[sicmutils.generic :as g]
[sicmutils.util :as u]
[sicmutils.util.stream :as us]
[sicmutils.util.vectorset :as uv]
[sicmutils.value :as v])
#?(:clj
(:import (clojure.lang AFn IFn))))
ForwardMode Automatic Differentiation
For many scientific computing applications, it's valuable be able to generate a "derivative" of a function; given some wiggle in the inputs, how much wobble does the output produce?
we know how to take derivatives of many of the generic functions exposed by SICMUtils, like +
, *
, sin
and friends. It turns out that we can take the derivatives of large, complicated functions by combining the derivatives of these smaller functions using the chain rule as a clever bookkeeping device.
The technique of evaluating a function and its derivative in parallel is called "forwardmode Automatic Differentiation". The SICMUtils wiki has more information on the history of this technique, and links to the many other implementations you'll find in different languages. See the cljdocs Automatic Differentiation page for "how do I use this?"style questions.
NOTE: The other flavor of automatic differentiation (AD) is "reversemode AD". See sicmutils.tape for an implementation of this style, coming soon!
Dual Numbers and AD
Our goal is to build up derivatives of complex functions out of the derivatives of small pieces. A dual number is a relatively simple piece of machinery that will help us accomplish this goal.
A dual number is a pair of numbers of the form
\[
a+b\varepsilon
\]
where \(a\) and \(b\) are real numbers, and \(\varepsilon\) is an abstract thing, with the property that \(\varepsilon^2 = 0\).
NOTE: This might remind you of the definition of a complex number of the form \(a + bi\), where \(i\) is also a new thing with the property that \(i^2 = 1\). You are very wise! The bigger idea lurking here is the "generalized complex number", and of course mathematicians have pushed this very far. You might explore the "Splitcomplex numbers" too, which arise when you set \(i^2 = 1\).
Why are dual numbers useful (in SICMUtils)? If you pass \(a+b\varepsilon\) in to a function \(f\), the result is a dual number \(f(a) + Df(a) b \varepsilon\); you get both the function and its derivative evaluated at the same time!
To see why, look at what happens when you pass a dual number into the Taylor series expansion of some arbitrary function \(f\). As a reminder, the Taylor series expansion of \(f\) around some point \(a\) is:
\[
f(x) = f(a)+\frac{Df(a)}{1!}(xa)+\frac{D^2f(a)}{2!}(xa)^{2}+\frac{D^3f(a)}{3!}(xa)^{3}+\cdots
\]
NOTE: See this nice overview of Taylor series expansion by Andrew Chamberlain if you want to understand this idea and why we can approximate (smooth) functions this way.
If you evaluate the expansion of \(f(x)\) around \(a\) with a dual number argument whose first component is \(a\) – take \(x=a+b\varepsilon\), for example – watch how the expansion simplifies:
\[
f(a+b\varepsilon) = f(a)+\frac{Df(a)}{1!}(b\varepsilon)+\frac{D^2f(a)}{2!}(b\varepsilon)^2+\cdots
\]
Since \(\varepsilon^2=0\) we can ignore all terms beyond the first two:
\[
f(a+b\varepsilon) = f(a)+ (Df(a)b)\varepsilon
\]
NOTE: See lift1 for an implementation of this idea.
Interesting! This justifies our claim above: applying a function to some dual number \(a+\varepsilon\) returns a new dual number, where
 the first component is \(f(a)\), the normal function evaluation
 the second component is \(Df(a)\), the derivative.
If we do this twice, the second component of the returned dual number beautifully recreates the Chain Rule:
\begin{aligned}
g(f(a+\varepsilon)) & = g(f(a) + Df(a)\varepsilon) \
& = g(f(a)) + (Dg(f(a)))(Df(a))\varepsilon
\end{aligned}
Terminology Change!
A "dual number" is a very general idea. Because we're interested in dual numbers as a bookkeeping device for derivatives, we're going to specialize our terminology. From now on, we'll rename \(a\) and \(b\) to \(x\) and \(x'\). Given a dual number of the form \(x+x'\varepsilon\): we'll refer to:
 \(x\) as the "primal" part of the dual number
 \(x'\) as the "tangent" part
 \(\varepsilon\) as the "tag"
NOTE: "primal" means \(x\) is tracking the "primal", or "primary", part of the computation. "tangent" is a synonym for "derivative". "tag" is going to make more sense shortly, when we start talking about mixing together multiple \(\varepsilon_1\), \(\varepsilon_2\) from different computations.
Binary Functions
What about functions of more than one variable? We can use the same approach by leaning on the multivariable Taylor series expansion. Take \(f(x, y)\) as a binary example. If we pass dual numbers in to the taylor series expansion of \(f\), the \(\varepsilon\) multiplication rule will erase all higherorder terms, leaving us with:
\[
f(x+x'\varepsilon, y+y'\varepsilon) = f(x,y) + \partial_1 f(x,y)x' + \partial_2 f(x,y) y'
\]
NOTE: See lift2 for an implementation of this idea.
This expansion generalizes for nary functions; every new argument \(x_n + x'_n\varepsilon\) contributes \(\partial_n f(...)x'_n\) to the result.
We can check this with the simple cases of addition, subtraction and multiplication.
The real parts of a dual number add commutatively, so we can rearrange the components of a sum to get a new dual number:
\[
(x+x'\varepsilon)+(y+y'\varepsilon) == (x+y)+(x'+y')\varepsilon
\]
This matches the sum rule of differentiation, since the partials of \(x + y\) with respect to either \(x\) or \(y\) both equal 1.
Subtraction is almost identical and agrees with the subtraction rule:
\[
(x+x'\varepsilon)(y+y'\varepsilon) == (xy)+(x'y')\varepsilon
\]
Multiplying out the components of two dual numbers again gives us a new dual number, whose tangent component agrees with the product rule:
\begin{aligned}
(x+ x'\varepsilon)*(y+y'\epsilon) &= xy+(xy')\varepsilon+(x'y)\varepsilon+(x'y')\epsilon^2 \
&= xy+(xy'+y'x)\varepsilon
\end{aligned}
Stare at these smaller derivations and convince yourself that they agree with the Taylor series expansion method for binary functions.
The upshot is that, armed with these techniques, we can implement a higherorder derivative
function (almost!) as simply as this:
(defn derivative [f]
(fn [x]
(extracttangent
(f (makedual x 1)))))
As long as f
is built out of functions that know how to apply themselves to dual numbers, this will all Just Work.
Multiple Variables, Nesting
All of the examples above are about firstorder derivatives. Taking higherorder derivatives is, in theory, straightforward:
(derivative
(derivative f))
But this guess hits one of many subtle problems with the implementation of forwardmode AD. The doublecall to derivative
will expand out to this:
(fn [x]
(letfn [(innerd [x]
(extracttangent
(f (makedual x 1))))]
(extracttangent
(innerd
(makedual x 1)))))
the x
received by innerd
will ALREADY be a dual number \(x+\varepsilon\)! This will cause two immediate problems:

(makedual x 1)
will return \((x+\varepsilon)+\varepsilon = x+2\varepsilon\), which is not what we we want 
The
extracttangent
call insideinnerd
will return theDf(x)
component of the dual number… which, remember, is no longer a dual number! So the SECOND call toextracttangent
have nothing to extract, and can only sensibly return 0.
The problem here is called "perturbation confusion", and is covered beautifully in "Confusion of Tagged Perturbations in Forward Automatic Differentiation of HigherOrder Functions", by Manzyuk et pal. (2019).
The solution is to introduce a new \(\varepsilon\) for every level, and allow different \(\varepsilon\) instances to multiply without annihalating. Each \(\varepsilon\) is called a "tag". Differential (implemented below) is a generalized dual number that can track many tags at once, allowing nested derivatives like the one described above to work.
This implies that extracttangent
needs to take a tag, to determine which tangent to extract:
(defn derivative [f]
(let [tag (freshtag)]
(fn [x]
(> (f (makedual x 1 tag))
(extracttangent tag)))))
This is close to the final form you'll find in derivative.
What Return Values are Allowed?
Before we discuss the implementation of dual numbers (called Differential), lift1, lift2 and the rest of the machinery that makes this all possible; what sorts of objects is f
allowed to return?
The dual number approach is beautiful because we can bring to bear all sorts of operations in Clojure that never even see dual numbers. For example, squareandcube
called with a dual number returns a PAIR of dual numbers:
(defn squareandcube [x]
(let [x2 (g/square x)
x3 (g/cube x)]
[x2 x3]))
Vectors don't care what they hold! We want the derivative of squareandcube
to also return a vector, whose entries represent the derivative of that entry with respect to the function's input.
But this implies that extracttangent from the example above needs to know how to handle vectors and other collections; in the case of a vector v
by returning (mapv extracttangent v)
The dual number implementation is called Differential; the way that Differential instances interact with container types in SICMUtils makes it easy for these captures to occur all over. Whenever we multiply a Differential by a structure, a function, a vector, any of those things, our implementation of the SICMUtils generics pushes the Differential inside those objects, rather than forming a Differential with, for example, a vector in the primal and tangent parts.
Functions… interesting. what about higherorder functions?
(defn offsetfn
"Returns a function that takes a singleargument function `g`, and returns a new
function like `g` that offsets its input by `offset`."
[offset]
(fn [g]
(fn [x]
(g (+ x offset)))))
(derivative offsetfn)
here returns a function! Ideally we'd like the returned function to act exactly like:
(derivative
(fn [offset] (g (+ x offset))))
for some known g
and x
, but with the ability to store (derivative offsetfn)
and call it later with many different g
.
We might accomplish this by composing extracttangent
with the returned function, so that the extraction happens later, when the function's called.
NOTE: The real implementation is more subtle! See the sicmutils.calculus.derivative
namespace for the actual implementation of IPerturbed for functions and multimethods.
All of this suggests that we need to make extracttangent an open function that other folks can extend for other containerlike types (functors, specifically).
The IPerturbed protocol accomplishes this, along with two other functions that we'll use later:
(defprotocol IPerturbed
(perturbed? [this]
"Returns true if the supplied object has some known nonzero tangent to be
extracted via [[extracttangent]], false otherwise. (Return `false` by
default if you can't detect a perturbation.)")
(replacetag [this oldtag newtag]
"If `this` is perturbed, Returns a similar object with the perturbation
modified by replacing any appearance of `oldtag` with `newtag`. Else,
return `this`.")
(extracttangent [this tag]
"If `this` is perturbed, return the tangent component paired with the
supplied tag. Else, returns `([[sicmutils.value/zerolike]] this)`."))
replacetag
exists to handle subtle bugs that can arise in the case of functional return values. See the "Amazing Bug" sections in sicmutils.calculus.derivativetest for detailed examples on how this might bite you.
The default implementations are straightforward, and match the docstrings:
(extendprotocol IPerturbed
#?(:clj Object :cljs default)
(perturbed? [_] false)
(replacetag [this _ _] this)
(extracttangent [this _] (v/zerolike this)))
Differential Implementation
We now have a template for how to implement derivative
. What's left? We need a dual number type that we can build and split back out into primal and tangent components, given some tag. We'll call this type a Differential.
A Differential is a generalized dual number with a single primal component, and potentially many tagged terms. Identical tags cancel to 0 when multiplied, but are allowed to multiply by each other:
\[
a + b\varepsilon_1 + c\varepsilon_2 + d\varepsilon_1 \varepsilon_2 + \cdots
\]
Alternatively, you can view a Differential as a dual number with a specific tag, that's able to hold dual numbers with some other tag in its primal and tangent slots. You can turn a Differential into a dual number by specifying one of its tags. Here are the primal and tangent components for \(\varepsilon_2\):
\[
(a + b\varepsilon_1) + (c + d\varepsilon_1)\varepsilon_2
\]
And for \(\varepsilon_1\):
\[
(a + c\varepsilon_2) + (b + d \varepsilon_2) \varepsilon_1
\]
A differential term is implemented as a pair whose first element is a set of tags and whose second is the coefficient.
(def ^:private tags first)
(def ^:private coefficient peek)
The set of tags is implemented as a "vector set", from sicmutils.util.vectorset. This is a sorted set data structure, backed by a Clojure vector. vector sets provide cheap "max" and "min" operations, and acceptable union, intersection and difference performance.
(defn maketerm
"Returns a [[Differential]] term with the supplied vectorset of `tags` paired
with coefficient `coef`.
`tags` defaults to [[uv/emptyset]]"
([coef] [uv/emptyset coef])
([tags coef] [tags coef]))
Since the only use of a tag is to distinguish each unnamed \(\varepsilon_n\), we'll assign a new, unique positive integer for each new tag:
(let [nexttag (atom 0)]
(defn freshtag
"Returns a new, unique tag for use inside of a [[Differential]] term list."
[]
(swap! nexttag inc)))
(defn taginterm?
"Return true if `t` is in the tagset of the supplied `term`, false otherwise."
[term t]
(uv/contains? (tags term) t))
Term List Algebra
The discussion above about Taylor series expansions revealed that for single variable functions, we can pass a dual number into any function whose derivative we already know:
\[
f(a+b\varepsilon) = f(a) + (Df(a)b)\varepsilon
\]
Because we can split a Differential into a primal and tangent component with respect to some tag, we can reuse this result. We'll default to splitting Differential instances by the highestindex tag:
\begin{aligned}
f(a &+ b\varepsilon_1 + c\varepsilon_2 + d\varepsilon_1 \varepsilon_2) \
&= f((a + b\varepsilon_1)+(c + d\varepsilon_1)\varepsilon_2) \
&= f(a + b\varepsilon_1)+Df(a + b\varepsilon_1)(c + d\varepsilon_1)\varepsilon_2 \
\end{aligned}
Note that \(f\) and \(Df\) both received a dual number! One more expansion, this time in \(\varepsilon_1\), completes the evaluation (and makes abundantly clear why we want the computer doing this, not pencilandpaper):
\begin{aligned}
f(a &+ b\varepsilon_1)+Df(a+b\varepsilon_1)(c+d\varepsilon_1)\varepsilon_2 \
&= (f(a)+Df(a)b\varepsilon_1)+(Df(a)+D^2f(a)b\varepsilon_1)(c + d\varepsilon_1)\varepsilon_2 \
&= f(a)+(Df(a)b+D^2f(a)bc)\varepsilon_1+Df(a)c\varepsilon_2+Df(a)d\varepsilon_1\varepsilon_2
\end{aligned}
The only operations we need to implement between lists of terms are addition and multiplication.
Addition and Multiplication
To efficiently add two Differential instances (represented as vectors of terms), we keep all terms in sorted order, sorted first by the length of each tag list (the "order" of the differential term), and secondarily by the values of the tags inside the tag list.
NOTE: Clojure vectors already implement this ordering properly, so we can use clojure.core/compare
to determine an ordering on a tag list.
(defn terms:+
"Returns the sum of the two supplied sequences of differential terms; any terms
in the result with a zero coefficient will be removed.
Each input must be sequence of `[tagset, coefficient]` pairs, sorted by
`tagset`."
[xs ys]
(loop [xs xs, ys ys, result []]
(cond (empty? xs) (into result ys)
(empty? ys) (into result xs)
:else (let [[xtags xcoef :as x] (first xs)
[ytags ycoef :as y] (first ys)
compareflag (corecompare xtags ytags)]
(cond
If the terms have the same tag set, add the coefficients
together. Include the term in the result only if the new
coefficient is nonzero.
(zero? compareflag)
(let [sum (g/+ xcoef ycoef)]
(recur (rest xs)
(rest ys)
(if (v/zero? sum)
result
(conj result (maketerm xtags sum)))))
Else, pass the smaller term on unchanged and proceed.
(neg? compareflag)
(recur (rest xs) ys (conj result x))
:else
(recur xs (rest ys) (conj result y)))))))
Because we've decided to store terms as a vector, we can multiply two vectors of terms by:
 taking the cartesian product of both term lists
 discarding all pairs of terms that share any tag (since \(\varepsilon^2=0\))
 multiplying the coefficients of all remaining pairs and unioning their tag lists
 grouping and adding any new terms with the SAME list of tags, and filtering out zeros
This final step is required by a number of different operations later, so we break it out into its own collectterms
function:
(defn collectterms
"Build a term list up of pairs of tags => coefficients by grouping together and
summing coefficients paired with the same term list.
The result will be sorted by term list, and contain no duplicate term lists."
[tags>coefs]
(let [terms (for [[tags tagscoefs] (groupby tags tags>coefs)
:let [c (transduce (map coefficient) g/+ tagscoefs)]
:when (not (v/zero? c))]
[tags c])]
(into [] (sortby tags terms))))
terms:*
implements the first three steps, and calls collectterms
on the resulting sequence:
(defn terms:*
"Returns a vector of nonzero [[Differential]] terms that represent the product
of the differential term lists `xs` and `ys`."
[xs ys]
(collectterms
(for [[xtags xcoef] xs
[ytags ycoef] ys
:when (empty? (uv/intersection xtags ytags))]
(maketerm (uv/union xtags ytags)
(g/* xcoef ycoef)))))
Differential Type Implementation
Armed with our term list arithmetic operations, we can finally implement our Differential type and implement a number of important Clojure and SICMUtils protocols.
A Differential will respond to v/kind
with ::differential
. Because we want Differential instances to work in any place that real numbers or symbolic argument work, let's make ::differential
derive from ::v/scalar
:
(derive ::differential ::v/scalar)
Now the actual type. The terms
field is a termlist vector that will remain (contractually!) sorted by its list of tags.
(declare d:apply compare equiv fromterms one?)
(deftype Differential [terms]
A [[Differential]] as implemented can act as a chainrule accounting device
for all sorts of types, not just numbers. A [[Differential]] is
only [[v/numerical?]] if its coefficients are numerical.
v/Numerical
(numerical? [_]
(v/numerical? (coefficient (first terms))))
IPerturbed
(perturbed? [_] true)
;; There are 3 cases to consider when replacing some tag in a term, annotated
;; below:
(replacetag [_ oldtag newtag]
(letfn [(process [term]
(let [tagv (tags term)]
(ifnot (uv/contains? tagv oldtag)
if the term doesn't contain the old tag, ignore it.
[term]
(if (uv/contains? tagv newtag)
if the term _already contains_ the new tag
$\varepsilon_{new}$, then replacing $\varepsilon_1$
with a new instance of $\varepsilon_2$ would cause a
clash. Since $\varepsilon_2^2=0$, the term should be
removed.
[]
else, perform the replacement.
[(maketerm (> tagv
(uv/disj oldtag)
(uv/conj newtag))
(coefficient term))]))))]
(fromterms
(mapcat process terms))))
;; To extract the tangent (with respect to `tag`) from a differential, return
;; all terms that contain the tag (with the tag removed!) This can create
;; duplicate terms, so use [[fromterms]] to massage the result into
;; wellformedness again.
(extracttangent [_ tag]
(fromterms
(mapcat (fn [term]
(let [tagv (tags term)]
(if (uv/contains? tagv tag)
[(maketerm (uv/disj tagv tag)
(coefficient term))]
[])))
terms)))
v/Value
(zero? [this]
(every? (comp v/zero? coefficient) terms))
(one? [this] (one? this))
(identity? [this] (one? this))
(zerolike [_] 0)
(onelike [_] 1)
(identitylike [_] 1)
(freeze [_] `[~'Differential ~@terms])
(exact? [_] false)
(kind [_] ::differential)
Object
;; Comparing [[Differential]] objects using `equals` defaults to [[equiv]],
;; which compares instances only using their nontagged ('finite') components.
;; If you want to compare two instances using their full term lists,
;; See [[eq]].
#?(:clj (equals [a b] (equiv a b)))
(toString [_] (str "D[" (join " " (map #(join " → " %) terms)) "]"))
;; Because a [[Differential]] is an accounting device that augments other
;; operations with the ability to carry around derivatives, it's possible that
;; the coefficient slots could be occupied by function objects. If so, then it
;; becomes possible to "apply" a [[Differential]] to some arguments (apply
;; each coefficient to the arguments).
;; TODO the arity, if anyone cares to ask, might be better implemented as the
;; joint arity of all coefficients; but my guess here is that the tangent
;; terms all have to be tracking derivatives of the primal term, which have to
;; have the same arity by definition.
f/IArity
(arity [_]
(f/arity (coefficient (first terms))))
#?@(:clj
;; This one is slightly subtle. To participate in control flow operations,
;; like comparison with both [[Differential]] and non[[Differential]]
;; numbers, [[Differential]] instances should compare using ONLY their
;; nontagged ("finite") terms. This means that comparison will totally
;; ignore any difference in tags.
[Comparable
(compareTo [a b] (compare a b))
IFn
(invoke [this]
(d:apply this []))
(invoke [this a]
(d:apply this [a]))
(invoke [this a b]
(d:apply this [a b]))
(invoke [this a b c]
(d:apply this [a b c]))
(invoke [this a b c d]
(d:apply this [a b c d]))
(invoke [this a b c d e]
(d:apply this [a b c d e]))
(invoke [this a b c d e f]
(d:apply this [a b c d e f]))
(invoke [this a b c d e f g]
(d:apply this [a b c d e f g]))
(invoke [this a b c d e f g h]
(d:apply this [a b c d e f g h]))
(invoke [this a b c d e f g h i]
(d:apply this [a b c d e f g h i]))
(invoke [this a b c d e f g h i j]
(d:apply this [a b c d e f g h i j]))
(invoke [this a b c d e f g h i j k]
(d:apply this [a b c d e f g h i j k]))
(invoke [this a b c d e f g h i j k l]
(d:apply this [a b c d e f g h i j k l]))
(invoke [this a b c d e f g h i j k l m]
(d:apply this [a b c d e f g h i j k l m]))
(invoke [this a b c d e f g h i j k l m n]
(d:apply this [a b c d e f g h i j k l m n]))
(invoke [this a b c d e f g h i j k l m n o]
(d:apply this [a b c d e f g h i j k l m n o]))
(invoke [this a b c d e f g h i j k l m n o p]
(d:apply this [a b c d e f g h i j k l m n o p]))
(invoke [this a b c d e f g h i j k l m n o p q]
(d:apply this [a b c d e f g h i j k l m n o p q]))
(invoke [this a b c d e f g h i j k l m n o p q r]
(d:apply this [a b c d e f g h i j k l m n o p q r]))
(invoke [this a b c d e f g h i j k l m n o p q r s]
(d:apply this [a b c d e f g h i j k l m n o p q r s]))
(invoke [this a b c d e f g h i j k l m n o p q r s t]
(d:apply this [a b c d e f g h i j k l m n o p q r s t]))
(applyTo [this xs] (AFn/applyToHelper this xs))]
:cljs
[IEquiv
(equiv [a b] (equiv a b))
IComparable
(compare [a b] (compare a b))
IPrintWithWriter
(prwriter [x writer _]
(writeall writer (.toString x)))
IFn
(invoke [this]
(d:apply this []))
(invoke [this a]
(d:apply this [a]))
(invoke [this a b]
(d:apply this [a b]))
(invoke [this a b c]
(d:apply this [a b c]))
(invoke [this a b c d]
(d:apply this [a b c d]))
(invoke [this a b c d e]
(d:apply this [a b c d e]))
(invoke [this a b c d e f]
(d:apply this [a b c d e f]))
(invoke [this a b c d e f g]
(d:apply this [a b c d e f g]))
(invoke [this a b c d e f g h]
(d:apply this [a b c d e f g h]))
(invoke [this a b c d e f g h i]
(d:apply this [a b c d e f g h i]))
(invoke [this a b c d e f g h i j]
(d:apply this [a b c d e f g h i j]))
(invoke [this a b c d e f g h i j k]
(d:apply this [a b c d e f g h i j k]))
(invoke [this a b c d e f g h i j k l]
(d:apply this [a b c d e f g h i j k l]))
(invoke [this a b c d e f g h i j k l m]
(d:apply this [a b c d e f g h i j k l m]))
(invoke [this a b c d e f g h i j k l m n]
(d:apply this [a b c d e f g h i j k l m n]))
(invoke [this a b c d e f g h i j k l m n o]
(d:apply this [a b c d e f g h i j k l m n o]))
(invoke [this a b c d e f g h i j k l m n o p]
(d:apply this [a b c d e f g h i j k l m n o p]))
(invoke [this a b c d e f g h i j k l m n o p q]
(d:apply this [a b c d e f g h i j k l m n o p q]))
(invoke [this a b c d e f g h i j k l m n o p q r]
(d:apply this [a b c d e f g h i j k l m n o p q r]))
(invoke [this a b c d e f g h i j k l m n o p q r s]
(d:apply this [a b c d e f g h i j k l m n o p q r s]))
(invoke [this a b c d e f g h i j k l m n o p q r s t]
(d:apply this [a b c d e f g h i j k l m n o p q r s t]))
(invoke [this a b c d e f g h i j k l m n o p q r s t rest]
(d:apply this (concat [a b c d e f g h i j k l m n o p q r s t] rest)))]))
#?(:clj
(defmethod printmethod Differential
[^Differential s ^java.io.Writer w]
(.write w (.toString s))))
Accessor Methods
(defn differential?
"Returns true if the supplied object is an instance of `Differential`, false
otherwise."
[dx]
(instance? Differential dx))
(defn bareterms
"Returns the `terms` field of the supplied `Differential` object. Errors if any
other type is supplied."
[dx]
{:pre [(differential? dx)]}
(.terms ^Differential dx))
Constructors
Because a Differential is really a wrapper around the idea of a generalized dual number represented as a termlist, we need to be able to get to and from the term list format from other types, not just Differential instances.
(defn >terms
"Returns a vector of terms that represent the supplied [[Differential]]; any
term with a [[v/zero?]] coefficient will be filtered out before return.
If you pass a non[[Differential]], [[>terms]] will return a singleton term
list (or `[]` if the argument was zero)."
[dx]
(cond (differential? dx)
(filterv (fn [term]
(not (v/zero? (coefficient term))))
(bareterms dx))
(v/zero? dx) []
:else [(maketerm dx)]))
(defn terms>differential
"Returns a differential instance generated from a vector of terms. This method
will do some mild cleanup, or canonicalization:
 any empty term list will return 0
 a singleton term list with no tags will return its coefficient
NOTE this method assumes that the input is properly sorted, and contains no
zero coefficients."
[terms]
{:pre [(vector? terms)]}
(cond (empty? terms) 0
(and (= (count terms) 1)
(empty? (tags (first terms))))
(coefficient (first terms))
:else (>Differential terms)))
(defn fromterms
"Accepts a sequence of terms (pairs of [taglist, coefficient]), and returns:
 a wellformed [[Differential]] instance, if the terms resolve to a
differential with nonzero infinitesimal terms
 the original input otherwise
Duplicate (by tag list) terms are allowed; their coefficients will be summed
together and removed if they sum to zero."
[tags>coefs]
(terms>differential
(collectterms tags>coefs)))
Differential API
This next section lifts slightlyaugmented versions of terms:+
and terms:*
up to operate on Differential instances. These work just as before, but handle wrapping and unwrapping the term list.
(defn d:+
"Returns an object representing the sum of the two objects `dx` and `dy`. This
works by summing the coefficients of all terms with the same list of tags.
Works with non[[Differential]] instances on either or both sides, and returns
a [[Differential]] only if it contains any nonzero tangent components."
[dx dy]
(terms>differential
(terms:+ (>terms dx)
(>terms dy))))
(defn d:*
"Returns an object representing the product of the two objects `dx` and `dy`.
This works by multiplying out all terms:
$$(dx1 + dx2 + dx3 + ...)(dy1 + dy2 + dy3...)$$
and then collecting any duplicate terms by summing their coefficients.
Works with non[[Differential]] instances on either or both sides, and returns
a [[Differential]] only if it contains any nonzero tangent components."
[dx dy]
(terms>differential
(terms:* (>terms dx)
(>terms dy))))
(defn d:apply
"Accepts a [[Differential]] and a sequence of `args`, interprets each
coefficient as a function and returns a new [[Differential]] generated by
applying the coefficient to `args`."
[diff args]
(terms>differential
(into [] (mapcat (fn [term]
(let [result (apply (coefficient term) args)]
(if (v/zero? result)
[]
[(maketerm (tags term) result)]))))
(>terms diff))))
Finally, the function we've been waiting for! bundle
allows you to augment some nonDifferential thing with a tag and push it through the generic arithmetic system, where it will accumulate the derivative of your original input (tagged with tag
.)
(defn bundle
"Generate a new [[Differential]] object with the supplied `primal` and `tangent`
components, and the supplied internal `tag` that this [[Differential]] will
carry around to prevent perturbation confusion.
If the `tangent` component is `0`, acts as identity on `primal`. `tangent`
defaults to 1.
`tag` defaults to a sideeffecting call to [[freshtag]]; you can retrieve
this unknown tag by calling [[maxordertag]]."
([primal]
(bundle primal 1 (freshtag)))
([primal tag]
(bundle primal 1 tag))
([primal tangent tag]
(let [term (maketerm (uv/make [tag]) tangent)]
(d:+ primal (>Differential [term])))))
Differential Parts API
These functions give higherlevel access to the components of a Differential you're typically interested in.
(defn maxordertag
"Given one or more wellformed [[Differential]] objects, returns the
maximum ('highest order') tag found in the highestorder term of any of
the [[Differential]] instances.
If there is NO maximal tag (ie, if you provide [[Differential]] instances with
no nonzero tangent parts, or all non[[Differential]]s), returns nil."
([dx]
(when (differential? dx)
(let [lastterm (peek (>terms dx))
highesttag (peek (tags lastterm))]
highesttag)))
([dx & dxs]
(letfn [(maxtermv [dx]
(iflet [maxorder (maxordertag dx)]
[maxorder]
[]))]
(whenlet [orders (seq (mapcat maxtermv (cons dx dxs)))]
(apply max orders)))))
A reminder: the primalpart
of a Differential is all terms except for terms containing maxordertag
, and tangentpart
is a Differential built out of the remaining terms, all of which contain that tag.
(defn primalpart
"Returns a [[Differential]] containing only the terms of `dx` that do NOT
contain the supplied `tag`, ie, the primal component of `dx` with respect to
`tag`.
If no tag is supplied, defaults to `([[maxordertag]] dx)`.
NOTE: every [[Differential]] can be factored into a dual number of the form
primal + (tangent * tag)
For each tag in any of its terms. [[primalpart]] returns this first piece,
potentially simplified into a non[[Differential]] if the primal part contains
no other tags."
([dx] (primalpart dx (maxordertag dx)))
([dx tag]
(if (differential? dx)
(let [sanstag? #(not (taginterm? % tag))]
(>> (>terms dx)
(filterv sanstag?)
(terms>differential)))
dx)))
(defn tangentpart
"Returns a [[Differential]] containing only the terms of `dx` that contain the
supplied `tag`, ie, the tangent component of `dx` with respect to `tag`.
If no tag is supplied, defaults to `([[maxordertag]] dx)`.
NOTE: Every [[Differential]] can be factored into a dual number of the form
primal + (tangent * tag)
For each tag in any of its terms. [[tangentpart]] returns a [[Differential]]
representing `(tangent * tag)`, or 0 if `dx` contains no terms with the
supplied `tag`.
NOTE: the 2arity case is similar to `(extracttangent dx tag)`; the only
difference is that `extracttangent` drops the `dx` tag from all terms in the
returned value. Call `extracttangent` if you want to drop `tag`."
([dx] (tangentpart dx (maxordertag dx)))
([dx tag]
(if (differential? dx)
(>> (>terms dx)
(filterv #(taginterm? % tag))
(terms>differential))
0)))
(defn primaltangentpair
"Returns a pair of the primal and tangent components of the supplied `dx`, with
respect to the supplied `tag`. See the docs for [[primalpart]]
and [[tangentpart]] for more details.
[[primaltangentpair]] is equivalent to
`[([[primalpart]] dx tag) ([[tangentpart]] dx tag)]`
but slightly more efficient if you need both."
([dx] (primaltangentpair dx (maxordertag dx)))
([dx tag]
(ifnot (differential? dx)
[dx 0]
(let [[tangentterms primalterms]
(us/separatev #(taginterm? % tag)
(>terms dx))]
[(terms>differential primalterms)
(terms>differential tangentterms)]))))
(defn finiteterm
"Returns the term of the supplied [[Differential]] `dx` that has no tags
attached to it, `0` otherwise.
[[Differential]] instances with many can be decomposed many times
into [[primalpart]] and [[tangentpart]]. Repeated calls
to [[primalpart]] (with different tags!) will eventually yield a
non[[Differential]] value. If you know you want this, [[finiteterm]] will
get you there in one shot.
NOTE that this will only work with a wellformed [[Differential]], ie, an
instance with all terms sorted by their list of tags."
[dx]
(if (differential? dx)
(let [[head] (bareterms dx)
ts (tags head)]
(if (= [] ts)
(coefficient head)
0))
dx))
Comparison, Control Flow
Functions like =
, <
and friends don't have derivatives; instead, they're used for control flow inside of Clojure functions. To play nicely with these functions, the Differential API exposes a number of methods for comparing numbers on ONLY their finite parts.
Why? If x
is a Differential instance, (< x 10)
needs to return true whenever a nonDifferential x
would return true. To make this work, these operations look only at the finitepart
.
HOWEVER! v/one?
and v/zero?
are examples of SICMUtils functions that are used to skip operations that we want to happen, like multiplication.
(g/* x y)
will return y
if (v/one? x)
is true… but to propagate the derivative through we need this multiplication to occur. The compromise is:
v/one?
andv/zero?
return true only when ALLtangentpart=s are zero and the =finitepart
is eitherv/one?
orv/zero?
respectivelyeq
andcomparefull
similarly looks at every component in the Differential supplied to both sides
while:
equiv
andcompare
only examine thefinitepart
of either side.
(defn one?
"Returns true if the supplied instance has a [[finitepart]] that responds true
to [[sicmutils.value/one?]], and zero coefficients on any of its tangent
components; false otherwise.
NOTE: This means that [[one?]] will not do what you expect as a conditional
inside some function. If you want to branch inside some function you're taking
the derivative of, prefer `(= 1 dx)`. This will only look at
the [[finitepart]] and ignore the values of the tangent parts."
[dx]
(let [[p t] (primaltangentpair dx)]
(and (v/one? p)
(v/zero? t))))
(defn eq
"For nondifferentials, this is identical to [[clojure.core/=]].
For [[Differential]] instances, equality acts on tangent components too.
If you want to ignore the tangent components, use [[equiv]]."
([_] true)
([a b]
(= (>terms a)
(>terms b)))
([a b & more]
(reduce eq (eq a b) more)))
(defn comparefull
"Comparator that compares [[Differential]] instances with each other or
nondifferentials using all tangent terms each instance. Matches the response
of [[eq]].
Acts as [[clojure.core/compare]] for nondifferentials."
[a b]
(corecompare
(>terms a)
(>terms b)))
(defn equiv
"Returns true if all of the supplied objects have equal [[finitepart]]s, false
otherwise.
Use [[equiv]] if you want to compare nondifferentials with
[[Differential]]s and ignore all tangent components. If you _do_ want to take
the tangent components into account, prefer [[eq]]."
([_] true)
([a b]
(= (finiteterm a)
(finiteterm b)))
([a b & more]
(reduce equiv (equiv a b) more)))
(defn compare
"Comparator that compares [[Differential]] instances with each other or
nondifferentials using only the [[finitepart]] of each instance. Matches the
response of [[equiv]].
Acts as [[clojure.core/compare]] for nondifferentials."
[a b]
(corecompare
(finiteterm a)
(finiteterm b)))
Chain Rule and Lifted Functions
Finally, we come to the heart of it! lift1 and lift2 "lift", or augment, unary or binary functions with the ability to handle Differential instances in addition to whatever other types they previously supported.
These functions are implementations of the single and multivariable Taylor series expansion methods discussed at the beginning of the namespace.
There is yet another subtlety here, noted in the docstrings below. lift1 and lift2 really are able to lift functions like clojure.core/+
that can't accept Differentials. But the firstorder derivatives that you have to supply do have to be able to take Differential instances.
This is because the tangentpart
of Differential might still be a Differential, and for Df
to handle this we need to be able to take the secondorder derivative.
Magically this will all Just Work if you pass an alreadylifted function, or a function built out of alreadylifted components, as df:dx
or df:dy
.
(defn lift1
"Given:
 some unary function `f`
 a function `df:dx` that computes the derivative of `f` with respect to its
single argument
Returns a new unary function that operates on both the original type of `f`
and [[Differential]] instances.
NOTE: `df:dx` has to ALREADY be able to handle [[Differential]] instances. The
best way to accomplish this is by building `df:dx` out of alreadylifted
functions, and declaring them by forward reference if you need to."
[f df:dx]
(fn call [x]
(ifnot (differential? x)
(f x)
(let [[px tx] (primaltangentpair x)
fx (call px)]
(if (and (v/number? tx) (v/zero? tx))
fx
(d:+ fx (d:* (df:dx px) tx)))))))
(defn lift2
"Given:
 some binary function `f`
 a function `df:dx` that computes the derivative of `f` with respect to its
single argument
 a function `df:dy`, similar to `df:dx` for the second arg
Returns a new binary function that operates on both the original type of `f`
and [[Differential]] instances.
NOTE: `df:dx` and `df:dy` have to ALREADY be able to handle [[Differential]]
instances. The best way to accomplish this is by building `df:dx` and `df:dy`
out of alreadylifted functions, and declaring them by forward reference if
you need to."
[f df:dx df:dy]
(fn call [x y]
(ifnot (or (differential? x)
(differential? y))
(f x y)
(let [tag (maxordertag x y)
[xe dx] (primaltangentpair x tag)
[ye dy] (primaltangentpair y tag)
a (call xe ye)
b (if (and (v/number? dx) (v/zero? dx))
a
(d:+ a (d:* dx (df:dx xe ye))))]
(if (and (v/number? dy) (v/zero? dy))
b
(d:+ b (d:* (df:dy xe ye) dy)))))))
(defn liftn
"Given:
 some function `f` that can handle 0, 1 or 2 arguments
 `df:dx`, a fn that returns the derivative wrt the single arg in the unary case
 `df:dx1` and `df:dx2`, fns that return the derivative with respect to the
first and second args in the binary case
Returns a new anyarity function that operates on both the original type of
`f` and [[Differential]] instances.
NOTE: The nary case of `f` is populated by nested calls to the binary case.
That means that this is NOT an appropriate lifting method for an nary
function that isn't built out of associative binary calls. If you need this
ability, please file an issue at the [sicmutils issue
tracker](https://github.com/sicmutils/sicmutils/issues)."
[f df:dx df:dx1 df:dx2]
(let [f1 (lift1 f df:dx)
f2 (lift2 f df:dx1 df:dx2)]
(fn call
([] (f))
([x] (f1 x))
([x y] (f2 x y))
([x y & more]
(reduce call (call x y) more)))))
Derivatives of Differentials
One more treat before we augment the generic arithmetic system. The derivative operation is linear, so:
\[
D(x+x'\varepsilon) = D(x)+D(x')\varepsilon
\]
This implementation is valid because the coefficients of a Differential can be functions.
(defmethod g/partialderivative [::differential v/seqtype] [a selectors]
(let [tag (maxordertag a)
px (primalpart a tag)
tx (extracttangent a tag)]
(d:+ (g/partialderivative px selectors)
(d:* (g/partialderivative tx selectors)
(bundle 0 1 tag)))))
Generic Method Installation
Armed with lift1 and lift2, we can install Differential into the SICMUtils generic arithmetic system.
Any function built out of these components will work with the D operator.
(defn defunary
"Given:
 a generic unary multimethod `genericop`
 a corresponding singlearity lifted function `differentialop`
installs an appropriate unary implementation of `genericop` for
`::differential` instances."
[genericop differentialop]
(defmethod genericop [::differential] [a] (differentialop a)))
(defn defbinary
"Given:
 a generic binary multimethod `genericop`
 a corresponding 2arity lifted function `differentialop`
installs an appropriate binary implementation of `genericop` between
`:differential` and `::v/scalar` instances."
[genericop differentialop]
(doseq [signature [[::differential ::differential]
[::v/scalar ::differential]
[::differential ::v/scalar]]]
(defmethod genericop signature [a b] (differentialop a b))))
And now we're off to the races. The rest of the namespace provides defunary
and defbinary
calls for all of the generic operations for which we know how to declare partial derivatives.
(defbinary g/add
(lift2 g/add
(fn [_ _] 1)
(fn [_ _] 1)))
(defunary g/negate
(lift1 g/negate (fn [_] 1)))
(defunary g/negative?
(fn [x] (g/negative? (finiteterm x))))
(defbinary g/sub
(lift2 g/sub
(fn [_ _] 1)
(fn [_ _] 1)))
(let [mul (lift2
g/mul
(fn [_ y] y)
(fn [x _] x))]
(defbinary g/mul mul)
(defunary g/square (fn [x] (mul x x)))
(defunary g/cube (fn [x] (mul x (mul x x))))
(defbinary g/dotproduct mul))
(defunary g/invert
(lift1 g/invert
(fn [x] (g/div 1 (g/square x)))))
(defbinary g/div
(lift2 g/div
(fn [_ y] (g/div 1 y))
(fn [x y] (g/div (g/negate x)
(g/square y)))))
(defunary g/abs
(fn [x]
(let [f (finiteterm x)
func (cond (< f 0) (lift1 (fn [x] (g/negate x)) (fn [_] 1))
(> f 0) (lift1 (fn [x] x) (fn [_] 1))
(= f 0) (u/illegal "Derivative of g/abs undefined at zero")
:else (u/illegal (str "error! derivative of g/abs at" x)))]
(func x))))
(defunary g/sqrt
(lift1 g/sqrt
(fn [x]
(g/invert
(g/mul (g/sqrt x) 2)))))
This first case of g/expt
, where the exponent itself is nonDifferential, is specialcased and slightly simpler. The second partial derivative throws, since the more general definition below should always override.
(let [power (lift2
g/expt
(fn [x y]
(g/mul y (g/expt x (g/sub y 1))))
(fn [_ _]
(u/illegal "can't get there from here")))]
(defmethod g/expt [::differential ::v/scalar] [d n] (power d n)))
The remaining two cases allow for a differential exponent.
NOTE: I took this implementationsplit from scmutils, but I'm not sure that it matters… if the second partial never gets called, why is this a good optimization?
(let [expt (lift2
g/expt
(fn [x y]
(g/mul y (g/expt x (g/sub y 1))))
(fn [x y]
(if (and (v/number? x) (v/zero? y))
(if (v/number? y)
(if (not (g/negative? y))
0
(u/illegal "Derivative undefined: expt"))
0)
(g/* (g/log x) (g/expt x y)))))]
(defmethod g/expt [::differential ::differential] [d n] (expt d n))
(defmethod g/expt [::v/scalar ::differential] [d n] (expt d n)))
(defunary g/log
(lift1 g/log g/invert))
(defunary g/exp
(lift1 g/exp g/exp))
(defunary g/sin
(lift1 g/sin g/cos))
(defunary g/cos
(lift1 g/cos
(fn [x] (g/negate (g/sin x)))))
(defunary g/tan
(lift1 g/tan
(fn [x]
(g/invert
(g/square (g/cos x))))))
(defunary g/asin
(lift1 g/asin
(fn [x]
(g/invert
(g/sqrt (g/sub 1 (g/square x)))))))
(defunary g/acos
(lift1 g/acos
(fn [x]
(g/negate
(g/invert
(g/sqrt (g/sub 1 (g/square x))))))))
(defunary g/atan
(lift1 g/atan (fn [x]
(g/invert
(g/add 1 (g/square x))))))
(defbinary g/atan
(lift2 g/atan
(fn [y x]
(g/div x (g/add (g/square x)
(g/square y))))
(fn [y x]
(g/div (g/negate y)
(g/add (g/square x)
(g/square y))))))
(defunary g/sinh
(lift1 g/sinh g/cosh))
(defunary g/cosh
(lift1 g/cosh g/sinh))
(defunary g/tanh
(lift1 g/tanh
(fn [x]
(g/sub 1 (g/square (g/tanh x))))))
Comments
comments powered by Disqus