Building HAMUX

The building blocks of energy-based Associative Memory

We develop a modular energy perspective of Associative Memories where the energy of any model from this family can be decomposed into standardized components: neuron layers that encode dynamic variables, and hypersynapses that encode their interactions.

The total energy of the system is the sum of the individual component energies subtracted by the energies of all the layers and all the interactions between those layers.

In computer science terms, neurons and synapses form a hypergraph, where each neuron layer is a node and each hypersynapse is a hyperedge.

This framework of energy-based building blocks for memory not only clarifies how existing methods for building Associative Memories relate to each other (e.g., the Classical Hopfield Network [@hopfield1982neural, @hopfield1984Neurons], Dense Associative Memory [@krotov2016dense]), but it also provides a systematic language for designing new architectures (e.g., Energy Transformers [@hoover2024energy], Neuron-Astrocyte Networks [@kozachkov2023neuron]).

We begin by introducing the building blocks of Associative Memory: neurons and hypersynapses.

Neurons

Neurons turn Lagrangians into the dynamic building blocks of memory.

TL;DR
  1. A neuron is a fancy term to describe the dynamic (fast moving) variables in an associative memory.
  2. Every neuron has an internal state \(\mathbf{x}\) that evolves over time and an activation \(\hat{\mathbf{x}}\) that affects the rest of the network.
  3. In the complete hypergraph of Associative Memory, our neurons are the nodes while our hypersynapses are the hyperedges (since a synapse can connect more than two nodes, they cannot be regular “edges”).
  4. A neuron is just a Lagrangian assigned to a tensor of data.

A neuron layer is a Lagrangian function on top of data, where the Lagrangian defines the activations of that neuron

A neuron layer (node of the Associative Memory) is a fancy term to describe the dynamic variables in AM. Each neuron layer has an internal state \(\mathbf{x}\) which evolves over time and an activation \(\hat{\mathbf{x}}\) that forwards a signal to the rest of the network. Think of neurons like the activation functions of standard neural networks, where \(\mathbf{x}\) are the pre-activations and \(\hat{\mathbf{x}}\) (the outputs) are the activations e.g., \(\hat{\mathbf{x}} = \texttt{ReLU}(\mathbf{x})\).

In order to define neuron layer’s energy, AMs employ two mathematical tools from physics: convex Lagrangian functions and the Legendre transform. For each neuron layer, we define a convex, scalar-valued Lagrangian \(\mathcal{L}_x(\mathbf{x})\). The Legendre transform \(\mathcal{T}\) of this Lagrangian produces the dual variable \(\hat{\mathbf{x}}\) (our activations) and the dual energy \(E_x(\hat{\mathbf{x}})\) (our new energy) as in:

\[ \begin{align} \hat{\mathbf{x}} &= \nabla \mathcal{L}_x(\mathbf{x}) \quad \text{(activation function)} \\ E_x(\hat{\mathbf{x}}) = \mathcal{T}(\mathcal{L}_x) &= \langle \mathbf{x}, \hat{\mathbf{x}} \rangle - \mathcal{L}_x(\mathbf{x}) \quad \text{(dual energy)} \end{align} \tag{1}\]

where \(\langle \cdot, \cdot \rangle\) is the element-wise inner product. Because \(\mathcal{L}_x\) is convex, the Jacobian of the activations \(\frac{\partial \hat{\mathbf{x}}}{\partial \mathbf{x}} = \nabla^2 \mathcal{L}_x(\mathbf{x})\) (i.e., the Hessian of the Lagrangian) is positive definite.

Notational conventions

There are lots of named components inside the neuron layer. As a notational convention, each neuron layer is identified by a single letter (e.g., \(\mathsf{X}\) or \(\mathsf{Y}\)). We say a neuron layer \(\mathsf{X}\) has a internal state \(\mathbf{x} \in \hat{\mathcal{X}}\) and Lagrangian \(\mathcal{L}_x(\mathbf{x})\), alongside an activation \(\hat{\mathbf{x}} \in \hat{\mathcal{X}}\) and total energy \(E_x(\hat{\mathbf{x}})\) constrained through the Legendre transform of the Lagrangian. Meanwhile, neuron layer \(\mathsf{Y}\) has a internal state \(\mathbf{y} \in \mathcal{Y}\) and Lagrangian \(\mathcal{L}_y(\mathbf{y})\), alongside an activation \(\hat{\mathbf{y}} \in \hat{\mathcal{Y}}\) and total energy \(E_y(\hat{\mathbf{y}})\).

