← Back to Optimization

Lie Groups

Lie Groups for Robotics — Python Implementations

⬇ Download lie_groups.py

"""
Lie Groups for Robotics — Python Implementations
===================================================
Companion code for Lie group theory applied to robot state estimation
and pose-graph optimization.

Covers:
  - SO(2): 2D rotation group (exp, log, compose)
  - SO(3): 3D rotation group (Rodrigues, quaternions, Jacobians)
  - SE(2): 2D rigid-body motions (exp, log, inverse, compose)
  - SE(3): 3D rigid-body motions (exp, log, adjoint, Jacobian)
  - Manifold Gauss-Newton optimization on SE(2)
  - Pose-graph optimization on SE(2)
  - Utilities: validation, error metrics, random sampling

Usage:
    python lie_groups.py          # Run demos with visualizations
    python lie_groups.py --test   # Run self-tests only

Dependencies: numpy (core), matplotlib (demo plots only).
No scipy required for any core operation.
"""

import numpy as np
import matplotlib
matplotlib.use("Agg")  # non-interactive backend for CI
import matplotlib.pyplot as plt
import sys


# ============================================================
# 1. SO(2) — 2D Rotations
# ============================================================

def so2_exp(theta):
    """Exponential map: angle → 2×2 rotation matrix.

    Parameters
    ----------
    theta : float
        Rotation angle in radians.

    Returns
    -------
    R : ndarray, shape (2, 2)
        2D rotation matrix in SO(2).
    """
    c, s = np.cos(theta), np.sin(theta)
    return np.array([[c, -s],
                     [s,  c]])


def so2_log(R):
    """Logarithmic map: 2×2 rotation matrix → angle.

    Parameters
    ----------
    R : ndarray, shape (2, 2)
        Rotation matrix in SO(2).

    Returns
    -------
    theta : float
        Rotation angle in (-π, π].
    """
    return np.arctan2(R[1, 0], R[0, 0])


def so2_compose(R1, R2):
    """Compose two SO(2) elements.

    Parameters
    ----------
    R1, R2 : ndarray, shape (2, 2)
        Rotation matrices in SO(2).

    Returns
    -------
    R : ndarray, shape (2, 2)
        R1 @ R2.
    """
    return R1 @ R2


# ============================================================
# 2. SO(3) — 3D Rotations
# ============================================================

def hat(omega):
    """Hat operator: R³ → so(3) skew-symmetric matrix.

    Parameters
    ----------
    omega : array-like, shape (3,)
        Angular velocity / axis-angle vector.

    Returns
    -------
    Omega : ndarray, shape (3, 3)
        Skew-symmetric matrix [omega]×.
    """
    omega = np.asarray(omega, dtype=float)
    return np.array([[0,        -omega[2],  omega[1]],
                     [omega[2],  0,        -omega[0]],
                     [-omega[1], omega[0],  0]])


def vee(Omega):
    """Vee operator: so(3) → R³ (inverse of hat).

    Parameters
    ----------
    Omega : ndarray, shape (3, 3)
        Skew-symmetric matrix.

    Returns
    -------
    omega : ndarray, shape (3,)
        Vector such that hat(omega) = Omega.
    """
    return np.array([Omega[2, 1], Omega[0, 2], Omega[1, 0]])


def so3_exp(omega):
    """Exponential map: R³ (axis-angle) → 3×3 rotation matrix.

    Uses the Rodrigues formula with Taylor expansion for small angles.

    Parameters
    ----------
    omega : array-like, shape (3,)
        Axis-angle vector. Direction = axis, norm = angle.

    Returns
    -------
    R : ndarray, shape (3, 3)
        Rotation matrix in SO(3).
    """
    omega = np.asarray(omega, dtype=float)
    theta = np.linalg.norm(omega)
    I = np.eye(3)

    if theta < 1e-10:
        # First-order Taylor: R ≈ I + [omega]×
        return I + hat(omega)

    K = hat(omega / theta)  # normalized skew-symmetric
    # Rodrigues: R = I + sin(θ) K + (1 - cos θ) K²
    return I + np.sin(theta) * K + (1.0 - np.cos(theta)) * (K @ K)


def so3_log(R):
    """Logarithmic map: SO(3) → R³ axis-angle vector.

    Handles identity (θ≈0) and θ≈π cases.

    Parameters
    ----------
    R : ndarray, shape (3, 3)
        Rotation matrix in SO(3).

    Returns
    -------
    omega : ndarray, shape (3,)
        Axis-angle vector.
    """
    cos_theta = np.clip((np.trace(R) - 1.0) / 2.0, -1.0, 1.0)
    theta = np.arccos(cos_theta)

    if theta < 1e-10:
        # Near identity: omega ≈ vee(R - I)
        return vee(R - np.eye(3))

    if np.abs(theta - np.pi) < 1e-6:
        # θ ≈ π: R + I = 2 n n^T, find axis from column of R+I with largest norm
        M = R + np.eye(3)
        norms = np.linalg.norm(M, axis=0)
        idx = np.argmax(norms)
        n = M[:, idx] / norms[idx]
        return n * theta

    # General case
    omega_hat = (theta / (2.0 * np.sin(theta))) * (R - R.T)
    return vee(omega_hat)


