# Building HAMUX


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

We develop a *modular energy* perspective of Associative Memories where
the energy of any model from this family can be decomposed into
standardized components: [**neuron layers**](#sec-neurons) that encode
dynamic variables, and [**hypersynapses**](#sec-hypersynapses) that
encode their interactions.

**The total energy of the system is the sum of the individual component
energies subtracted by the energies of all the layers and all the
interactions between those layers.**

In computer science terms, neurons and synapses form a
[**hypergraph**](https://en.wikipedia.org/wiki/Hypergraph), where each
neuron layer is a node and each hypersynapse is a hyperedge.

This framework of energy-based building blocks for memory not only
clarifies how existing methods for building Associative Memories relate
to each other (e.g., the Classical Hopfield Network
\[@hopfield1982neural, @hopfield1984Neurons\], Dense Associative Memory
\[@krotov2016dense\]), but it also provides a systematic language for
designing new architectures (e.g., Energy Transformers
\[@hoover2024energy\], Neuron-Astrocyte Networks
\[@kozachkov2023neuron\]).

We begin by introducing the building blocks of Associative Memory:
[**neurons**](#sec-neurons) and [**hypersynapses**](#sec-hypersynapses).

# Neurons

> Neurons turn [Lagrangians](./01_lagrangians.ipynb) into the dynamic
> building blocks of memory.

<div>

> **TL;DR**
>
> 1.  A *neuron* is a fancy term to describe the dynamic (fast moving)
>     variables in an associative memory.
> 2.  Every neuron has an internal state **x** that evolves over time
>     and an activation $\hat{\mathbf{x}}$ that affects the rest of the
>     network.
> 3.  In the complete hypergraph of Associative Memory, our neurons are
>     the *nodes* while our [hypersynapses](#sec-hypersynapses) are the
>     *hyperedges* (since a synapse can connect more than two nodes,
>     they cannot be regular “edges”).
> 4.  A neuron is just a Lagrangian assigned to a tensor of data.

</div>

<figure>
<img src="figures/NeuronOverview.png" width="350"
alt="A neuron layer is a Lagrangian function on top of data, where the Lagrangian defines the activations of that neuron" />
<figcaption aria-hidden="true">A neuron layer is a Lagrangian function
on top of data, where the Lagrangian defines the activations of that
neuron</figcaption>
</figure>

A **neuron layer** (node of the Associative Memory) is a fancy term to
describe the dynamic variables in AM. Each neuron layer has an
**internal state** **x** which evolves over time and an **activation**
$\hat{\mathbf{x}}$ that forwards a signal to the rest of the network.
Think of neurons like the *activation functions* of standard neural
networks, where **x** are the *pre-activations* and $\hat{\mathbf{x}}$
(the outputs) are the *activations* e.g.,
$\hat{\mathbf{x}} = \texttt{ReLU}(\mathbf{x})$.

In order to define neuron layer’s energy, AMs employ two mathematical
tools from physics: convex **Lagrangian functions** and the **Legendre
transform**. For each neuron layer, we define a convex, scalar-valued
Lagrangian ℒ<sub>*x*</sub>(**x**). The Legendre transform 𝒯 of this
Lagrangian produces the dual variable $\hat{\mathbf{x}}$ (our
activations) and the dual energy $E_x(\hat{\mathbf{x}})$ (our new
energy) as in:

<span id="eq-neuron-legendre-transform">
$$
\begin{align}
    \hat{\mathbf{x}} &= \nabla \mathcal{L}\_x(\mathbf{x}) \quad \text{(activation function)} \\
    E_x(\hat{\mathbf{x}}) = \mathcal{T}(\mathcal{L}\_x) &= \langle \mathbf{x}, \hat{\mathbf{x}} \rangle - \mathcal{L}\_x(\mathbf{x}) \quad \text{(dual energy)}
\end{align}
 \qquad(1)$$
</span>

where ⟨⋅, ⋅⟩ is the element-wise inner product. Because ℒ<sub>*x*</sub>
is convex, the Jacobian of the activations
$\frac{\partial \hat{\mathbf{x}}}{\partial \mathbf{x}} = \nabla^2 \mathcal{L}\_x(\mathbf{x})$
(i.e., the Hessian of the Lagrangian) is positive definite.
<!-- This important point is summarized in \Cref{fig:hamux-diagram}. -->

<div>

> **Notational conventions**
>
> There are lots of named components inside the neuron layer. As a
> notational convention, each neuron layer is identified by a single
> letter (e.g., X or Y). We say a neuron layer X has a internal state
> $\mathbf{x} \in \hat{\mathcal{X}}$ and Lagrangian
> ℒ<sub>*x*</sub>(**x**), alongside an activation
> $\hat{\mathbf{x}} \in \hat{\mathcal{X}}$ and total energy
> $E_x(\hat{\mathbf{x}})$ constrained through the Legendre transform of
> the Lagrangian. Meanwhile, neuron layer Y has a internal state
> **y** ∈ 𝒴 and Lagrangian ℒ<sub>*y*</sub>(**y**), alongside an
> activation $\hat{\mathbf{y}} \in \hat{\mathcal{Y}}$ and total energy
> $E_y(\hat{\mathbf{y}})$.
>
> Because it is often nice to think of the activations as being a
> non-linear function of the internal states, we can also write
> $\hat{\mathbf{x}} = \sigma_x(\mathbf{x})$, where
> *σ*<sub>*x*</sub>(⋅) := ∇ℒ<sub>*x*</sub>(⋅).

</div>

The dual energy $E_x(\hat{\mathbf{x}})$ has another nice property: *its
gradient equals the internal states*. Thus, when we minimize the energy
of our neurons (in the absence of any other signal), we observe
exponential decay. This is nice to keep the dynamic behavior of our
system bounded and well-behaved, especially for very large values of
**x**.

<span id="eq-exponential-decay">
$$
\frac{d \mathbf{x}}{dt} = - \nabla\_{\hat{\mathbf{x}}} E_x(\hat{\mathbf{x}}) = - \mathbf{x}.
 \qquad(2)$$
</span>

## Neurons as Code

The methods to implement neurons are remarkably simple:

------------------------------------------------------------------------

### NeuronLayer

>  NeuronLayer (lagrangian:Callable, shape:Tuple[int])

*Neuron layers represent dynamic variables that evolve during inference
(i.e., memory retrieval/error correction)*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
class NeuronLayer(eqx.Module):
    """Neuron layers represent dynamic variables that evolve during inference (i.e., memory retrieval/error correction)"""
    lagrangian: Callable # The scalar-valued Lagrangian function:  x |-> R
    shape: Tuple[int] # The shape of the neuron layer
```

</details>

Remember that, at its core, a `NeuronLayer` object is nothing more than
a Lagrangian (see [example Lagrangians](./01_lagrangians.ipynb) for
examples) function on top of (shaped) data. All the other methods of the
`NeuronLayer` class just provide conveniences on top of this core
functionality.

------------------------------------------------------------------------

### NeuronLayer.activations

>  NeuronLayer.activations (x)

*Use autograd to compute the activations of the neuron layer from the
Lagrangian*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch
def activations(self: NeuronLayer, x): 
    """Use autograd to compute the activations of the neuron layer from the Lagrangian"""
    return jax.grad(self.lagrangian)(x)
```

</details>

The `NeuronLayer.activations` is the gradient of the Lagrangian with
respect to the states. This is easily computed via jax autograd.

<div>

> **Test the activations**
>
> For example, we can test the activations of a few different
> Lagrangians.
>
> ``` python
> from hamux.lagrangians import lagr_relu, lagr_softmax, lagr_sigmoid, lagr_identity
> from pprint import pp
> ```
>
> ``` python
> D = 10
>
> # Identity activation
> nn = NeuronLayer(lagrangian=lambda x: jnp.sum(0.5 * x**2), shape=(D,)) # Identity activation
> xtest = jr.normal(jr.key(0), (D,))
> assert jnp.allclose(nn.activations(xtest), xtest)
>
> # ReLU activation
> nn = NeuronLayer(lagrangian=lagr_relu, shape=(D,))
> xtest = jr.normal(jr.key(1), (D,))
> assert jnp.allclose(nn.activations(xtest), jnp.maximum(0, xtest))
>
> # Softmax activation
> nn = NeuronLayer(lagrangian=lagr_softmax, shape=(D,))
> xtest = jr.normal(jr.key(2), (D,))
> assert jnp.allclose(nn.activations(xtest), jax.nn.softmax(xtest))
> ```

</div>

------------------------------------------------------------------------

### NeuronLayer.init

>  NeuronLayer.init (bs:Optional[int]=None)

*Return an empty initial neuron state*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch
def init(self: NeuronLayer, bs: Optional[int] = None):
    """Return an empty initial neuron state"""
    if bs is None or bs == 0: return jnp.zeros(self.shape) # No batch dimension
    return jnp.zeros((bs, *self.shape))
```

</details>

The `NeuronLayer.init` method is a convenience method that initializes
an empty collection of neuron layer states. We generally want to
populate this state with values from some piece of data.

It’s used like:

``` python
D = 784 # e.g., rasterized MNIST image size
nn = NeuronLayer(lagrangian=lambda x: 0.5 * x**2, shape=(D,))
x_unbatched = nn.init()
print("Unbatched shape:", x_unbatched.shape)
x_batched = nn.init(bs=2)
print("Batched shape:", x_batched.shape)
```

    Unbatched shape: (784,)
    Batched shape: (2, 784)

## Legendre transform and Neuron Energy

The energy is the Legendre transform of the Lagrangian. Consider some
scalar-valued function *F* : 𝒳 ↦ ℝ for which we want to compute it’s
dual representation $\hat{F}: \hat{\mathcal{X}} \mapsto \mathbb{R}$
under the Legendre Transform. The Legendre transform 𝒯 of *F* transforms
both the function *F* and its argument **x** into a dual formulation *F̂*
and $\hat{\mathbf{x}} = \sigma(\mathbf{x}) = \nabla F(\mathbf{x})$. The
transform is defined as:

$$
\hat{F}(\hat{\mathbf{x}}) = \langle \mathbf{x}, \hat{\mathbf{x}} \rangle - F(\mathbf{x}).
$$

Note that *F̂* is only a function of $\hat{\mathbf{x}}$ (**x** is
computed as $\mathbf{x} = \sigma^{(-1)}(\hat{\mathbf{x}})$. You can
confirm this for yourself by trying to compute
$\frac{\partial \hat{F}}{\partial \mathbf{x}}$ and checking that the
answer is 0).

The code for the Legendre transform is easy to implement in jax as a
higher order function. We’ll assume that we always have the original
variable **x** so that we don’t need to compute *σ*<sup>(−1)</sup>.

------------------------------------------------------------------------

### legendre_transform

>  legendre_transform (F:Callable)

*Transform scalar F(x) into the dual Fhat(xhat, x) using the Legendre
transform*

<table>
<thead>
<tr>
<th></th>
<th><strong>Type</strong></th>
<th><strong>Details</strong></th>
</tr>
</thead>
<tbody>
<tr>
<td>F</td>
<td>Callable</td>
<td>The function to transform</td>
</tr>
</tbody>
</table>

<details open class="code-fold">
<summary>Exported source</summary>

``` python
def legendre_transform(
    F: Callable # The function to transform
    ):
    "Transform scalar F(x) into the dual Fhat(xhat, x) using the Legendre transform"

    # We define custom gradient rules to give jax some autograd shortcuts
    @jax.custom_jvp
    def Fhat(xhat, x): return jnp.multiply(xhat, x).sum() - F(x)

    @Fhat.defjvp
    def Fhat_jvp(primals, tangents):
        (xhat, x), (dxhat, dx) = primals, tangents
        o, do = Fhat(xhat, x), jnp.multiply(x, dxhat).sum()
        return o, do

    return Fhat
```

</details>

<div>

> **Test the Legendre transform**
>
> Let’s test if the `legendre_transform` automatic gradients are what we
> expect:
>
> <details class="code-fold">
> <summary>Test the Legendre transform</summary>
>
> ``` python
> x = jnp.array([1., 2, 3])
> F = lagr_sigmoid
> g = jax.nn.sigmoid(x)
> Fhat = legendre_transform(F)
>
> assert jnp.allclose(jax.grad(Fhat)(g, x), x)
> assert jnp.all(jax.grad(Fhat, argnums=1)(g, x) == 0.)
>
> x = jnp.array([1., 2, 3])
> F = lagr_identity
> g = x
> Fhat = legendre_transform(F)
>
> assert jnp.allclose(jax.grad(Fhat)(g, x), x)
> assert jnp.all(jax.grad(Fhat, argnums=1)(g, x) == 0.)
> ```
>
> </details>

</div>

The Legendre transform is the final piece of the puzzle to describe the
energy of a neuron layer.

------------------------------------------------------------------------

### NeuronLayer.energy

>  NeuronLayer.energy (xhat, x)

*The energy of the neuron layer is the Legendre transform of the
Lagrangian*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch
def energy(self: NeuronLayer, xhat, x): 
    """The energy of the neuron layer is the Legendre transform of the Lagrangian"""
    return legendre_transform(self.lagrangian)(xhat, x)
```

</details>

The energy is the Legendre transform of the Lagrangian:

$$
E\_\text{neuron} = \langle \mathbf{x}, \hat{\mathbf{x}} \rangle - \mathcal{L}\_x(\mathbf{x})
$$

<details class="code-fold">
<summary>Test the energy</summary>

``` python
nn = NeuronLayer(lagrangian=lagr_sigmoid, shape=(3,))
x = jnp.array([1., 2, 3])
xhat = nn.activations(x)
assert jnp.allclose(jax.grad(nn.energy, argnums=0)(xhat, x), x)
assert jnp.allclose(jax.grad(nn.energy, argnums=1)(xhat, x), 0.)
```

</details>

<div>

> **Additional methods**
>
> We alias a few of the methods for convenience.
>
> <details open class="code-fold">
> <summary>Exported source</summary>
>
> ``` python
> NeuronLayer.sigma = NeuronLayer.activations
> NeuronLayer.E = NeuronLayer.energy
> ```
>
> </details>
>
> And wrap a few other python conveniences.
>
> ------------------------------------------------------------------------
>
> ### NeuronLayer.\_\_post_init\_\_
>
> >  NeuronLayer.__post_init__ ()
>
> *Ensure the neuron shape is a tuple*
>
> ------------------------------------------------------------------------
>
> ### NeuronLayer.\_\_repr\_\_
>
> >  NeuronLayer.__repr__ ()
>
> *Look nice when inspected*

</div>

# Hypersynapses

> Hypersynapses modulate signals between one or more neuron layers.

<figure>
<img src="figures/HypersynapseOverview.png" width="400"
alt="A hypersynapse is a scalar valued energy function defined on top of the activations of connected neuron layers" />
<figcaption aria-hidden="true">A hypersynapse is a scalar valued energy
function defined on top of the activations of connected neuron
layers</figcaption>
</figure>

The activations of one `NeuronLayer` are sent to other neurons via
communication channels called **hypersynapses**. At its most general, a
hypersynapse is a scalar valued energy function defined on top of the
activations of connected neuron layers. For example, a hypersynapse
connecting neuron layers X and Y has an **interaction energy**
$E\_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})$, where **Ξ**
represents the **synaptic weights** or learnable parameters.

$E\_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})$ encodes the
desired relationship between activations $\hat{\mathbf{x}}$ and
$\hat{\mathbf{y}}$. When this energy is low, the activations satisfy the
relationship encoded by the synaptic weights **Ξ**. During energy
minimization, the system adjusts the activations to reduce all energy
terms, which means synapses effectively *pull* the connected neuron
layers toward configurations encoded in the parameters that minimize
their interaction energy.

<div>

> **Hypersynapse notation conventions**
>
> For synapses connecting multiple layers, we subscript with the
> identifiers of all connected layers. For example:
>
> - *E*<sub>*x**y*</sub> — synapse connecting layers X and Y
> - *E*<sub>*x**y**z*</sub> — synapse connecting layers X, Y, and Z.  
> - *E*<sub>*x**y**z*…</sub> — synapses connecting more than three
>   layers are possible, but rare.
>
> However, synapses can also connect a layer to itself
> (self-connections). To avoid confusion with neuron layer energy
> *E*<sub>*x*</sub>, we use curly brackets for synaptic
> self-connections. For example, *E*<sub>{*x*}</sub> represents the
> interaction energy of a synapse that connects layer X to itself.

</div>

### How biological are hypersynapses?

Hypersynapses in `hamux` differ from biological synapses in two
fundamental ways:

- **Hypersynapses can connect any number of layers simultaneously**,
  while biological synapses connect only two neurons. This officially
  makes hypersynapses “hyperedges” in graph theory terms.
- **Hypersynapses are undirected**, meaning that all connected layers
  influence each other bidirectionally during energy minimization.
  Meanwhile, biological synapses are unidirectional, meaning signal
  flows from a presynaptic to postsynaptic neuron.

Because of these differences, we choose the distinct term
“hypersynapses” to distinguish them from biological synapses.

<figure>
<img src="figures/hamux-undirected-synapses.png" width="400"
alt="Hypersynapses are represented as undirected (hyper)edges in a hypergraph. Shown is an example pairwise synapse, which is a single energy function E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi}) defined on the activations \hat{\mathbf{x}} and \hat{\mathbf{y}} from connected nodes, which necessarily propagate signal to both connected nodes. Here,  is defined as the negative gradient of the interaction energy {w.r.t.} the connected layer’s activations (e.g., layer \mathsf{X} receives signal \mathcal{I}_x = -\nabla_{\hat{\mathbf{x}}} E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})). This is in contrast to biological synapses which are directional and only propagate signal in one direction from layer \mathsf{X} to \mathsf{Y}, needing a separate synapse to bring information back from \mathsf{Y} to \mathsf{X}" />
<figcaption aria-hidden="true"><strong>Hypersynapses are represented as
undirected (hyper)edges in a hypergraph.</strong> Shown is an example
pairwise synapse, which is a single energy function <span
class="math inline">$E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}};
\mathbf{\Xi})$</span> defined on the activations <span
class="math inline">$\hat{\mathbf{x}}$</span> and <span
class="math inline">$\hat{\mathbf{y}}$</span> from connected nodes,
which necessarily propagate signal to both connected nodes. Here, is
defined as the negative gradient of the interaction energy {w.r.t.} the
connected layer’s activations (e.g., layer <span
class="math inline">X</span> receives signal <span
class="math inline">$\mathcal{I}_x = -\nabla_{\hat{\mathbf{x}}}
E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})$</span>). This
is in contrast to biological synapses which are directional and only
propagate signal in one direction from layer <span
class="math inline">X</span> to <span class="math inline">Y</span>,
needing a separate synapse to bring information back from <span
class="math inline">Y</span> to <span
class="math inline">X</span></figcaption>
</figure>

