@register_model
def example_classical_hn(img_shape:Tuple, # Vector input size
int], # Number of labels
label_shape:Tuple[int=1000, # Number of hidden units in the single hidden layer
nhid:int=4, # Default number of iterations to run the Hopfield Network prediction function
depth:float=0.4, # Default step size of the system
dt:
): """Create a 2-layer classical Hopfield Network applied on vectorized inputs and a function showing how to use it"""
= [
layers
hmx.TanhLayer(img_shape),
hmx.SoftmaxLayer(label_shape),
]
= [
synapses =hmx.lagrangians.LRelu()),
hmx.DenseMatrixSynapseWithHiddenLayer(nhid, hidden_lagrangian
]
= [
connections 0, 1), 0),
((
]
= hmx.HAM(layers, synapses, connections)
ham
def fwd(model, x, depth=depth, dt=dt, rng=None):
"""A pure function to extract desired information from the configured HAM, applied on batched inputs"""
# Initialize hidden states to our image
= model.init_states(x.shape[0], rng=rng)
xs 0] = jnp.array(x)
xs[
# Masks allow us to clamp our visible data over time
= jtu.tree_map(lambda x: jnp.ones_like(x, dtype=jnp.int8), xs)
masks 0] = jnp.zeros_like(masks[0], dtype=jnp.int8) # Don't evolve images
masks[
for i in range(depth):
= model.vupdates(xs) # Calculate the updates
updates = model.step(
xs =dt, masks=masks
xs, updates, dt# Add them to our current states
)
# All labels have a softmax activation function as the last layer, spitting out probabilities
return model.layers[-1].g(xs[-1])
return ham, fwd
Registry
We create very simple helper functions to instantiate HAMs with particular architectural choices. Inspired by timm
.
A HAM is a fundamentally general purpose architecture. It is a general-purpose Associative Memory – it is up to the user to extract the desired information from the system. Hence, every registered model must return the ham
architecture and a fwd
function that accomplishes a task from that architecture
The Registry
named_partial
named_partial (f, *args, new_name=None, order=None, **kwargs)
Like functools.partial
but also copies over function name and docstring.
If new_name is not None, use that as the name
create_model
create_model (mname:str, *args, **kwargs)
Retrieve the model name from all registered models, passing args
and kwargs
to the factory function
Type | Details | |
---|---|---|
mname | str | Retrieve this stored model name |
args | ||
kwargs |
register_model
register_model (fgen:Callable)
Register a function that returns a model configuration factory function. The name of the function acts as the retrieval key and must be unique across models
Type | Details | |
---|---|---|
fgen | typing.Callable | Function that returns a HAM with desired config |
We can now register a model as follows:
The model that we just created comes with a default function that predicts label probabilities after 4 steps (though feel free to write any function to extract a layer state/activation at any point in time).
= (32,32); bs = 12
img_shape = create_model("example_classical_hn", img_shape=img_shape, label_shape=(10,))
model, fwd
= model.init_states_and_params(jax.random.PRNGKey(0))
_, model = jnp.ones((bs, *img_shape))
x = fwd(model, x); probs.shape probs
2022-12-13 16:24:58.481170: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(12, 10)
For the simple pipeline of classification, our fwd
pipelines are quite similar. We therefore create some helper functions to use throughout the rest of our model configuration.
fwd_conv
fwd_conv (model:hamux.ham.HAM, x:jax.Array, depth:int, dt:float, rng:Optional[jax.Array]=None)
Where the image input is kept as a 3 channel image
Type | Default | Details | |
---|---|---|---|
model | HAM | HAM where layer[0] is the image input and layer[-1] are the labels | |
x | Array | Starting point for clamped layer[0] | |
depth | int | Number of iterations for which to run the model | |
dt | float | Step size through time | |
rng | typing.Optional[jax.Array] | None | If provided, initialize states to random instead of 0 |
fwd_vec
fwd_vec (model:hamux.ham.HAM, x:jax.Array, depth:int, dt:float, rng:Optional[jax.Array]=None)
Where the image input is vectorized
Type | Default | Details | |
---|---|---|---|
model | HAM | HAM where layer[0] is the image input and layer[-1] are the labels | |
x | Array | Starting point for clamped layer[0] | |
depth | int | Number of iterations for which to run the model | |
dt | float | Step size through time | |
rng | typing.Optional[jax.Array] | None | If provided, initialize states to random instead of 0 |
simple_fwd
simple_fwd (model:hamux.ham.HAM, x:jax.Array, depth:int, dt:float, rng:Optional[jax.Array]=None)
A simple version of the forward function for showing in the paper.
All time constants tau
are set to be 1 in our architecture, but this is variable
Type | Default | Details | |
---|---|---|---|
model | HAM | HAM where layer[0] is the image input and layer[-1] are the labels | |
x | Array | Starting point for clamped layer[0] | |
depth | int | Number of iterations for which to run the model | |
dt | float | Step size through time | |
rng | typing.Optional[jax.Array] | None | If provided, initialize states to random instead of 0 |
Model Registry
2 Layer HN
hn_softmax_cifar
hn_softmax_cifar (nhid=6000)
hn_softmax_mnist
hn_softmax_mnist (nhid=1000)
hn_repu5_cifar
hn_repu5_cifar (nhid=6000)
Vectorized DAM on flattened CIFAR
hn_repu5_mnist
hn_repu5_mnist (nhid=1000)
Vectorized DAM on flattened MNIST
hn_relu_cifar
hn_relu_cifar (nhid:int=6000)
Vectorized HN on flattened CIFAR10
Type | Default | Details | |
---|---|---|---|
nhid | int | 6000 | Number of units in the single hidden layer |
hn_relu_mnist
hn_relu_mnist (nhid:int=1000)
Vectorized HN on flattened MNIST
Type | Default | Details | |
---|---|---|---|
nhid | int | 1000 | Number of units in the single hidden layer |
hn
hn (hidden_lagrangian:treex.module.Module, img_shape:Tuple, label_shape:Tuple, nhid:int=1000, do_norm:bool=False)
Create a Classical Hopfield Network that is intended to be applied on vectorized inputs
Type | Default | Details | |
---|---|---|---|
hidden_lagrangian | Module | ||
img_shape | typing.Tuple | Shape of image input to model | |
label_shape | typing.Tuple | Shape of label probabilities,typically (NLABELS,) | |
nhid | int | 1000 | Number of units in hidden layer |
do_norm | bool | False | If provided, enforce that all weights are standardized |
These models can now be instantiated by their strings:
= jnp.ones((1,3, 32,32)) # Per pytorch convention, CHW
xcifar = jnp.ones((1,1,28,28)) # Per pytorch convention, CHW
xmnist
= create_model("hn", hmx.lagrangians.LExp(), (32,32,3), (10,))
exhn, exhn_fwd = exhn.init_states_and_params(jax.random.PRNGKey(22))
_, exhn exhn_fwd(exhn, xcifar)
Array([[0.31536135, 0.08483113, 0.0897951 , 0.02981309, 0.03241062,
0.03071734, 0.03666373, 0.00281298, 0.03433144, 0.3432633 ]], dtype=float32)
Simple Convolution
conv_ham_maxpool_cifar
conv_ham_maxpool_cifar (nhid=1000)
conv_ham_avgpool_cifar
conv_ham_avgpool_cifar (nhid=1000)
conv_ham_maxpool_mnist
conv_ham_maxpool_mnist (nhid=1000)
conv_ham_avgpool_mnist
conv_ham_avgpool_mnist (nhid=1000)
conv_ham
conv_ham (s1, s2, s3, pool_type, nhid=1000)
= create_model("conv_ham_avgpool_cifar")
model, fwd = model.init_states_and_params(jax.random.PRNGKey(0))
_, model fwd(model, xcifar)
Energy Version of Attention
We now introduce a simple model for energy-based attention
energy_attn_cifar
energy_attn_cifar ()
energy_attn_mnist
energy_attn_mnist ()
energy_attn
energy_attn (s1, s2, nheads_self, nheads_cross)