Because it is often nice to think of the activations as being a non-linear function of the internal states, we can also write \(\hat{\mathbf{x}} = \sigma_x(\mathbf{x})\), where \(\sigma_x(\cdot) := \nabla \mathcal{L}_x (\cdot)\).

The dual energy \(E_x(\hat{\mathbf{x}})\) has another nice property: its gradient equals the internal states. Thus, when we minimize the energy of our neurons (in the absence of any other signal), we observe exponential decay. This is nice to keep the dynamic behavior of our system bounded and well-behaved, especially for very large values of \(\mathbf{x}\).

\[ \frac{d \mathbf{x}}{dt} = - \nabla_{\hat{\mathbf{x}}} E_x(\hat{\mathbf{x}}) = - \mathbf{x}. \tag{2}\]

Neurons as Code

The methods to implement neurons are remarkably simple:


NeuronLayer

 NeuronLayer (lagrangian:Callable, shape:Tuple[int])

Neuron layers represent dynamic variables that evolve during inference (i.e., memory retrieval/error correction)

Exported source
class NeuronLayer(eqx.Module):
    """Neuron layers represent dynamic variables that evolve during inference (i.e., memory retrieval/error correction)"""
    lagrangian: Callable # The scalar-valued Lagrangian function:  x |-> R
    shape: Tuple[int] # The shape of the neuron layer

Remember that, at its core, a NeuronLayer object is nothing more than a Lagrangian (see example Lagrangians for examples) function on top of (shaped) data. All the other methods of the NeuronLayer class just provide conveniences on top of this core functionality.


NeuronLayer.activations

 NeuronLayer.activations (x)

Use autograd to compute the activations of the neuron layer from the Lagrangian

Exported source
@patch
def activations(self: NeuronLayer, x): 
    """Use autograd to compute the activations of the neuron layer from the Lagrangian"""
    return jax.grad(self.lagrangian)(x)

The NeuronLayer.activations is the gradient of the Lagrangian with respect to the states. This is easily computed via jax autograd.

For example, we can test the activations of a few different Lagrangians.

from hamux.lagrangians import lagr_relu, lagr_softmax, lagr_sigmoid, lagr_identity
from pprint import pp
D = 10

# Identity activation
nn = NeuronLayer(lagrangian=lambda x: jnp.sum(0.5 * x**2), shape=(D,)) # Identity activation
xtest = jr.normal(jr.key(0), (D,))
assert jnp.allclose(nn.activations(xtest), xtest)

# ReLU activation
nn = NeuronLayer(lagrangian=lagr_relu, shape=(D,))
xtest = jr.normal(jr.key(1), (D,))
assert jnp.allclose(nn.activations(xtest), jnp.maximum(0, xtest))

# Softmax activation
nn = NeuronLayer(lagrangian=lagr_softmax, shape=(D,))
xtest = jr.normal(jr.key(2), (D,))
assert jnp.allclose(nn.activations(xtest), jax.nn.softmax(xtest))

NeuronLayer.init

 NeuronLayer.init (bs:Optional[int]=None)

Return an empty initial neuron state

Exported source
@patch
def init(self: NeuronLayer, bs: Optional[int] = None):
    """Return an empty initial neuron state"""
    if bs is None or bs == 0: return jnp.zeros(self.shape) # No batch dimension
    return jnp.zeros((bs, *self.shape))

The NeuronLayer.init method is a convenience method that initializes an empty collection of neuron layer states. We generally want to populate this state with values from some piece of data.

It’s used like:

D = 784 # e.g., rasterized MNIST image size
nn = NeuronLayer(lagrangian=lambda x: 0.5 * x**2, shape=(D,))
x_unbatched = nn.init()
print("Unbatched shape:", x_unbatched.shape)
x_batched = nn.init(bs=2)
print("Batched shape:", x_batched.shape)
Unbatched shape: (784,)
Batched shape: (2, 784)