The undirected nature of hypersynapses fundamentally distinguishes AM
from traditional neural networks. Whereas feed-forward networks follow a
directed computational graph with clear input-to-output flow, AMs have
no inherent concept of “forward” or “backward” directions. All connected
layers influence each other bidirectionally during energy minimization,
with information propagating from deeper layers to shallower layers as
readily as the other way around.

Unlike the `NeuronLayer`’s energies, the interaction energies of the
hypersynapses are completely unconstrained: *any function* that takes
activations as input and returns a scalar is admissable and will have
well-behaved dynamics. The interaction energy of a synapse may choose to
introduce its own non-linearities beyond those handled by the neuron
layers. When this occurs, the energy minimization dynamics must compute
gradients through these “synaptic non-linearities”, unlike the case
where all non-linearities are abstracted into the `NeuronLayer`
Lagrangians.

<div>

> **TL;DR**
>
> 1.  A *hypersynapse* describes the “strain energy” between the
>     activations of one or more neuron layers. The lower that energy,
>     the more aligned the activations.
> 2.  In the complete hypergraph of Associative Memory, our neurons are
>     the nodes and our hypersynapses are the *hyperedges*.

</div>

## Hypersynapse implementation

> **Hypersynapses are just callable `equinox.Module` with trainable
> parameters.** Any differentiable, scalar-valued function, implemented
> in a `__call__` method, will work.

