← Back to Optimization

Automatic Differentiation

Day 11: Automatic Differentiation — Companion Code

⬇ Download automatic_differentiation.py

"""
Day 11: Automatic Differentiation — Companion Code
Phase I, Week 2

Run: python automatic_differentiation.py
"""

import numpy as np
from typing import Union


# =============================================================================
# Exercise 1: Dual Number Class
# =============================================================================

class DualNumber:
    """Dual number a + b*ε where ε² = 0."""

    def __init__(self, real: float, dual: float = 0.0):
        self.real = float(real)
        self.dual = float(dual)

    def __repr__(self):
        return f"Dual({self.real:.6f} + {self.dual:.6f}ε)"

    def __add__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(self.real + other, self.dual)
        return DualNumber(self.real + other.real, self.dual + other.dual)

    def __radd__(self, other):
        return self.__add__(other)

    def __sub__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(self.real - other, self.dual)
        return DualNumber(self.real - other.real, self.dual - other.dual)

    def __rsub__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(other - self.real, -self.dual)
        return other.__sub__(self)

    def __mul__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(self.real * other, self.dual * other)
        return DualNumber(self.real * other.real,
                         self.real * other.dual + self.dual * other.real)

    def __rmul__(self, other):
        return self.__mul__(other)

    def __truediv__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(self.real / other, self.dual / other)
        denom = other.real ** 2
        real = self.real / other.real
        dual = (self.dual * other.real - self.real * other.dual) / denom
        return DualNumber(real, dual)

    def __pow__(self, n):
        if isinstance(n, (int, float)):
            return DualNumber(self.real ** n,
                            n * self.real ** (n - 1) * self.dual)
        raise NotImplementedError("Dual^Dual not implemented")

    def __neg__(self):
        return DualNumber(-self.real, -self.dual)


def dual_sin(x: DualNumber) -> DualNumber:
    return DualNumber(np.sin(x.real), np.cos(x.real) * x.dual)


def dual_cos(x: DualNumber) -> DualNumber:
    return DualNumber(np.cos(x.real), -np.sin(x.real) * x.dual)


def dual_exp(x: DualNumber) -> DualNumber:
    ex = np.exp(x.real)
    return DualNumber(ex, ex * x.dual)


def dual_log(x: DualNumber) -> DualNumber:
    return DualNumber(np.log(x.real), x.dual / x.real)


def dual_sqrt(x: DualNumber) -> DualNumber:
    s = np.sqrt(x.real)
    return DualNumber(s, x.dual / (2 * s))


def test_dual_numbers():
    """Test dual number AD."""
    print("=== Exercise 1: Dual Number Class ===")

    # f(x) = sin(x), f'(x) = cos(x), test at x=1
    x = DualNumber(1.0, 1.0)  # x + ε
    result = dual_sin(x)
    print(f"  sin(1 + ε) = {result}")
    print(f"  f(1)  = {result.real:.10f}, exact = {np.sin(1.0):.10f}")
    print(f"  f'(1) = {result.dual:.10f}, exact = {np.cos(1.0):.10f}")

    # f(x) = x³ + 2x at x=3
    x = DualNumber(3.0, 1.0)
    result = x ** 3 + 2 * x
    print(f"\n  x³+2x at x=3: {result}")
    print(f"  f(3)  = {result.real:.1f}, exact = {33.0:.1f}")
    print(f"  f'(3) = {result.dual:.1f}, exact = {29.0:.1f}")

    # f(x) = exp(sin(x)) at x=0.5
    x = DualNumber(0.5, 1.0)
    result = dual_exp(dual_sin(x))
    exact_val = np.exp(np.sin(0.5))
    exact_der = np.cos(0.5) * np.exp(np.sin(0.5))
    print(f"\n  exp(sin(0.5)): {result}")
    print(f"  f(0.5)  = {result.real:.10f}, exact = {exact_val:.10f}")
    print(f"  f'(0.5) = {result.dual:.10f}, exact = {exact_der:.10f}")


# =============================================================================
# Exercise 2: Forward-Mode AD for Multi-Variable Functions
# =============================================================================

def forward_mode_gradient(f, x: np.ndarray) -> tuple:
    """Compute gradient of scalar f: R^n -> R using forward-mode AD."""
    n = len(x)
    grad = np.zeros(n)

    for i in range(n):
        # Create dual number for each variable
        x_dual = [DualNumber(x[j], 1.0 if j == i else 0.0)
                  for j in range(n)]
        result = f(x_dual)
        grad[i] = result.dual

    return f([DualNumber(xi, 0.0) for xi in x]).real, grad


def test_forward_mode():
    """Test forward-mode gradient computation."""
    print("\n=== Exercise 2: Forward-Mode AD ===")

    # f(x,y) = x*y + sin(x)
    def f(x):
        return x[0] * x[1] + dual_sin(x[0])

    point = np.array([1.5, -0.7])
    val, grad = forward_mode_gradient(f, point)

    exact_grad = np.array([
        point[1] + np.cos(point[0]),  # df/dx = y + cos(x)
        point[0]                       # df/dy = x
    ])

    print(f"  f(1.5, -0.7) = {val:.6f}")
    print(f"  AD gradient:    [{grad[0]:.6f}, {grad[1]:.6f}]")
    print(f"  Exact gradient: [{exact_grad[0]:.6f}, {exact_grad[1]:.6f}]")
    print(f"  Error: {np.linalg.norm(grad - exact_grad):.2e}")