Legendre transform and Neuron Energy

The energy is the Legendre transform of the Lagrangian. Consider some scalar-valued function \(F: \mathcal{X} \mapsto \mathbb{R}\) for which we want to compute it’s dual representation \(\hat{F}: \hat{\mathcal{X}} \mapsto \mathbb{R}\) under the Legendre Transform. The Legendre transform \(\mathcal{T}\) of \(F\) transforms both the function \(F\) and its argument \(\mathbf{x}\) into a dual formulation \(\hat{F}\) and \(\hat{\mathbf{x}} = \sigma(\mathbf{x}) = \nabla F(\mathbf{x})\). The transform is defined as:

\[ \hat{F}(\hat{\mathbf{x}}) = \langle \mathbf{x}, \hat{\mathbf{x}} \rangle - F(\mathbf{x}). \]

Note that \(\hat{F}\) is only a function of \(\hat{\mathbf{x}}\) (\(\mathbf{x}\) is computed as \(\mathbf{x} = \sigma^{(-1)}(\hat{\mathbf{x}})\). You can confirm this for yourself by trying to compute \(\frac{\partial \hat{F}}{\partial \mathbf{x}}\) and checking that the answer is \(0\)).

The code for the Legendre transform is easy to implement in jax as a higher order function. We’ll assume that we always have the original variable \(\mathbf{x}\) so that we don’t need to compute \(\sigma^{(-1)}\).


legendre_transform

 legendre_transform (F:Callable)

Transform scalar F(x) into the dual Fhat(xhat, x) using the Legendre transform

Type Details
F Callable The function to transform
Exported source
def legendre_transform(
    F: Callable # The function to transform
    ):
    "Transform scalar F(x) into the dual Fhat(xhat, x) using the Legendre transform"

    # We define custom gradient rules to give jax some autograd shortcuts
    @jax.custom_jvp
    def Fhat(xhat, x): return jnp.multiply(xhat, x).sum() - F(x)

    @Fhat.defjvp
    def Fhat_jvp(primals, tangents):
        (xhat, x), (dxhat, dx) = primals, tangents
        o, do = Fhat(xhat, x), jnp.multiply(x, dxhat).sum()
        return o, do

    return Fhat

Let’s test if the legendre_transform automatic gradients are what we expect:

Test the Legendre transform
x = jnp.array([1., 2, 3])
F = lagr_sigmoid
g = jax.nn.sigmoid(x)
Fhat = legendre_transform(F)

assert jnp.allclose(jax.grad(Fhat)(g, x), x)
assert jnp.all(jax.grad(Fhat, argnums=1)(g, x) == 0.)

x = jnp.array([1., 2, 3])
F = lagr_identity
g = x
Fhat = legendre_transform(F)

assert jnp.allclose(jax.grad(Fhat)(g, x), x)
assert jnp.all(jax.grad(Fhat, argnums=1)(g, x) == 0.)

The Legendre transform is the final piece of the puzzle to describe the energy of a neuron layer.


NeuronLayer.energy

 NeuronLayer.energy (xhat, x)

The energy of the neuron layer is the Legendre transform of the Lagrangian

Exported source
@patch
def energy(self: NeuronLayer, xhat, x): 
    """The energy of the neuron layer is the Legendre transform of the Lagrangian"""
    return legendre_transform(self.lagrangian)(xhat, x)

The energy is the Legendre transform of the Lagrangian:

\[ E_\text{neuron} = \langle \mathbf{x}, \hat{\mathbf{x}} \rangle - \mathcal{L}_x(\mathbf{x}) \]

Test the energy
nn = NeuronLayer(lagrangian=lagr_sigmoid, shape=(3,))
x = jnp.array([1., 2, 3])
xhat = nn.activations(x)
assert jnp.allclose(jax.grad(nn.energy, argnums=0)(xhat, x), x)
assert jnp.allclose(jax.grad(nn.energy, argnums=1)(xhat, x), 0.)

We alias a few of the methods for convenience.

Exported source
NeuronLayer.sigma = NeuronLayer.activations
NeuronLayer.E = NeuronLayer.energy

And wrap a few other python conveniences.


NeuronLayer.__post_init__

 NeuronLayer.__post_init__ ()