That’s it. Many things can be hypersynapse energies. Here are two
examples that may look familiar to those with a background in ML.

------------------------------------------------------------------------

### LinearSynapse

>  LinearSynapse (W:jax.Array)

*The energy synapse corrolary of the linear layer in standard neural
networks*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
class LinearSynapse(eqx.Module):
    """The energy synapse corrolary of the linear layer in standard neural networks"""
    W: jax.Array
    def __call__(self, xhat1:jax.Array, xhat2:jax.Array):
        "Compute the interaction energy between activations `xhat1` and `xhat2`."
        # Best to use batch-dim agnostic operations
        return -jnp.einsum("...c,...d,cd->...", xhat1, xhat2, self.W)

    @classmethod
    def rand_init(cls, key: jax.Array, x1_dim: int, x2_dim: int):
        Winit = 0.02 * jr.normal(key, (x1_dim, x2_dim))
        return cls(W=Winit)
```

</details>

Take the gradient w.r.t. either of the input activations and you have a
linear layer.

``` python
syn = LinearSynapse.rand_init(jr.key(0), 10, 20)
xhat1 = jr.normal(jr.key(1), (10,))
xhat2 = jr.normal(jr.key(2), (20,))

print("Energy:", syn(xhat1, xhat2))
ff_compute = -jnp.einsum("...c,cd->...d", xhat1, syn.W)
fb_compute = -jnp.einsum("...d,cd->...c", xhat2, syn.W)
assert jnp.allclose(jax.grad(syn, argnums=1)(xhat1, xhat2), ff_compute)
assert jnp.allclose(jax.grad(syn, argnums=0)(xhat1, xhat2), fb_compute)
```

    Energy: 0.045816228

*The linear layer is trying to align the activations of its two
connected layers, and taking the gradient w.r.t. either of the
activations gives you the standard linear layer output.*

We may want to add biases to the network. We can do so in two ways.

------------------------------------------------------------------------

### BiasSynapse

>  BiasSynapse (b:jax.Array)

*Energy defines constant input to a neuron layer*

------------------------------------------------------------------------

### LinearSynapseWithBias

>  LinearSynapseWithBias (W:jax.Array, b:jax.Array)

*A linear synapse with a bias*

``` python
D1, D2 = 10, 20
W = 0.02 * jr.normal(jr.key(0), (D1, D2))
b = jnp.arange(D2)+1
linear_syn_with_bias = LinearSynapseWithBias(W, b)

