Registry

Easily create preconfigured models and prediction functions on a HAM

We create very simple helper functions to instantiate HAMs with particular architectural choices. Inspired by timm.

A HAM is a fundamentally general purpose architecture. It is a general-purpose Associative Memory – it is up to the user to extract the desired information from the system. Hence, every registered model must return the ham architecture and a fwd function that accomplishes a task from that architecture

The Registry


source

named_partial

 named_partial (f, *args, new_name=None, order=None, **kwargs)

Like functools.partial but also copies over function name and docstring.

If new_name is not None, use that as the name


source

create_model

 create_model (mname:str, *args, **kwargs)

Retrieve the model name from all registered models, passing args and kwargs to the factory function

Type Details
mname str Retrieve this stored model name
args
kwargs

source

register_model

 register_model (fgen:Callable)

Register a function that returns a model configuration factory function. The name of the function acts as the retrieval key and must be unique across models

Type Details
fgen typing.Callable Function that returns a HAM with desired config

We can now register a model as follows:

@register_model
def example_classical_hn(img_shape:Tuple, # Vector input size
            label_shape:Tuple[int], # Number of labels
            nhid:int=1000, # Number of hidden units in the single hidden layer
            depth:int=4, # Default number of iterations to run the Hopfield Network prediction function
            dt:float=0.4, # Default step size of the system
           ): 
    """Create a 2-layer classical Hopfield Network applied on vectorized inputs and a function showing how to use it"""
    layers = [
        hmx.TanhLayer(img_shape),
        hmx.SoftmaxLayer(label_shape),
    ]

    synapses = [
        hmx.DenseMatrixSynapseWithHiddenLayer(nhid, hidden_lagrangian=hmx.lagrangians.LRelu()),
    ]

    connections = [
        ((0, 1), 0),
    ]

    ham = hmx.HAM(layers, synapses, connections)
    
    def fwd(model, x, depth=depth, dt=dt, rng=None):
        """A pure function to extract desired information from the configured HAM, applied on batched inputs"""
        # Initialize hidden states to our image
        xs = model.init_states(x.shape[0], rng=rng)
        xs[0] = jnp.array(x)

        # Masks allow us to clamp our visible data over time
        masks = jtu.tree_map(lambda x: jnp.ones_like(x, dtype=jnp.int8), xs)
        masks[0] = jnp.zeros_like(masks[0], dtype=jnp.int8)  # Don't evolve images

        for i in range(depth):
            updates = model.vupdates(xs)  # Calculate the updates
            xs = model.step(
                xs, updates, dt=dt, masks=masks
            )  # Add them to our current states

        # All labels have a softmax activation function as the last layer, spitting out probabilities
        return model.layers[-1].g(xs[-1])

    return ham, fwd

The model that we just created comes with a default function that predicts label probabilities after 4 steps (though feel free to write any function to extract a layer state/activation at any point in time).

img_shape = (32,32); bs = 12
model, fwd = create_model("example_classical_hn", img_shape=img_shape, label_shape=(10,))

_, model = model.init_states_and_params(jax.random.PRNGKey(0))
x = jnp.ones((bs, *img_shape))
probs = fwd(model, x); probs.shape
2022-12-13 16:24:58.481170: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(12, 10)

For the simple pipeline of classification, our fwd pipelines are quite similar. We therefore create some helper functions to use throughout the rest of our model configuration.


source

fwd_conv

 fwd_conv (model:hamux.ham.HAM, x:jax.Array, depth:int, dt:float,
           rng:Optional[jax.Array]=None)

Where the image input is kept as a 3 channel image

Type Default Details
model HAM HAM where layer[0] is the image input and layer[-1] are the labels
x Array Starting point for clamped layer[0]
depth int Number of iterations for which to run the model
dt float Step size through time
rng typing.Optional[jax.Array] None If provided, initialize states to random instead of 0

source

fwd_vec

 fwd_vec (model:hamux.ham.HAM, x:jax.Array, depth:int, dt:float,
          rng:Optional[jax.Array]=None)