Ensure the neuron shape is a tuple


NeuronLayer.__repr__

 NeuronLayer.__repr__ ()

Look nice when inspected

Hypersynapses

Hypersynapses modulate signals between one or more neuron layers.

A hypersynapse is a scalar valued energy function defined on top of the activations of connected neuron layers

The activations of one NeuronLayer are sent to other neurons via communication channels called hypersynapses. At its most general, a hypersynapse is a scalar valued energy function defined on top of the activations of connected neuron layers. For example, a hypersynapse connecting neuron layers \(\mathsf{X}\) and \(\mathsf{Y}\) has an interaction energy \(E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})\), where \(\mathbf{\Xi}\) represents the synaptic weights or learnable parameters.

\(E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})\) encodes the desired relationship between activations \(\hat{\mathbf{x}}\) and \(\hat{\mathbf{y}}\). When this energy is low, the activations satisfy the relationship encoded by the synaptic weights \(\mathbf{\Xi}\). During energy minimization, the system adjusts the activations to reduce all energy terms, which means synapses effectively pull the connected neuron layers toward configurations encoded in the parameters that minimize their interaction energy.

Hypersynapse notation conventions

For synapses connecting multiple layers, we subscript with the identifiers of all connected layers. For example:

  • \(E_{xy}\) — synapse connecting layers \(\mathsf{X}\) and \(\mathsf{Y}\)
  • \(E_{xyz}\) — synapse connecting layers \(\mathsf{X}\), \(\mathsf{Y}\), and \(\mathsf{Z}\).
  • \(E_{xyz\ldots}\) — synapses connecting more than three layers are possible, but rare.

However, synapses can also connect a layer to itself (self-connections). To avoid confusion with neuron layer energy \(E_x\), we use curly brackets for synaptic self-connections. For example, \(E_{\{x\}}\) represents the interaction energy of a synapse that connects layer \(\mathsf{X}\) to itself.

How biological are hypersynapses?

Hypersynapses in hamux differ from biological synapses in two fundamental ways:

  • Hypersynapses can connect any number of layers simultaneously, while biological synapses connect only two neurons. This officially makes hypersynapses “hyperedges” in graph theory terms.
  • Hypersynapses are undirected, meaning that all connected layers influence each other bidirectionally during energy minimization. Meanwhile, biological synapses are unidirectional, meaning signal flows from a presynaptic to postsynaptic neuron.

Because of these differences, we choose the distinct term “hypersynapses” to distinguish them from biological synapses.

Hypersynapses are represented as undirected (hyper)edges in a hypergraph. Shown is an example pairwise synapse, which is a single energy function \(E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})\) defined on the activations \(\hat{\mathbf{x}}\) and \(\hat{\mathbf{y}}\) from connected nodes, which necessarily propagate signal to both connected nodes. Here, is defined as the negative gradient of the interaction energy {w.r.t.} the connected layer’s activations (e.g., layer \(\mathsf{X}\) receives signal \(\mathcal{I}_x = -\nabla_{\hat{\mathbf{x}}} E_{xy}(\hat{\mathbf{x}}, \hat{\mathbf{y}}; \mathbf{\Xi})\)). This is in contrast to biological synapses which are directional and only propagate signal in one direction from layer \(\mathsf{X}\) to \(\mathsf{Y}\), needing a separate synapse to bring information back from \(\mathsf{Y}\) to \(\mathsf{X}\)

The undirected nature of hypersynapses fundamentally distinguishes AM from traditional neural networks. Whereas feed-forward networks follow a directed computational graph with clear input-to-output flow, AMs have no inherent concept of “forward” or “backward” directions. All connected layers influence each other bidirectionally during energy minimization, with information propagating from deeper layers to shallower layers as readily as the other way around.

Unlike the NeuronLayer’s energies, the interaction energies of the hypersynapses are completely unconstrained: any function that takes activations as input and returns a scalar is admissable and will have well-behaved dynamics. The interaction energy of a synapse may choose to introduce its own non-linearities beyond those handled by the neuron layers. When this occurs, the energy minimization dynamics must compute gradients through these “synaptic non-linearities”, unlike the case where all non-linearities are abstracted into the NeuronLayer Lagrangians.