# =============================================================================
# Exercise 3: Reverse-Mode AD (Simple Computation Graph)
# =============================================================================

class Var:
    """Variable in a computation graph for reverse-mode AD."""
    _tape = []  # global tape

    def __init__(self, value: float, children=(), grad_fns=()):
        self.value = float(value)
        self.grad = 0.0
        self._children = children
        self._grad_fns = grad_fns
        Var._tape.append(self)

    def __repr__(self):
        return f"Var({self.value:.6f})"

    def backward(self):
        """Backpropagate from this node."""
        self.grad = 1.0
        for node in reversed(Var._tape):
            for child, grad_fn in zip(node._children, node._grad_fns):
                child.grad += node.grad * grad_fn()

    @staticmethod
    def clear_tape():
        Var._tape = []

    def __add__(self, other):
        if isinstance(other, (int, float)):
            other = Var(other)
        return Var(self.value + other.value,
                  (self, other),
                  (lambda: 1.0, lambda: 1.0))

    def __radd__(self, other):
        return self.__add__(other)

    def __mul__(self, other):
        if isinstance(other, (int, float)):
            other = Var(other)
        s, o = self, other
        return Var(self.value * other.value,
                  (self, other),
                  (lambda: o.value, lambda: s.value))

    def __rmul__(self, other):
        return self.__mul__(other)

    def __neg__(self):
        return Var(-self.value, (self,), (lambda: -1.0,))

    def __sub__(self, other):
        return self + (-other)

    def __pow__(self, n):
        s = self
        return Var(self.value ** n,
                  (self,),
                  (lambda: n * s.value ** (n-1),))


def var_sin(x: Var) -> Var:
    return Var(np.sin(x.value), (x,),
              (lambda: np.cos(x.value),))


def var_exp(x: Var) -> Var:
    ex = np.exp(x.value)
    return Var(ex, (x,), (lambda: ex,))


def test_reverse_mode():
    """Test reverse-mode AD."""
    print("\n=== Exercise 3: Reverse-Mode AD ===")

    # f(x,y) = x*y + sin(x) + y²
    Var.clear_tape()
    x = Var(1.5)
    y = Var(-0.7)
    z = x * y + var_sin(x) + y ** 2

    z.backward()

    exact_dfdx = y.value + np.cos(x.value)  # y + cos(x)
    exact_dfdy = x.value + 2 * y.value       # x + 2y

    print(f"  f(1.5, -0.7) = {z.value:.6f}")
    print(f"  Reverse df/dx = {x.grad:.6f}, exact = {exact_dfdx:.6f}")
    print(f"  Reverse df/dy = {y.grad:.6f}, exact = {exact_dfdy:.6f}")
    print(f"  Note: ONE backward pass gave BOTH partial derivatives!")


# =============================================================================
# Exercise 4: JAX Gradients and Hessians
# =============================================================================

def jax_demo():
    """Demonstrate JAX AD capabilities (if available)."""
    print("\n=== Exercise 4: JAX AD ===")
    try:
        import jax
        import jax.numpy as jnp

        # Rosenbrock function
        def rosenbrock(x):
            return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

        x0 = jnp.array([0.5, 0.5])

        # Gradient (reverse mode)
        grad_f = jax.grad(rosenbrock)
        g = grad_f(x0)
        print(f"  Rosenbrock at [0.5, 0.5]:")
        print(f"  f(x)  = {rosenbrock(x0):.4f}")
        print(f"  grad  = [{g[0]:.4f}, {g[1]:.4f}]")

        # Hessian
        hess_f = jax.hessian(rosenbrock)
        H = hess_f(x0)
        print(f"  Hessian = \n    {H}")

        # Forward vs reverse Jacobian timing
        def f_vec(x):
            return jnp.array([jnp.sin(x[0]) * x[1],
                             jnp.exp(x[0] + x[1]),
                             x[0]**2 - x[1]])

        J_fwd = jax.jacfwd(f_vec)(x0)
        J_rev = jax.jacrev(f_vec)(x0)
        print(f"\n  Forward Jacobian == Reverse Jacobian: "
              f"{jnp.allclose(J_fwd, J_rev)}")

    except ImportError:
        print("  JAX not installed. Install with: pip install jax jaxlib")
        print("  Showing manual forward/reverse comparison instead.")

        # Fallback: compare our implementations
        print("\n  Forward-mode (n passes for n inputs):")
        print("    Best for f: R^n -> R^m where n << m")
        print("  Reverse-mode (1 pass for scalar output):")
        print("    Best for f: R^n -> R where n is large")


# =============================================================================
# Main
# =============================================================================

if __name__ == "__main__":
    test_dual_numbers()
    test_forward_mode()
    test_reverse_mode()
    jax_demo()