def angle_axis_to_quaternion(omega):
    """Convert axis-angle vector to unit quaternion [w, x, y, z].

    Parameters
    ----------
    omega : array-like, shape (3,)
        Axis-angle vector.

    Returns
    -------
    q : ndarray, shape (4,)
        Unit quaternion [w, x, y, z].
    """
    omega = np.asarray(omega, dtype=float)
    theta = np.linalg.norm(omega)

    if theta < 1e-10:
        return np.array([1.0, 0.0, 0.0, 0.0])

    axis = omega / theta
    half = theta / 2.0
    w = np.cos(half)
    xyz = np.sin(half) * axis
    return np.array([w, xyz[0], xyz[1], xyz[2]])


def quaternion_to_rotation(q):
    """Unit quaternion [w, x, y, z] → 3×3 rotation matrix.

    Parameters
    ----------
    q : array-like, shape (4,)
        Unit quaternion [w, x, y, z].

    Returns
    -------
    R : ndarray, shape (3, 3)
        Rotation matrix.
    """
    q = np.asarray(q, dtype=float)
    q = q / np.linalg.norm(q)  # ensure unit
    w, x, y, z = q

    return np.array([
        [1 - 2*(y*y + z*z),  2*(x*y - w*z),      2*(x*z + w*y)],
        [2*(x*y + w*z),      1 - 2*(x*x + z*z),  2*(y*z - w*x)],
        [2*(x*z - w*y),      2*(y*z + w*x),      1 - 2*(x*x + y*y)]
    ])


def quaternion_slerp(q1, q2, t):
    """Spherical linear interpolation between two unit quaternions.

    Parameters
    ----------
    q1, q2 : array-like, shape (4,)
        Unit quaternions [w, x, y, z].
    t : float
        Interpolation parameter in [0, 1].

    Returns
    -------
    q : ndarray, shape (4,)
        Interpolated unit quaternion.
    """
    q1 = np.asarray(q1, dtype=float)
    q2 = np.asarray(q2, dtype=float)
    q1 = q1 / np.linalg.norm(q1)
    q2 = q2 / np.linalg.norm(q2)

    dot = np.dot(q1, q2)

    # Ensure shortest path
    if dot < 0:
        q2 = -q2
        dot = -dot

    dot = np.clip(dot, -1.0, 1.0)

    if dot > 0.9995:
        # Nearly parallel — linear interpolation + normalize
        result = q1 + t * (q2 - q1)
        return result / np.linalg.norm(result)

    theta_0 = np.arccos(dot)
    sin_theta_0 = np.sin(theta_0)
    theta = theta_0 * t
    sin_theta = np.sin(theta)

    s1 = np.cos(theta) - dot * sin_theta / sin_theta_0
    s2 = sin_theta / sin_theta_0

    result = s1 * q1 + s2 * q2
    return result / np.linalg.norm(result)


def so3_left_jacobian(omega):
    """Left Jacobian J_l of SO(3).

    J_l(ω) = I + (1 - cos θ)/θ² [ω]× + (θ - sin θ)/θ³ [ω]ײ

    Parameters
    ----------
    omega : array-like, shape (3,)
        Axis-angle vector.

    Returns
    -------
    J : ndarray, shape (3, 3)
        Left Jacobian matrix.
    """
    omega = np.asarray(omega, dtype=float)
    theta = np.linalg.norm(omega)
    I = np.eye(3)

    if theta < 1e-10:
        # Taylor: J_l ≈ I + 0.5 [ω]×
        return I + 0.5 * hat(omega)

    K = hat(omega)
    K2 = K @ K
    return I + ((1.0 - np.cos(theta)) / (theta * theta)) * K + \
           ((theta - np.sin(theta)) / (theta ** 3)) * K2


def so3_right_jacobian(omega):
    """Right Jacobian J_r of SO(3). J_r(ω) = J_l(-ω).

    Parameters
    ----------
    omega : array-like, shape (3,)
        Axis-angle vector.

    Returns
    -------
    J : ndarray, shape (3, 3)
        Right Jacobian matrix.
    """
    return so3_left_jacobian(-np.asarray(omega, dtype=float))


# ============================================================
# 3. SE(2) — 2D Rigid Motions
# ============================================================

def se2_exp(xi):
    """Exponential map: R³ twist [v1, v2, theta] → 3×3 SE(2) matrix.

    Parameters
    ----------
    xi : array-like, shape (3,)
        Twist vector [v1, v2, theta].

    Returns
    -------
    T : ndarray, shape (3, 3)
        Homogeneous transformation in SE(2).
    """
    xi = np.asarray(xi, dtype=float)
    v1, v2, theta = xi

    T = np.eye(3)

    if np.abs(theta) < 1e-10:
        # Pure translation (first-order)
        T[0, 2] = v1
        T[1, 2] = v2
        return T

    s, c = np.sin(theta), np.cos(theta)
    # Rotation block
    T[0, 0] = c
    T[0, 1] = -s
    T[1, 0] = s
    T[1, 1] = c

    # V matrix for SE(2): V = (1/θ) [[sin θ, -(1-cos θ)], [1-cos θ, sin θ]]
    V = np.array([[s / theta, -(1.0 - c) / theta],
                  [(1.0 - c) / theta, s / theta]])
    t = V @ np.array([v1, v2])
    T[0, 2] = t[0]
    T[1, 2] = t[1]
    return T


def se2_log(T):
    """Logarithmic map: SE(2) → R³ twist [v1, v2, theta].

    Parameters
    ----------
    T : ndarray, shape (3, 3)
        SE(2) transformation matrix.

    Returns
    -------
    xi : ndarray, shape (3,)
        Twist vector [v1, v2, theta].
    """
    theta = np.arctan2(T[1, 0], T[0, 0])
    t = T[:2, 2]

    if np.abs(theta) < 1e-10:
        return np.array([t[0], t[1], 0.0])

    s, c = np.sin(theta), np.cos(theta)
    V = np.array([[s / theta, -(1.0 - c) / theta],
                  [(1.0 - c) / theta, s / theta]])
    v = np.linalg.solve(V, t)
    return np.array([v[0], v[1], theta])