TL;DR
  1. A hypersynapse describes the “strain energy” between the activations of one or more neuron layers. The lower that energy, the more aligned the activations.
  2. In the complete hypergraph of Associative Memory, our neurons are the nodes and our hypersynapses are the hyperedges.

Hypersynapse implementation

Hypersynapses are just callable equinox.Module with trainable parameters. Any differentiable, scalar-valued function, implemented in a __call__ method, will work.

That’s it. Many things can be hypersynapse energies. Here are two examples that may look familiar to those with a background in ML.


LinearSynapse

 LinearSynapse (W:jax.Array)

The energy synapse corrolary of the linear layer in standard neural networks

Exported source
class LinearSynapse(eqx.Module):
    """The energy synapse corrolary of the linear layer in standard neural networks"""
    W: jax.Array
    def __call__(self, xhat1:jax.Array, xhat2:jax.Array):
        "Compute the interaction energy between activations `xhat1` and `xhat2`."
        # Best to use batch-dim agnostic operations
        return -jnp.einsum("...c,...d,cd->...", xhat1, xhat2, self.W)

    @classmethod
    def rand_init(cls, key: jax.Array, x1_dim: int, x2_dim: int):
        Winit = 0.02 * jr.normal(key, (x1_dim, x2_dim))
        return cls(W=Winit)

Take the gradient w.r.t. either of the input activations and you have a linear layer.

syn = LinearSynapse.rand_init(jr.key(0), 10, 20)
xhat1 = jr.normal(jr.key(1), (10,))
xhat2 = jr.normal(jr.key(2), (20,))

print("Energy:", syn(xhat1, xhat2))
ff_compute = -jnp.einsum("...c,cd->...d", xhat1, syn.W)
fb_compute = -jnp.einsum("...d,cd->...c", xhat2, syn.W)
assert jnp.allclose(jax.grad(syn, argnums=1)(xhat1, xhat2), ff_compute)
assert jnp.allclose(jax.grad(syn, argnums=0)(xhat1, xhat2), fb_compute)
Energy: 0.045816228

The linear layer is trying to align the activations of its two connected layers, and taking the gradient w.r.t. either of the activations gives you the standard linear layer output.

We may want to add biases to the network. We can do so in two ways.


BiasSynapse

 BiasSynapse (b:jax.Array)

Energy defines constant input to a neuron layer


LinearSynapseWithBias

 LinearSynapseWithBias (W:jax.Array, b:jax.Array)

A linear synapse with a bias

D1, D2 = 10, 20
W = 0.02 * jr.normal(jr.key(0), (D1, D2))
b = jnp.arange(D2)+1
linear_syn_with_bias = LinearSynapseWithBias(W, b)

# Gradients match how linear layers work
xhat1 = jr.normal(jr.key(3), (D1,))
xhat2 = jr.normal(jr.key(4), (D2,))

expected_forward = W.T @ xhat1 + b
expected_backward = W @ xhat2

forward_signal = -jax.grad(linear_syn_with_bias, argnums=1)(xhat1, xhat2)
backward_signal = -jax.grad(linear_syn_with_bias, argnums=0)(xhat1, xhat2)
assert jnp.allclose(forward_signal, expected_forward)
assert jnp.allclose(backward_signal, expected_backward)

# Could also use a dedicated bias synapse
bias_syn = BiasSynapse(b=b)
assert jnp.allclose(-jax.grad(bias_syn)(xhat2), bias_syn.b)

Finally, we can consider even convolutional synapses. We have to get a bit creative to define the energy here to use efficient forward convolution implementations in jax.


ConvSynapse

 ConvSynapse (W:jax.Array, window_strides:Tuple[int,int], padding:str)

The energy corrolary of a convolutional layer in standard neural networks

The gradient w.r.t. xhat2 is what we call a “forward convolution”. The gradient w.r.t. xhat1 is a “transposed convolution”!

Simple test for ConvSynapse
key = jr.key(42)
H, W, C_in, C_out = 8, 8, 3, 5
filter_h, filter_w = 3, 3

# Create a ConvSynapse
conv_syn = ConvSynapse.from_conv_params(
    key=key,
    channels_out=C_out,
    channels_in=C_in, 
    filter_shape=(filter_h, filter_w),
    window_strides=(1, 1),
    padding="SAME"
)