# Gradients match how linear layers work
xhat1 = jr.normal(jr.key(3), (D1,))
xhat2 = jr.normal(jr.key(4), (D2,))

expected_forward = W.T @ xhat1 + b
expected_backward = W @ xhat2

forward_signal = -jax.grad(linear_syn_with_bias, argnums=1)(xhat1, xhat2)
backward_signal = -jax.grad(linear_syn_with_bias, argnums=0)(xhat1, xhat2)
assert jnp.allclose(forward_signal, expected_forward)
assert jnp.allclose(backward_signal, expected_backward)

# Could also use a dedicated bias synapse
bias_syn = BiasSynapse(b=b)
assert jnp.allclose(-jax.grad(bias_syn)(xhat2), bias_syn.b)
```

Finally, we can consider even convolutional synapses. We have to get a
bit creative to define the energy here to use efficient forward
convolution implementations in jax.

------------------------------------------------------------------------

### ConvSynapse

>  ConvSynapse (W:jax.Array, window_strides:Tuple[int,int], padding:str)

*The energy corrolary of a convolutional layer in standard neural
networks*

The gradient w.r.t. `xhat2` is what we call a “forward convolution”. The
gradient w.r.t. `xhat1` is a [“transposed
convolution”](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_transpose.html)!

<details class="code-fold">
<summary>Simple test for ConvSynapse</summary>

``` python
key = jr.key(42)
H, W, C_in, C_out = 8, 8, 3, 5
filter_h, filter_w = 3, 3