def se2_compose(T1, T2):
    """Compose two SE(2) elements.

    Parameters
    ----------
    T1, T2 : ndarray, shape (3, 3)
        SE(2) transformation matrices.

    Returns
    -------
    T : ndarray, shape (3, 3)
        T1 @ T2.
    """
    return T1 @ T2


def se2_inverse(T):
    """Inverse of an SE(2) element.

    Parameters
    ----------
    T : ndarray, shape (3, 3)
        SE(2) transformation matrix.

    Returns
    -------
    T_inv : ndarray, shape (3, 3)
        Inverse transformation.
    """
    R = T[:2, :2]
    t = T[:2, 2]
    T_inv = np.eye(3)
    T_inv[:2, :2] = R.T
    T_inv[:2, 2] = -R.T @ t
    return T_inv


# ============================================================
# 4. SE(3) — 3D Rigid Motions
# ============================================================

def se3_hat(xi):
    """Hat operator: R⁶ → 4×4 se(3) matrix.

    xi = [v1, v2, v3, omega1, omega2, omega3]  (translation first, then rotation).

    Parameters
    ----------
    xi : array-like, shape (6,)
        Twist vector [v; omega].

    Returns
    -------
    Xi : ndarray, shape (4, 4)
        Element of se(3) Lie algebra.
    """
    xi = np.asarray(xi, dtype=float)
    v = xi[:3]
    omega = xi[3:]
    Xi = np.zeros((4, 4))
    Xi[:3, :3] = hat(omega)
    Xi[:3, 3] = v
    return Xi


def se3_vee(Xi):
    """Vee operator: se(3) → R⁶.

    Parameters
    ----------
    Xi : ndarray, shape (4, 4)
        Element of se(3).

    Returns
    -------
    xi : ndarray, shape (6,)
        Twist vector [v; omega].
    """
    v = Xi[:3, 3]
    omega = vee(Xi[:3, :3])
    return np.concatenate([v, omega])


def se3_exp(xi):
    """Exponential map: R⁶ twist → 4×4 SE(3) matrix.

    Uses the closed-form with the V matrix:
        T = [[R, V·v], [0, 1]]
    where V = I + (1-cosθ)/θ² [ω]× + (θ-sinθ)/θ³ [ω]ײ

    Parameters
    ----------
    xi : array-like, shape (6,)
        Twist vector [v; omega].

    Returns
    -------
    T : ndarray, shape (4, 4)
        SE(3) transformation matrix.
    """
    xi = np.asarray(xi, dtype=float)
    v = xi[:3]
    omega = xi[3:]
    theta = np.linalg.norm(omega)

    T = np.eye(4)

    if theta < 1e-10:
        # Pure translation (+ first-order rotation)
        T[:3, :3] = np.eye(3) + hat(omega)
        T[:3, 3] = v
        return T

    R = so3_exp(omega)
    T[:3, :3] = R

    # V matrix (same as left Jacobian of SO(3))
    V = so3_left_jacobian(omega)
    T[:3, 3] = V @ v
    return T


def se3_log(T):
    """Logarithmic map: SE(3) → R⁶ twist [v; omega].

    Parameters
    ----------
    T : ndarray, shape (4, 4)
        SE(3) transformation matrix.

    Returns
    -------
    xi : ndarray, shape (6,)
        Twist vector [v; omega].
    """
    R = T[:3, :3]
    t = T[:3, 3]
    omega = so3_log(R)
    theta = np.linalg.norm(omega)

    if theta < 1e-10:
        return np.concatenate([t, omega])

    J = so3_left_jacobian(omega)
    v = np.linalg.solve(J, t)
    return np.concatenate([v, omega])


def se3_compose(T1, T2):
    """Compose two SE(3) elements.

    Parameters
    ----------
    T1, T2 : ndarray, shape (4, 4)
        SE(3) transformation matrices.

    Returns
    -------
    T : ndarray, shape (4, 4)
        T1 @ T2.
    """
    return T1 @ T2


def se3_inverse(T):
    """Inverse of an SE(3) element.

    Parameters
    ----------
    T : ndarray, shape (4, 4)
        SE(3) transformation matrix.

    Returns
    -------
    T_inv : ndarray, shape (4, 4)
        Inverse transformation.
    """
    R = T[:3, :3]
    t = T[:3, 3]
    T_inv = np.eye(4)
    T_inv[:3, :3] = R.T
    T_inv[:3, 3] = -R.T @ t
    return T_inv


def se3_adjoint(T):
    """6×6 Adjoint matrix of an SE(3) element.

    Ad(T) = [[R, [t]× R], [0, R]]

    Maps twists from one frame to another:
        xi_b = Ad(T_ab) @ xi_a

    Parameters
    ----------
    T : ndarray, shape (4, 4)
        SE(3) transformation matrix.

    Returns
    -------
    Ad : ndarray, shape (6, 6)
        Adjoint matrix.
    """
    R = T[:3, :3]
    t = T[:3, 3]
    Ad = np.zeros((6, 6))
    Ad[:3, :3] = R
    Ad[:3, 3:] = hat(t) @ R
    Ad[3:, 3:] = R
    return Ad