# Create test activations
xhat1 = jr.normal(jr.key(1), (H, W, C_in))  # Input activation
xhat2 = jr.normal(jr.key(2), (H, W, C_out))  # Output activation

# Test energy computation
energy = conv_syn(xhat1, xhat2)
print(f"Energy: {energy}")
assert isinstance(energy, jax.Array) and energy.shape == ()

# The negative gradient w.r.t. xhat2 is a standard convolution
conv_result = conv_syn.forward_conv(xhat1[None])[0]  # Remove batch dim
grad_xhat2 = jax.grad(conv_syn, argnums=1)(xhat1, xhat2)
assert jnp.allclose(-grad_xhat2, conv_result)

# Test that the conv_transpose is the same as the gradient w.r.t. xhat1
conv_transpose_result = jax.lax.conv_transpose(xhat2[None], conv_syn.W, strides=conv_syn.window_strides, padding=conv_syn.padding, dimension_numbers=("NHWC", "OIHW", "NHWC"), transpose_kernel=True)[0]
grad_xhat1 = jax.grad(conv_syn, argnums=0)(xhat1, xhat2)
assert jnp.allclose(conv_transpose_result, -grad_xhat1)
Energy: -8.785972595214844

Energy Hypergraphs

The above sections have described the building blocks of an Associative Memory. What remains is to build the hypergraph that assembles them into a complete Associative Memory.

The rules of the building blocks give us a single total energy where the update rules are local and the system’s energy is guaranteed to decrease. See Figure 1 for a graphical depiction of the hypergraph of an Associative Memory.

Figure 1: hamux hypergraph diagrams are a graphical depiction of an AM whose total energy is the sum of the neuron layer (node) and hypersynapse (hyperedge) energies. Inference is done recurrently, modeled by a system of differential equations where each neuron layer’s hidden state updates to minimize the total energy. When all non-linearities are captured in the dynamic neurons, inference becomes a local computation that avoids differentiating through non-linearities.

The total energy is structured such that the activations of a neuron layer affect only connected hypersynapses and itself. Let \(\hat{\mathbf{x}}_\ell\) and \(\mathbf{x}_\ell\) represent the activations and internal states of neuron layer \(\ell\), and let \(\mathtt{N}(\ell)\) represent the set of hypersynapses that connect to neuron layer \(\ell\). The following update rule describes how neuron internal states \(\mathbf{x}_\ell\) minimize the total energy using only local signals:

\[ \tau_\ell\frac{d \mathbf{x}_\ell}{dt} = - \frac{\partial E_\text{total}}{\partial \hat{\mathbf{x}}_\ell} = - \left(\sum_{s \in \mathtt{N}(\ell)} \frac{\partial E^\text{synapse}_s}{\partial \hat{\mathbf{x}}_\ell}\right) - \frac{\partial E^\text{neuron}_\ell}{\partial \hat{\mathbf{x}}_\ell} = \mathcal{I}_{x_\ell} - \mathbf{x}_\ell, \tag{3}\]

where \(\mathcal{I}_{x_\ell} := - \sum_{s \in \mathtt{N}(\ell)} \nabla_{\hat{\mathbf{x}}_\ell} E^\text{synapse}_s\) is the total synaptic input current to neuron layer \(\ell\), which is fundamentally local and serves to minimize the energy of connected hypersynapses. See sections. The time constant for neurons in layer \(\ell\) is denoted by \(\tau_\ell\).

The central result is that dynamical equations Equation 3 decrease the global energy of the network. In order to demostrate this, consider the total time derivative of the energy

\[ \frac{dE_\text{total}}{dt} = \sum\limits_{\ell=1}^L \frac{\partial E_\text{total}}{\partial \hat{\mathbf{x}}_\ell} \frac{\partial \hat{\mathbf{x}}_\ell}{\partial \mathbf{x}_\ell} \frac{d\mathbf{x}_\ell}{dt} = -\sum\limits_{\ell=1}^L \tau_\ell \frac{d \mathbf{x} _\ell }{dt} \frac{\partial^2 \mathcal{L}_x}{\partial \mathbf{x}_\ell \partial \mathbf{x}_\ell} \frac{d\mathbf{x}_\ell}{dt} \leq 0, \]