# Create a ConvSynapse
conv_syn = ConvSynapse.from_conv_params(
    key=key,
    channels_out=C_out,
    channels_in=C_in, 
    filter_shape=(filter_h, filter_w),
    window_strides=(1, 1),
    padding="SAME"
)

# Create test activations
xhat1 = jr.normal(jr.key(1), (H, W, C_in))  # Input activation
xhat2 = jr.normal(jr.key(2), (H, W, C_out))  # Output activation

# Test energy computation
energy = conv_syn(xhat1, xhat2)
print(f"Energy: {energy}")
assert isinstance(energy, jax.Array) and energy.shape == ()

# The negative gradient w.r.t. xhat2 is a standard convolution
conv_result = conv_syn.forward_conv(xhat1[None])[0]  # Remove batch dim
grad_xhat2 = jax.grad(conv_syn, argnums=1)(xhat1, xhat2)
assert jnp.allclose(-grad_xhat2, conv_result)

# Test that the conv_transpose is the same as the gradient w.r.t. xhat1
conv_transpose_result = jax.lax.conv_transpose(xhat2[None], conv_syn.W, strides=conv_syn.window_strides, padding=conv_syn.padding, dimension_numbers=("NHWC", "OIHW", "NHWC"), transpose_kernel=True)[0]
grad_xhat1 = jax.grad(conv_syn, argnums=0)(xhat1, xhat2)
assert jnp.allclose(conv_transpose_result, -grad_xhat1)
```

</details>

    Energy: -8.785972595214844

# Energy Hypergraphs

The above sections have described the building blocks of an Associative
Memory. What remains is to build the hypergraph that assembles them into
a complete Associative Memory.

The rules of the building blocks give us a **single total energy** where
the update rules are **local** and the system’s energy is **guaranteed
to decrease**. See
<a href="#fig-hamux-overview" class="quarto-xref">Figure 1</a> for a
graphical depiction of the hypergraph of an Associative Memory.

<img src="./figures/hamux_overview.png" width="700" />

The total energy is structured such that the activations of a neuron
layer affect only connected hypersynapses and itself. Let
$\hat{\mathbf{x}}\_\ell$ and **x**<sub>ℓ</sub> represent the activations
and internal states of neuron layer ℓ, and let `N`(ℓ) represent the set
of hypersynapses that connect to neuron layer ℓ. The following update
rule describes how neuron internal states **x**<sub>ℓ</sub> minimize the
total energy using only local signals:

<span id="eq-hamux-local-update">
$$
\tau\_\ell\frac{d \mathbf{x}\_\ell}{dt} = - \frac{\partial E\_\text{total}}{\partial \hat{\mathbf{x}}\_\ell} = - \left(\sum\_{s \in \mathtt{N}(\ell)} \frac{\partial E^\text{synapse}\_s}{\partial \hat{\mathbf{x}}\_\ell}\right) - \frac{\partial E^\text{neuron}\_\ell}{\partial \hat{\mathbf{x}}\_\ell} = \mathcal{I}\_{x\_\ell} - \mathbf{x}\_\ell, 
 \qquad(3)$$
</span>

where
$\mathcal{I}\_{x\_\ell} := - \sum\_{s \in \mathtt{N}(\ell)} \nabla\_{\hat{\mathbf{x}}\_\ell} E^\text{synapse}\_s$
is the **total synaptic input current** to neuron layer ℓ, which is
fundamentally local and serves to minimize the energy of connected
hypersynapses. See sections. The time constant for neurons in layer ℓ is
denoted by *τ*<sub>ℓ</sub>.

The central result is that dynamical equations
<a href="#eq-hamux-local-update" class="quarto-xref">Equation 3</a>
decrease the global energy of the network. In order to demostrate this,
consider the total time derivative of the energy

$$
\frac{dE\_\text{total}}{dt} = \sum\limits\_{\ell=1}^L \frac{\partial E\_\text{total}}{\partial \hat{\mathbf{x}}\_\ell} \frac{\partial \hat{\mathbf{x}}\_\ell}{\partial \mathbf{x}\_\ell} \frac{d\mathbf{x}\_\ell}{dt} = -\sum\limits\_{\ell=1}^L \tau\_\ell \frac{d \mathbf{x} \_\ell }{dt} \frac{\partial^2 \mathcal{L}\_x}{\partial \mathbf{x}\_\ell \partial \mathbf{x}\_\ell} \frac{d\mathbf{x}\_\ell}{dt} \leq 0,
$$

where we expressed the partial of the energy w.r.t. the activations
through the velocity of the neuron’s internal states
<a href="#eq-hamux-local-update" class="quarto-xref">Equation 3</a>. The
Hessian matrix
$\frac{\partial^2 \mathcal{L}\_x}{\partial \mathbf{x}\_\ell \partial \mathbf{x}\_\ell}$
has the size number of neurons in layer ℓ multiplied by the number of
neurons in layer ℓ. As long as this matrix is positive semi-definite, a
property resulting from the convexity of the Lagrangian, the total
energy of the network is guaranteed to either decrease or stay constant
— increase of the energy is not allowed.

Additionally, if the energy of the network is bounded from below, the
dynamics in
<a href="#eq-hamux-local-update" class="quarto-xref">Equation 3</a> are
guaranteed to lead the trajectories to fixed manifolds corresponding to
local minima of the energy. If the fixed manifolds have zero-dimension,
i.e., they are fixed point attractors, the velocity field will vanish
once the network arrives at the local minimum. This correspondes to
Hessians being strictly positive definite. Alternatively, if the
Lagrangians have zero modes, resulting in existence of zero eigenvalues
of the Hessian matrices, the network may converge to the fixed
manifolds, but the velocity fields may stay non-zero, while the
network’s state moves along that manifold.

## Energy Hypergraph Implementation

The local, summing structure of the *E*<sup>total</sup> is expressible
in code as a hypergraph. We roll our own implementation in JAX to keep
things simple.

------------------------------------------------------------------------

### HAM

>  HAM (neurons:Dict[str,__main__.NeuronLayer],
>           hypersynapses:Dict[str,equinox._module.Module],
>           connections:List[Tuple[Tuple,str]])

*A Hypergraph wrapper connecting all dynamic states (neurons) and
learnable parameters (synapses) for our associative memory*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
class HAM(eqx.Module): 
    "A Hypergraph wrapper connecting all dynamic states (neurons) and learnable parameters (synapses) for our associative memory" 
    neurons: Dict[str, NeuronLayer]
    hypersynapses: Dict[str, eqx.Module] 
    connections: List[Tuple[Tuple, str]]
```