def se3_left_jacobian(xi):
    """Left Jacobian of SE(3) (first-order approximation).

    For small perturbations this is accurate. For large xi, the full
    analytic form involves complicated expressions; this implementation
    uses the series truncation:
        J_l ≈ I + 0.5 * ad(xi) + (1/6) * ad(xi)²

    Parameters
    ----------
    xi : array-like, shape (6,)
        Twist vector [v; omega].

    Returns
    -------
    J : ndarray, shape (6, 6)
        Approximate left Jacobian of SE(3).
    """
    xi = np.asarray(xi, dtype=float)
    Xi = se3_hat(xi)
    # ad(xi) is the 6×6 adjoint representation of the Lie algebra element
    # For se(3): ad(xi) = [[hat(omega), hat(v)], [0, hat(omega)]]
    v = xi[:3]
    omega = xi[3:]
    ad = np.zeros((6, 6))
    ad[:3, :3] = hat(omega)
    ad[:3, 3:] = hat(v)
    ad[3:, 3:] = hat(omega)

    I6 = np.eye(6)
    ad2 = ad @ ad
    return I6 + 0.5 * ad + (1.0 / 6.0) * ad2


# ============================================================
# 5. Manifold Optimization
# ============================================================

def manifold_gauss_newton(residual_fn, jacobian_fn, x0, max_iter=50, tol=1e-8):
    """Gauss-Newton on SE(2) manifold.

    Iteratively solves:
        delta* = argmin ||r(x ⊕ delta)||²
        x ← x ⊕ delta*

    where ⊕ is the SE(2) retraction: x_new = x @ se2_exp(delta).

    Parameters
    ----------
    residual_fn : callable
        r(T) → residual vector, given a 3×3 SE(2) pose.
    jacobian_fn : callable
        J(T) → Jacobian matrix w.r.t. 3D tangent-space perturbation.
    x0 : ndarray, shape (3, 3)
        Initial SE(2) pose.
    max_iter : int
        Maximum iterations.
    tol : float
        Convergence tolerance on step norm.

    Returns
    -------
    x : ndarray, shape (3, 3)
        Optimized SE(2) pose.
    cost_history : list of float
        Cost (||r||²) at each iteration.
    """
    x = x0.copy()
    cost_history = []

    for _ in range(max_iter):
        r = residual_fn(x)
        cost = np.dot(r, r)
        cost_history.append(cost)

        J = jacobian_fn(x)
        # Normal equations: (J^T J) delta = -J^T r
        JtJ = J.T @ J
        Jtr = J.T @ r
        try:
            delta = np.linalg.solve(JtJ, -Jtr)
        except np.linalg.LinAlgError:
            delta = np.linalg.lstsq(J, -r, rcond=None)[0]

        if np.linalg.norm(delta) < tol:
            break

        # Retract on manifold
        x = se2_compose(x, se2_exp(delta))

    # Final cost
    r = residual_fn(x)
    cost_history.append(np.dot(r, r))
    return x, cost_history


def pose_graph_optimize_se2(poses, edges, max_iter=30):
    """Simple pose-graph optimization on SE(2).

    Optimizes a set of poses given relative pose measurements between them.
    Uses Gauss-Newton on the manifold. Pose 0 is fixed as the anchor.

    Parameters
    ----------
    poses : list of ndarray, shape (3, 3)
        Initial SE(2) pose estimates.
    edges : list of (int, int, ndarray)
        Each edge is (i, j, T_ij_measured) where T_ij_measured is the
        measured relative pose from i to j.
    max_iter : int
        Maximum Gauss-Newton iterations.

    Returns
    -------
    poses_opt : list of ndarray
        Optimized poses.
    cost_history : list of float
        Total cost at each iteration.
    """
    n = len(poses)
    poses_opt = [T.copy() for T in poses]
    cost_history = []

    for iteration in range(max_iter):
        total_cost = 0.0
        # Build the linear system: H delta = -b
        # Each pose (except anchor 0) has 3 DOF
        dim = 3 * (n - 1)
        H = np.zeros((dim, dim))
        b = np.zeros(dim)

        for (i, j, T_ij_meas) in edges:
            # Predicted relative pose
            T_ij_pred = se2_compose(se2_inverse(poses_opt[i]), poses_opt[j])
            # Error in tangent space
            e = se2_log(se2_compose(se2_inverse(T_ij_meas), T_ij_pred))
            total_cost += np.dot(e, e)

            # Numerical Jacobians w.r.t. perturbations of pose i and pose j
            eps = 1e-7
            Ji = np.zeros((3, 3))
            Jj = np.zeros((3, 3))

            for k in range(3):
                delta_vec = np.zeros(3)
                delta_vec[k] = eps

                # Jacobian w.r.t. pose i
                if i > 0:  # pose 0 is fixed
                    Ti_pert = se2_compose(poses_opt[i], se2_exp(delta_vec))
                    T_ij_pert = se2_compose(se2_inverse(Ti_pert), poses_opt[j])
                    e_pert = se2_log(se2_compose(se2_inverse(T_ij_meas), T_ij_pert))
                    Ji[:, k] = (e_pert - e) / eps

                # Jacobian w.r.t. pose j
                if j > 0:  # pose 0 is fixed
                    Tj_pert = se2_compose(poses_opt[j], se2_exp(delta_vec))
                    T_ij_pert = se2_compose(se2_inverse(poses_opt[i]), Tj_pert)
                    e_pert = se2_log(se2_compose(se2_inverse(T_ij_meas), T_ij_pert))
                    Jj[:, k] = (e_pert - e) / eps

            # Map to global indices (pose 0 is fixed, so pose k maps to 3*(k-1))
            def idx(pose_id):
                return 3 * (pose_id - 1)

            # Accumulate into H and b (information matrix form)
            # Using identity as information matrix (Omega = I)
            if i > 0:
                ii = idx(i)
                H[ii:ii+3, ii:ii+3] += Ji.T @ Ji
                b[ii:ii+3] += Ji.T @ e
            if j > 0:
                jj = idx(j)
                H[jj:jj+3, jj:jj+3] += Jj.T @ Jj
                b[jj:jj+3] += Jj.T @ e
            if i > 0 and j > 0:
                ii, jj = idx(i), idx(j)
                H[ii:ii+3, jj:jj+3] += Ji.T @ Jj
                H[jj:jj+3, ii:ii+3] += Jj.T @ Ji

        cost_history.append(total_cost)

        # Solve H delta = -b with damping for stability
        damping = 1e-6 * np.eye(dim)
        try:
            delta = np.linalg.solve(H + damping, -b)
        except np.linalg.LinAlgError:
            break

        if np.linalg.norm(delta) < 1e-8:
            break

        # Apply updates (pose 0 stays fixed)
        for k in range(1, n):
            dk = delta[3*(k-1):3*k]
            poses_opt[k] = se2_compose(poses_opt[k], se2_exp(dk))

    # Final cost
    total_cost = 0.0
    for (i, j, T_ij_meas) in edges:
        T_ij_pred = se2_compose(se2_inverse(poses_opt[i]), poses_opt[j])
        e = se2_log(se2_compose(se2_inverse(T_ij_meas), T_ij_pred))
        total_cost += np.dot(e, e)
    cost_history.append(total_cost)

    return poses_opt, cost_history