where we expressed the partial of the energy w.r.t. the activations through the velocity of the neuron’s internal states Equation 3. The Hessian matrix \(\frac{\partial^2 \mathcal{L}_x}{\partial \mathbf{x}_\ell \partial \mathbf{x}_\ell}\) has the size number of neurons in layer \(\ell\) multiplied by the number of neurons in layer \(\ell\). As long as this matrix is positive semi-definite, a property resulting from the convexity of the Lagrangian, the total energy of the network is guaranteed to either decrease or stay constant — increase of the energy is not allowed.

Additionally, if the energy of the network is bounded from below, the dynamics in Equation 3 are guaranteed to lead the trajectories to fixed manifolds corresponding to local minima of the energy. If the fixed manifolds have zero-dimension, i.e., they are fixed point attractors, the velocity field will vanish once the network arrives at the local minimum. This correspondes to Hessians being strictly positive definite. Alternatively, if the Lagrangians have zero modes, resulting in existence of zero eigenvalues of the Hessian matrices, the network may converge to the fixed manifolds, but the velocity fields may stay non-zero, while the network’s state moves along that manifold.

Energy Hypergraph Implementation

The local, summing structure of the \(E^\text{total}\) is expressible in code as a hypergraph. We roll our own implementation in JAX to keep things simple.


HAM

 HAM (neurons:Dict[str,__main__.NeuronLayer],
      hypersynapses:Dict[str,equinox._module.Module],
      connections:List[Tuple[Tuple,str]])

A Hypergraph wrapper connecting all dynamic states (neurons) and learnable parameters (synapses) for our associative memory

Exported source
class HAM(eqx.Module): 
    "A Hypergraph wrapper connecting all dynamic states (neurons) and learnable parameters (synapses) for our associative memory" 
    neurons: Dict[str, NeuronLayer]
    hypersynapses: Dict[str, eqx.Module] 
    connections: List[Tuple[Tuple, str]]

We describe an HAM using plain python datastructures for our neurons, hypersynapses and edge list of connections. This makes each object fully compatible with jax’s tree mapping utilities, which will help keep our hypergraph code super succinct.

For example, we can create a simple HAM with two neurons and one hypersynapse:

n1_dim, n2_dim = 10, 100
neurons = {
    "n1": NeuronLayer(lagrangian=lagr_sigmoid, shape=(n1_dim,)),
    "n2": NeuronLayer(lagrangian=lambda x: lagr_softmax(x, axis=-1), shape=(n2_dim,)),
}
hypersynapses = {
    "s1": LinearSynapse.rand_init(jax.random.key(0), n1_dim, n2_dim),
}
connections = [
    (("n1", "n2"), "s1"), # Read as: "Connect neurons n1 and n2 via synapse s1"
]
ham = HAM(neurons, hypersynapses, connections)

Let’s start with some basic description of the hypergraph, describing the data object we want to create.


HAM.n_connections

 HAM.n_connections ()

Total number of connections

Exported source
@patch(as_prop=True)
def n_neurons(self:HAM) -> int:
   "Total number of neurons"
   return len(self.neurons)

@patch(as_prop=True)
def n_hypersynapses(self:HAM) -> int:
   "Total number of hypersynapses"
   return len(self.hypersynapses)

@patch(as_prop=True)
def n_connections(self:HAM) -> int:
   "Total number of connections"
   return len(self.connections)

HAM.n_hypersynapses

 HAM.n_hypersynapses ()

Total number of hypersynapses


HAM.n_neurons

 HAM.n_neurons ()

Total number of neurons


HAM.init_states

 HAM.init_states (bs:Optional[int]=None)

Initialize all neuron states

Exported source
@patch
def init_states(self: HAM, bs: Optional[int] = None):
    """Initialize all neuron states"""
    if bs is not None and bs > 0: warn("Vectorize with `ham.vectorize()` before processing batched states")
    xs = {k: v.init(bs) for k, v in self.neurons.items()}
    return xs

Initialize all the dynamic neuron states at once, optionally with a batch size. This makes it easy to treat the whole collection of neuron states as a single tensor.