</details>

We describe an HAM using plain python datastructures for our `neurons`,
`hypersynapses` and edge list of `connections`. This makes each object
fully compatible with `jax`’s tree mapping utilities, which will help
keep our hypergraph code super succinct.

For example, we can create a simple HAM with two neurons and one
hypersynapse:

``` python
n1_dim, n2_dim = 10, 100
neurons = {
    "n1": NeuronLayer(lagrangian=lagr_sigmoid, shape=(n1_dim,)),
    "n2": NeuronLayer(lagrangian=lambda x: lagr_softmax(x, axis=-1), shape=(n2_dim,)),
}
hypersynapses = {
    "s1": LinearSynapse.rand_init(jax.random.key(0), n1_dim, n2_dim),
}
connections = [
    (("n1", "n2"), "s1"), # Read as: "Connect neurons n1 and n2 via synapse s1"
]
ham = HAM(neurons, hypersynapses, connections)
```

Let’s start with some basic description of the hypergraph, describing
the data object we want to create.

------------------------------------------------------------------------

### HAM.n_connections

>  HAM.n_connections ()

*Total number of connections*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch(as_prop=True)
def n_neurons(self:HAM) -> int:
   "Total number of neurons"
   return len(self.neurons)

@patch(as_prop=True)
def n_hypersynapses(self:HAM) -> int:
   "Total number of hypersynapses"
   return len(self.hypersynapses)