# ============================================================
# 6. Utilities
# ============================================================

def rotation_error(R1, R2):
    """Geodesic distance between two SO(3) elements in radians.

    err = ||log(R1^T R2)||

    Parameters
    ----------
    R1, R2 : ndarray, shape (3, 3)
        Rotation matrices.

    Returns
    -------
    err : float
        Rotation error in radians.
    """
    R_diff = R1.T @ R2
    omega = so3_log(R_diff)
    return np.linalg.norm(omega)


def pose_error(T1, T2):
    """SE(3) error as (rotation_error, translation_error).

    Parameters
    ----------
    T1, T2 : ndarray, shape (4, 4)
        SE(3) transformation matrices.

    Returns
    -------
    rot_err : float
        Rotation error in radians.
    trans_err : float
        Euclidean translation error.
    """
    rot_err = rotation_error(T1[:3, :3], T2[:3, :3])
    trans_err = np.linalg.norm(T1[:3, 3] - T2[:3, 3])
    return rot_err, trans_err


def is_valid_rotation(R, tol=1e-6):
    """Check if R ∈ SO(3): orthogonal and det = +1.

    Parameters
    ----------
    R : ndarray, shape (3, 3)
        Matrix to test.
    tol : float
        Tolerance for checks.

    Returns
    -------
    valid : bool
    """
    if R.shape != (3, 3):
        return False
    orth_err = np.linalg.norm(R.T @ R - np.eye(3))
    det_err = np.abs(np.linalg.det(R) - 1.0)
    return orth_err < tol and det_err < tol


def is_valid_se3(T, tol=1e-6):
    """Check if T ∈ SE(3).

    Parameters
    ----------
    T : ndarray, shape (4, 4)
        Matrix to test.
    tol : float
        Tolerance.

    Returns
    -------
    valid : bool
    """
    if T.shape != (4, 4):
        return False
    if not is_valid_rotation(T[:3, :3], tol):
        return False
    bottom = T[3, :]
    return np.allclose(bottom, [0, 0, 0, 1], atol=tol)


def random_rotation():
    """Sample a uniform random rotation from SO(3) via axis-angle.

    Returns
    -------
    R : ndarray, shape (3, 3)
        Random rotation matrix.
    """
    # Random axis (uniform on sphere)
    axis = np.random.randn(3)
    axis = axis / np.linalg.norm(axis)
    # Random angle uniform in [0, π]
    angle = np.random.uniform(0, np.pi)
    return so3_exp(axis * angle)


def random_se3(max_angle=np.pi, max_trans=1.0):
    """Sample a random SE(3) element.

    Parameters
    ----------
    max_angle : float
        Maximum rotation angle.
    max_trans : float
        Maximum translation norm.

    Returns
    -------
    T : ndarray, shape (4, 4)
        Random SE(3) element.
    """
    axis = np.random.randn(3)
    axis = axis / np.linalg.norm(axis)
    angle = np.random.uniform(0, max_angle)
    omega = axis * angle

    trans = np.random.randn(3)
    trans = trans / np.linalg.norm(trans) * np.random.uniform(0, max_trans)

    xi = np.concatenate([trans, omega])
    return se3_exp(xi)


# ============================================================
# 7. Self-Tests
# ============================================================