xs = ham.init_states()
print(jtu.tree_map(lambda x: f"Shape: {x.shape}", xs))
{'n1': 'Shape: (10,)', 'n2': 'Shape: (100,)'}

Key into these empty states to replace the states with real data.

example_data = jr.normal(jr.key(4), xs['n1'].shape)
xs["n1"] = example_data
Variable naming conventions

Throughout this code, we universally use the xs variable to refer to the collection of neuron internal states and the xhats variable to refer to the collection of neuron activations.

Additionally, whenever a function f takes both xs and xhats as arguments, we assume the xhats are passed first in the argument order i.e., f(xhats, xs, *args, **kwargs). This is because most AM operations do gradient descent on the activations, not the internal states, and the 0-th positional arg is the default argument for jax.grad


HAM.activations

 HAM.activations (xs)

Convert hidden states of each neuron into their activations

Exported source
@patch
def activations(self: HAM, xs):
    """Convert hidden states of each neuron into their activations"""
    xhats = {k: v.sigma(xs[k]) for k, v in self.neurons.items()}
    return xhats

From the states, we can compute the activations of each neuron as a single collection:

xhats = ham.activations(xs)
assert jnp.all(xhats['n1'] > 0) and jnp.all(xhats['n1'] < 1), "Sigmoid neurons should be between 0 and 1"
assert jnp.isclose(xhats['n2'].sum(), 1.0), "Softmax neurons should sum to 1"

HAM.energy

 HAM.energy (xhats, xs)

The complete energy of the HAM


HAM.energy_tree

 HAM.energy_tree (xhats, xs)

Return energies for each individual component


HAM.connection_energies

 HAM.connection_energies (xhats)

Get the energy for each connection


HAM.neuron_energies

 HAM.neuron_energies (xhats, xs)

Retrieve the energies of each neuron in the HAM

From the activations, we can collect all the energies of the neurons and the connections in the HAM. We can organize these into an energy tree from which we compute the total energy of the entire HAM..

The complete energy of the HAM is the sum of all the individual energies from the HAM.energy_tree.

xhats = ham.activations(xs)
pp(ham.energy_tree(xhats, xs))
{'neurons': {'n1': Array(-6.215218, dtype=float32),
             'n2': Array(-4.6051702, dtype=float32)},
 'connections': [Array(-0.00045318, dtype=float32)]}
pp(ham.energy(xhats, xs))
Array(-10.820841, dtype=float32)

HAM.dEdact

 HAM.dEdact (xhats, xs, return_energy=False)

Calculate gradient of system energy w.r.t. each activation

A small helper function to make it easier to compute the gradient of the energy w.r.t. the activations.

This energy is guaranteed to monotonically decrease over time, and be bounded from below.

Vectorizing the Energy

To scale these models, we generally want to operate on batches of data and activations using the same model. We can do this by creating a VectorizedHAM object whose functions all expect a batch dimension in neuron state and activations.


VectorizedHAM

 VectorizedHAM (_ham:equinox._module.Module)

Re-expose HAM API with vectorized inputs. No new HAM behaviors should be implemented in this class.


HAM.unvectorize

 HAM.unvectorize ()

Unvectorize to work on single inputs

Exported source
@patch
def vectorize(self: HAM):
    """Vectorize to work on batches of inputs"""
    return VectorizedHAM(self)

@patch
def unvectorize(self: HAM):
    """Unvectorize to work on single inputs"""
    return self

HAM.vectorize

 HAM.vectorize ()

Vectorize to work on batches of inputs

Now our HAM logic works on batches of inputs using jax.vmap.

# Instead of this
ham = HAM(neurons=neurons, hypersynapses=hypersynapses, connections=connections)
test_warns(lambda: ham.init_states(bs=5), show=False)
vxs = ham.init_states(bs=5)
test_fail(ham.activations, args=vxs)
UserWarning: Vectorize with `ham.vectorize()` before processing batched states
  if bs is not None and bs > 0: warn("Vectorize with `ham.vectorize()` before processing batched states")
# Do this
vham = ham.vectorize()
vxs = vham.init_states(bs=5)
vxhats = vham.activations(vxs)
assert all(g.shape[0] == 5 for g in vxhats.values()), "All activations should have batch dim"