@patch(as_prop=True)
def n_connections(self:HAM) -> int:
   "Total number of connections"
   return len(self.connections)
```

</details>

------------------------------------------------------------------------

### HAM.n_hypersynapses

>  HAM.n_hypersynapses ()

*Total number of hypersynapses*

------------------------------------------------------------------------

### HAM.n_neurons

>  HAM.n_neurons ()

*Total number of neurons*

------------------------------------------------------------------------

### HAM.init_states

>  HAM.init_states (bs:Optional[int]=None)

*Initialize all neuron states*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch
def init_states(self: HAM, bs: Optional[int] = None):
    """Initialize all neuron states"""
    if bs is not None and bs > 0: warn("Vectorize with `ham.vectorize()` before processing batched states")
    xs = {k: v.init(bs) for k, v in self.neurons.items()}
    return xs
```

</details>

Initialize all the dynamic neuron states at once, optionally with a
batch size. This makes it easy to treat the whole collection of neuron
states as a single tensor.

``` python
xs = ham.init_states()
print(jtu.tree_map(lambda x: f"Shape: {x.shape}", xs))
```

    {'n1': 'Shape: (10,)', 'n2': 'Shape: (100,)'}

Key into these empty states to replace the states with real data.

``` python
example_data = jr.normal(jr.key(4), xs['n1'].shape)
xs["n1"] = example_data
```