def run_tests():
    """Run all self-tests for Lie group operations."""
    np.random.seed(42)
    EPS = 1e-7

    # --- SO(2) round-trip ---
    for theta in [0.0, 0.3, -1.5, np.pi, -np.pi + 0.01]:
        R = so2_exp(theta)
        theta_rec = so2_log(R)
        R_rec = so2_exp(theta_rec)
        assert np.allclose(R, R_rec, atol=EPS), f"SO(2) round-trip failed for θ={theta}"

    # SO(2) composition
    R1 = so2_exp(0.3)
    R2 = so2_exp(0.7)
    R12 = so2_compose(R1, R2)
    assert np.allclose(so2_log(R12), 1.0, atol=EPS), "SO(2) compose failed"
    print("  [PASS] SO(2) exp/log/compose")

    # --- SO(3) round-trip ---
    test_omegas = [
        np.array([0.0, 0.0, 0.0]),                    # identity
        np.array([0.1, 0.2, 0.3]),                     # small angle
        np.array([1.0, 0.0, 0.0]),                     # 1 rad about x
        np.array([0.0, 0.0, np.pi - 0.01]),            # near π
        np.array([0.0, np.pi, 0.0]),                   # exactly π about y
        np.array([1e-12, 0.0, 0.0]),                   # near-zero
    ]
    for omega in test_omegas:
        R = so3_exp(omega)
        assert is_valid_rotation(R), f"so3_exp produced invalid rotation for ω={omega}"
        omega_rec = so3_log(R)
        R_rec = so3_exp(omega_rec)
        assert np.allclose(R, R_rec, atol=1e-6), \
            f"SO(3) round-trip failed for ω={omega}\n  R={R}\n  R_rec={R_rec}"
    print("  [PASS] SO(3) exp/log round-trip")

    # --- Hat / Vee ---
    w = np.array([1.0, 2.0, 3.0])
    assert np.allclose(vee(hat(w)), w), "hat/vee round-trip failed"
    print("  [PASS] hat/vee round-trip")

    # --- Quaternion ↔ rotation ---
    for _ in range(20):
        omega = np.random.randn(3) * np.random.uniform(0.01, np.pi)
        R_orig = so3_exp(omega)
        q = angle_axis_to_quaternion(omega)
        R_from_q = quaternion_to_rotation(q)
        assert np.allclose(R_orig, R_from_q, atol=1e-6), \
            f"Quaternion round-trip failed for ω={omega}"
    print("  [PASS] quaternion ↔ rotation round-trip")

    # --- Quaternion SLERP ---
    q1 = angle_axis_to_quaternion(np.array([0.0, 0.0, 0.0]))
    q2 = angle_axis_to_quaternion(np.array([0.0, 0.0, np.pi / 2]))
    q_mid = quaternion_slerp(q1, q2, 0.5)
    R_mid = quaternion_to_rotation(q_mid)
    R_expected = so3_exp(np.array([0.0, 0.0, np.pi / 4]))
    assert np.allclose(R_mid, R_expected, atol=1e-6), "SLERP midpoint failed"
    # Endpoints
    assert np.allclose(quaternion_slerp(q1, q2, 0.0), q1, atol=EPS), "SLERP t=0 failed"
    assert np.allclose(quaternion_to_rotation(quaternion_slerp(q1, q2, 1.0)),
                       quaternion_to_rotation(q2), atol=1e-6), "SLERP t=1 failed"
    print("  [PASS] quaternion SLERP")

    # --- SE(2) round-trip ---
    test_se2_twists = [
        np.array([0.0, 0.0, 0.0]),
        np.array([1.0, 2.0, 0.0]),        # pure translation
        np.array([0.0, 0.0, 0.5]),        # pure rotation
        np.array([1.0, -0.5, 1.2]),       # general
        np.array([0.3, 0.7, np.pi - 0.1]),
    ]
    for xi in test_se2_twists:
        T = se2_exp(xi)
        xi_rec = se2_log(T)
        T_rec = se2_exp(xi_rec)
        assert np.allclose(T, T_rec, atol=1e-6), f"SE(2) round-trip failed for ξ={xi}"
    print("  [PASS] SE(2) exp/log round-trip")

    # SE(2) inverse
    T = se2_exp(np.array([1.0, -0.5, 0.8]))
    T_inv = se2_inverse(T)
    assert np.allclose(T @ T_inv, np.eye(3), atol=EPS), "SE(2) inverse failed"
    print("  [PASS] SE(2) inverse")

    # --- SE(3) round-trip ---
    test_se3_twists = [
        np.zeros(6),
        np.array([1, 0, 0, 0, 0, 0]),              # pure translation
        np.array([0, 0, 0, 0.5, 0.3, 0.1]),        # pure rotation
        np.array([1, 2, 3, 0.4, -0.2, 0.6]),       # general
        np.array([0.1, -0.2, 0.3, 0, 0, np.pi - 0.01]),  # near π
    ]
    for xi in test_se3_twists:
        T = se3_exp(xi)
        assert is_valid_se3(T), f"se3_exp produced invalid SE(3) for ξ={xi}"
        xi_rec = se3_log(T)
        T_rec = se3_exp(xi_rec)
        assert np.allclose(T, T_rec, atol=1e-5), \
            f"SE(3) round-trip failed for ξ={xi}\n  max err={np.max(np.abs(T - T_rec))}"
    print("  [PASS] SE(3) exp/log round-trip")

    # SE(3) inverse
    T = se3_exp(np.array([1, 2, 3, 0.4, -0.2, 0.6]))
    T_inv = se3_inverse(T)
    assert np.allclose(T @ T_inv, np.eye(4), atol=EPS), "SE(3) inverse failed"
    print("  [PASS] SE(3) inverse")

    # SE(3) Adjoint identity: Ad(T) @ xi = vee(T @ hat4(xi) @ T^{-1})
    T = random_se3(max_angle=1.5, max_trans=2.0)
    xi = np.random.randn(6) * 0.5
    Ad_T = se3_adjoint(T)
    lhs = Ad_T @ xi
    Xi_mat = se3_hat(xi)
    T_inv = se3_inverse(T)
    rhs_mat = T @ Xi_mat @ T_inv
    rhs = se3_vee(rhs_mat)
    assert np.allclose(lhs, rhs, atol=1e-6), \
        f"Adjoint identity failed: lhs={lhs}, rhs={rhs}"
    print("  [PASS] SE(3) Adjoint identity")

    # --- Jacobian identities ---
    # J_l(omega) @ J_r(omega)^{-1} = exp(hat(omega))... more precisely:
    # J_r(omega) = R^T @ J_l(omega)  where R = exp(omega)
    omega = np.array([0.4, -0.3, 0.6])
    R = so3_exp(omega)
    J_l = so3_left_jacobian(omega)
    J_r = so3_right_jacobian(omega)
    assert np.allclose(J_r, R.T @ J_l, atol=1e-6), "J_r = R^T J_l identity failed"
    print("  [PASS] SO(3) Jacobian identities")

    # Numerical Jacobian check for so3_exp
    # The derivative of the exponential map satisfies:
    #   R^{-1} dR/dε = hat(J_r δ)  (body-frame Jacobian)
    # So: so3_log(R(ω)^T R(ω+εδ)) / ε → J_r(ω) δ  as ε → 0
    omega = np.array([0.3, -0.5, 0.2])
    J_r_analytic = so3_right_jacobian(omega)
    J_num = np.zeros((3, 3))
    eps = 1e-5
    R0 = so3_exp(omega)
    for k in range(3):
        d = np.zeros(3)
        d[k] = eps
        R_pert = so3_exp(omega + d)
        diff = so3_log(R0.T @ R_pert)
        J_num[:, k] = diff / eps
    assert np.allclose(J_num, J_r_analytic, atol=1e-4), \
        f"Numerical Jacobian check failed:\n  J_num=\n{J_num}\n  J_r=\n{J_r_analytic}"
    print("  [PASS] SO(3) Right Jacobian numerical check")

    # --- Pose graph optimization on SE(2) ---
    # Simple 3-node loop
    T01_true = se2_exp(np.array([1.0, 0.0, 0.3]))
    T12_true = se2_exp(np.array([0.5, 0.5, -0.2]))
    T20_true = se2_inverse(se2_compose(T01_true, T12_true))

    # True poses
    pose0 = np.eye(3)
    pose1 = se2_compose(pose0, T01_true)
    pose2 = se2_compose(pose1, T12_true)

    # Noisy initial estimates
    pose1_init = se2_compose(pose1, se2_exp(np.array([0.15, -0.1, 0.05])))
    pose2_init = se2_compose(pose2, se2_exp(np.array([-0.1, 0.2, -0.08])))

    poses_init = [pose0, pose1_init, pose2_init]
    edges = [
        (0, 1, T01_true),
        (1, 2, T12_true),
        (2, 0, T20_true),  # loop closure
    ]

    poses_opt, cost_hist = pose_graph_optimize_se2(poses_init, edges, max_iter=30)

    # Check convergence: optimized poses should be close to ground truth
    assert np.allclose(poses_opt[1], pose1, atol=0.05), \
        f"Pose graph: pose 1 not converged.\n  opt={poses_opt[1]}\n  true={pose1}"
    assert np.allclose(poses_opt[2], pose2, atol=0.05), \
        f"Pose graph: pose 2 not converged.\n  opt={poses_opt[2]}\n  true={pose2}"
    assert cost_hist[-1] < cost_hist[0] * 0.01, \
        f"Pose graph cost did not decrease enough: {cost_hist[0]:.6f} → {cost_hist[-1]:.6f}"
    print("  [PASS] SE(2) pose-graph optimization (3-node loop)")

    # --- Validation utilities ---
    R = random_rotation()
    assert is_valid_rotation(R), "random_rotation() produced invalid SO(3)"
    T = random_se3()
    assert is_valid_se3(T), "random_se3() produced invalid SE(3)"

    rot_e, trans_e = pose_error(T, T)
    assert rot_e < EPS and trans_e < EPS, "pose_error(T, T) should be zero"
    print("  [PASS] Utilities (validation, error, random)")

    print("\n  All tests passed!")