Where the image input is vectorized

Type Default Details
model HAM HAM where layer[0] is the image input and layer[-1] are the labels
x Array Starting point for clamped layer[0]
depth int Number of iterations for which to run the model
dt float Step size through time
rng typing.Optional[jax.Array] None If provided, initialize states to random instead of 0

source

simple_fwd

 simple_fwd (model:hamux.ham.HAM, x:jax.Array, depth:int, dt:float,
             rng:Optional[jax.Array]=None)

A simple version of the forward function for showing in the paper.

All time constants tau are set to be 1 in our architecture, but this is variable

Type Default Details
model HAM HAM where layer[0] is the image input and layer[-1] are the labels
x Array Starting point for clamped layer[0]
depth int Number of iterations for which to run the model
dt float Step size through time
rng typing.Optional[jax.Array] None If provided, initialize states to random instead of 0

Model Registry

2 Layer HN


source

hn_softmax_cifar

 hn_softmax_cifar (nhid=6000)

source

hn_softmax_mnist

 hn_softmax_mnist (nhid=1000)

source

hn_repu5_cifar

 hn_repu5_cifar (nhid=6000)

Vectorized DAM on flattened CIFAR


source

hn_repu5_mnist

 hn_repu5_mnist (nhid=1000)

Vectorized DAM on flattened MNIST


source

hn_relu_cifar

 hn_relu_cifar (nhid:int=6000)

Vectorized HN on flattened CIFAR10

Type Default Details
nhid int 6000 Number of units in the single hidden layer

source

hn_relu_mnist

 hn_relu_mnist (nhid:int=1000)

Vectorized HN on flattened MNIST

Type Default Details
nhid int 1000 Number of units in the single hidden layer

source

hn

 hn (hidden_lagrangian:treex.module.Module, img_shape:Tuple,
     label_shape:Tuple, nhid:int=1000, do_norm:bool=False)

Create a Classical Hopfield Network that is intended to be applied on vectorized inputs

Type Default Details
hidden_lagrangian Module
img_shape typing.Tuple Shape of image input to model
label_shape typing.Tuple Shape of label probabilities,typically (NLABELS,)
nhid int 1000 Number of units in hidden layer
do_norm bool False If provided, enforce that all weights are standardized

These models can now be instantiated by their strings:

xcifar = jnp.ones((1,3, 32,32)) # Per pytorch convention, CHW
xmnist = jnp.ones((1,1,28,28)) # Per pytorch convention, CHW

exhn, exhn_fwd = create_model("hn", hmx.lagrangians.LExp(), (32,32,3), (10,))
_, exhn = exhn.init_states_and_params(jax.random.PRNGKey(22))
exhn_fwd(exhn, xcifar)
Array([[0.31536135, 0.08483113, 0.0897951 , 0.02981309, 0.03241062,
        0.03071734, 0.03666373, 0.00281298, 0.03433144, 0.3432633 ]],      dtype=float32)

Simple Convolution


source

conv_ham_maxpool_cifar

 conv_ham_maxpool_cifar (nhid=1000)

source

conv_ham_avgpool_cifar

 conv_ham_avgpool_cifar (nhid=1000)

source

conv_ham_maxpool_mnist

 conv_ham_maxpool_mnist (nhid=1000)

source

conv_ham_avgpool_mnist

 conv_ham_avgpool_mnist (nhid=1000)

source

conv_ham

 conv_ham (s1, s2, s3, pool_type, nhid=1000)
model, fwd = create_model("conv_ham_avgpool_cifar")
_, model = model.init_states_and_params(jax.random.PRNGKey(0))
fwd(model, xcifar)

Energy Version of Attention

We now introduce a simple model for energy-based attention


source

energy_attn_cifar

 energy_attn_cifar ()

source

energy_attn_mnist

 energy_attn_mnist ()

source

energy_attn

 energy_attn (s1, s2, nheads_self, nheads_cross)