<div>

> **Variable naming conventions**
>
> Throughout this code, we universally use the `xs` variable to refer to
> the collection of neuron internal states and the `xhats` variable to
> refer to the collection of neuron activations.
>
> Additionally, whenever a function `f` takes both `xs` and `xhats` as
> arguments, we assume the `xhats` are passed first in the argument
> order i.e., `f(xhats, xs, *args, **kwargs)`. This is because most AM
> operations do gradient descent on the activations, not the internal
> states, and the 0-th positional arg is the default argument for
> `jax.grad`

</div>

------------------------------------------------------------------------

### HAM.activations

>  HAM.activations (xs)

*Convert hidden states of each neuron into their activations*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch
def activations(self: HAM, xs):
    """Convert hidden states of each neuron into their activations"""
    xhats = {k: v.sigma(xs[k]) for k, v in self.neurons.items()}
    return xhats
```

</details>

From the states, we can compute the activations of each neuron as a
single collection:

``` python
xhats = ham.activations(xs)
assert jnp.all(xhats['n1'] > 0) and jnp.all(xhats['n1'] < 1), "Sigmoid neurons should be between 0 and 1"
assert jnp.isclose(xhats['n2'].sum(), 1.0), "Softmax neurons should sum to 1"
```

------------------------------------------------------------------------

### HAM.energy

>  HAM.energy (xhats, xs)

*The complete energy of the HAM*

------------------------------------------------------------------------

### HAM.energy_tree

>  HAM.energy_tree (xhats, xs)

*Return energies for each individual component*

------------------------------------------------------------------------

### HAM.connection_energies

>  HAM.connection_energies (xhats)

*Get the energy for each connection*

------------------------------------------------------------------------

### HAM.neuron_energies

>  HAM.neuron_energies (xhats, xs)

*Retrieve the energies of each neuron in the HAM*

From the activations, we can collect all the energies of the neurons and
the connections in the HAM. We can organize these into an energy tree
from which we compute the total energy of the entire HAM..

The complete energy of the HAM is the sum of all the individual energies
from the `HAM.energy_tree`.

``` python
xhats = ham.activations(xs)
pp(ham.energy_tree(xhats, xs))
```

    {'neurons': {'n1': Array(-6.215218, dtype=float32),
                 'n2': Array(-4.6051702, dtype=float32)},
     'connections': [Array(-0.00045318, dtype=float32)]}

``` python
pp(ham.energy(xhats, xs))
```

    Array(-10.820841, dtype=float32)

------------------------------------------------------------------------

### HAM.dEdact

>  HAM.dEdact (xhats, xs, return_energy=False)

*Calculate gradient of system energy w.r.t. each activation*

A small helper function to make it easier to compute the gradient of the
energy w.r.t. the activations.

This energy is guaranteed to monotonically decrease over time, and be
bounded from below.

# Vectorizing the Energy

To scale these models, we generally want to operate on batches of data
and activations using the same model. We can do this by creating a
`VectorizedHAM` object whose functions all expect a batch dimension in
neuron state and activations.

------------------------------------------------------------------------

### VectorizedHAM

>  VectorizedHAM (_ham:equinox._module.Module)

*Re-expose HAM API with vectorized inputs. No new HAM behaviors should
be implemented in this class.*

------------------------------------------------------------------------

### HAM.unvectorize

>  HAM.unvectorize ()

*Unvectorize to work on single inputs*

<details open class="code-fold">
<summary>Exported source</summary>

``` python
@patch
def vectorize(self: HAM):
    """Vectorize to work on batches of inputs"""
    return VectorizedHAM(self)

@patch
def unvectorize(self: HAM):
    """Unvectorize to work on single inputs"""
    return self
```

</details>

------------------------------------------------------------------------

### HAM.vectorize

>  HAM.vectorize ()

*Vectorize to work on batches of inputs*

Now our `HAM` logic works on batches of inputs using `jax.vmap`.

``` python
# Instead of this
ham = HAM(neurons=neurons, hypersynapses=hypersynapses, connections=connections)
test_warns(lambda: ham.init_states(bs=5), show=False)
vxs = ham.init_states(bs=5)
test_fail(ham.activations, args=vxs)
```

    UserWarning: Vectorize with `ham.vectorize()` before processing batched states
      if bs is not None and bs > 0: warn("Vectorize with `ham.vectorize()` before processing batched states")

``` python
# Do this
vham = ham.vectorize()
vxs = vham.init_states(bs=5)
vxhats = vham.activations(vxs)
assert all(g.shape[0] == 5 for g in vxhats.values()), "All activations should have batch dim"
```