# ============================================================
# 8. Demos
# ============================================================

def run_demo():
    """Demonstrate key Lie group operations with visualizations."""
    np.random.seed(0)

    # ----------------------------------------------------------
    # Demo 1: SO(3) rotation composition and visualization
    # ----------------------------------------------------------
    print("=" * 60)
    print("Demo 1: SO(3) Rotation Composition")
    print("=" * 60)

    omega1 = np.array([0.0, 0.0, np.pi / 4])   # 45° about z
    omega2 = np.array([0.0, np.pi / 4, 0.0])    # 45° about y
    R1 = so3_exp(omega1)
    R2 = so3_exp(omega2)
    R12 = R1 @ R2

    print(f"  R1 (45° about z):\n{R1}")
    print(f"  R2 (45° about y):\n{R2}")
    print(f"  R1 @ R2:\n{R12}")
    print(f"  log(R1 @ R2) = {so3_log(R12)}")
    print(f"  Valid SO(3)? {is_valid_rotation(R12)}")

    # Visualize rotated unit vectors
    fig = plt.figure(figsize=(12, 4))
    axes_labels = ["R1 (z 45°)", "R2 (y 45°)", "R1·R2"]
    rotations = [R1, R2, R12]
    colors = ['r', 'g', 'b']

    for idx, (R, label) in enumerate(zip(rotations, axes_labels)):
        ax = fig.add_subplot(1, 3, idx + 1, projection='3d')
        origin = np.zeros(3)
        for k, c in enumerate(colors):
            v = R[:, k]
            ax.quiver(*origin, *v, color=c, arrow_length_ratio=0.15, linewidth=2)
        ax.set_xlim([-1.2, 1.2])
        ax.set_ylim([-1.2, 1.2])
        ax.set_zlim([-1.2, 1.2])
        ax.set_title(label)
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")

    plt.tight_layout()
    plt.savefig("demo_so3_composition.png", dpi=120)
    print("  → Saved demo_so3_composition.png\n")

    # ----------------------------------------------------------
    # Demo 2: SE(3) pose interpolation via exp map
    # ----------------------------------------------------------
    print("=" * 60)
    print("Demo 2: SE(3) Pose Interpolation")
    print("=" * 60)

    T_start = np.eye(4)
    xi_end = np.array([2.0, 1.0, 0.5, 0.0, 0.0, np.pi / 2])
    T_end = se3_exp(xi_end)

    # Compute relative twist
    T_rel = se3_compose(se3_inverse(T_start), T_end)
    xi_rel = se3_log(T_rel)

    n_steps = 11
    ts = np.linspace(0, 1, n_steps)
    positions = np.zeros((n_steps, 3))

    print(f"  Interpolating from identity to T_end over {n_steps} steps...")
    for k, t in enumerate(ts):
        T_interp = se3_compose(T_start, se3_exp(t * xi_rel))
        positions[k] = T_interp[:3, 3]

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(positions[:, 0], positions[:, 1], 'bo-', markersize=6, label='Interpolated path')
    ax.plot(positions[0, 0], positions[0, 1], 'gs', markersize=12, label='Start')
    ax.plot(positions[-1, 0], positions[-1, 1], 'r^', markersize=12, label='End')
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_title("SE(3) Pose Interpolation (XY projection)")
    ax.legend()
    ax.grid(True)
    ax.set_aspect('equal')
    plt.tight_layout()
    plt.savefig("demo_se3_interpolation.png", dpi=120)
    print("  → Saved demo_se3_interpolation.png\n")

    # ----------------------------------------------------------
    # Demo 3: Pose-graph optimization (4-node loop)
    # ----------------------------------------------------------
    print("=" * 60)
    print("Demo 3: SE(2) Pose-Graph Optimization (4-node loop)")
    print("=" * 60)

    # Ground-truth square path
    T01_gt = se2_exp(np.array([1.0, 0.0, 0.0]))
    T12_gt = se2_exp(np.array([0.0, 0.0, np.pi / 2]))
    T12_gt = se2_compose(T12_gt, se2_exp(np.array([1.0, 0.0, 0.0])))
    T23_gt = se2_exp(np.array([0.0, 0.0, np.pi / 2]))
    T23_gt = se2_compose(T23_gt, se2_exp(np.array([1.0, 0.0, 0.0])))

    p0_gt = np.eye(3)
    p1_gt = se2_compose(p0_gt, T01_gt)
    p2_gt = se2_compose(p1_gt, T12_gt)
    p3_gt = se2_compose(p2_gt, T23_gt)

    # Relative measurements (with noise)
    T01_meas = T01_gt.copy()
    T12_meas = T12_gt.copy()
    T23_meas = T23_gt.copy()
    T30_meas = se2_inverse(se2_compose(se2_compose(T01_gt, T12_gt), T23_gt))

    # Noisy initial estimates (accumulate drift)
    noise = lambda: se2_exp(np.random.randn(3) * np.array([0.1, 0.1, 0.03]))
    p1_init = se2_compose(p1_gt, noise())
    p2_init = se2_compose(p2_gt, noise())
    p3_init = se2_compose(p3_gt, noise())

    poses_init = [p0_gt, p1_init, p2_init, p3_init]
    edges = [
        (0, 1, T01_meas),
        (1, 2, T12_meas),
        (2, 3, T23_meas),
        (3, 0, T30_meas),  # loop closure
    ]

    poses_opt, cost_hist = pose_graph_optimize_se2(poses_init, edges, max_iter=50)

    # Extract positions for plotting
    def extract_xy(poses):
        return np.array([T[:2, 2] for T in poses] + [poses[0][:2, 2]])

    xy_gt = extract_xy([p0_gt, p1_gt, p2_gt, p3_gt])
    xy_init = extract_xy(poses_init)
    xy_opt = extract_xy(poses_opt)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    ax.plot(xy_gt[:, 0], xy_gt[:, 1], 'g-o', label='Ground truth', linewidth=2)
    ax.plot(xy_init[:, 0], xy_init[:, 1], 'r--s', label='Initial (noisy)', linewidth=1.5)
    ax.plot(xy_opt[:, 0], xy_opt[:, 1], 'b-^', label='Optimized', linewidth=2)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_title("Pose-Graph Optimization")
    ax.legend()
    ax.grid(True)
    ax.set_aspect('equal')

    ax = axes[1]
    ax.semilogy(cost_hist, 'k-o', markersize=3)
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Total cost (log scale)")
    ax.set_title("Convergence")
    ax.grid(True)

    plt.tight_layout()
    plt.savefig("demo_pose_graph.png", dpi=120)
    print(f"  Initial cost:   {cost_hist[0]:.6f}")
    print(f"  Final cost:     {cost_hist[-1]:.6f}")
    print(f"  Cost reduction: {cost_hist[0] / max(cost_hist[-1], 1e-15):.1f}x")
    print("  → Saved demo_pose_graph.png\n")

    print("=" * 60)
    print("All demos complete. Check generated PNG files.")
    print("=" * 60)


# ============================================================
# Entry Point
# ============================================================

if __name__ == "__main__":
    if "--test" in sys.argv:
        print("Running Lie Groups self-tests...\n")
        run_tests()
    else:
        run_demo()