HAM
We have now provided the two primary components: layers
and synapses
. This module assembles those together into a single network that is governed by an energy function
HAM Basics
We connect layers
and synapses
in a hypergraph to describe the energy function. A hypergraph is a generalization of the familiar graph in that edges (synapses) can connect multiple nodes (layers). This graph – complete with the operations of the synapses and the activation behavior of the layers – fully defines the energy function for a given collection of neuron states.
HAM
HAM (layers, synapses, connections)
A Tree class with all Mixin Extensions. Base class for all Treex classes.
HAM.activations
HAM.activations (xs:jax.Array)
Turn a collection of states into a collection of activations
Type | Details | |
---|---|---|
xs | Array | Collection of states for each layer |
HAM.energy
HAM.energy (xs:jax.Array)
The complete energy of the HAM
Type | Details | |
---|---|---|
xs | Array | Collection of states for each layer |
HAM.synapse_energy
HAM.synapse_energy (gs:jax.Array)
The total contribution of the synapses’ contribution to the energy of the HAM.
A function of the activations gs
rather than the states xs
Type | Details | |
---|---|---|
gs | Array | Collection of activations of each layer |
HAM.layer_energy
HAM.layer_energy (xs:jax.Array)
The total contribution of the layers’ contribution to the energy of the HAM
Type | Details | |
---|---|---|
xs | Array | Collection of states for each layer |
As is typical for JAX frameworks, the parameters of HAMs need to be initialized. Unlike other machine learning libraries, the states of each layer – that is, the dynamical variables of our system – also need to be initialized. The notation \(\mathbf{x}\) indicates the collection of all states from each layer, and \(x^\alpha\) indicates that we are referring to the state of layer at index \(\alpha\) in our collection.
We provide this functionality with the following helper functions:
HAM.init_states_and_params
HAM.init_states_and_params (param_key, bs=None, state_key=None)
Initialize the states and parameters of every layer and synapse in the network
Type | Default | Details | |
---|---|---|---|
param_key | RNG seed for random initialization of the parameters | ||
bs | NoneType | None | Batch size of the states to initialize, if needed |
state_key | NoneType | None | RNG seed for random initialization of the states, if non-zero initializations are desired |
HAM.init_states
HAM.init_states (bs=None, rng=None)
Initialize the states of every layer in the network
Type | Default | Details | |
---|---|---|---|
bs | NoneType | None | Batch size of the states to initialize, if needed |
rng | NoneType | None | RNG seed for random initialization of the states, if non-zero initializations are desired |
We build and test the following small 3 layer HAM network throughout this notebook
from hamux.layers import *
from hamux.synapses import *
= [
layers 2,)),
TanhLayer((3,)),
ReluLayer((4,)),
SoftmaxLayer((
]
= [
synapses
DenseSynapse(),
DenseSynapse(),
]
= [
connections 0,1], 0),
([1,2],1),
([
]= HAM(layers, synapses, connections)
ham
# Initialize states and parameters from specified layer shapes
= ham.init_states_and_params(jax.random.PRNGKey(0), state_key=jax.random.PRNGKey(1)); xs, ham
2022-12-02 01:13:47.509827: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
print("NLayers: ", ham.n_layers)
print("NSynapses: ", ham.n_synapses)
print("NConnections: ", ham.n_connections)
print("Taus of each layer: ", ham.layer_taus)
NLayers: 3
NSynapses: 2
NConnections: 2
Taus of each layer: [1.0, 1.0, 1.0]
The energy of a HAM is fully defined by the individual energies of its layers and synapses.
\[E_\text{HAM} = E_\text{Layers} + E_\text{Synapses}\]
Code
print("\n".join([f"x{i}: {x.shape}" for i,x in enumerate(xs)])) # The dynamic variables
print("\n".join([f"W{i}: {S.W.shape}" for i,S in enumerate(ham.synapses)])) # The dynamic variables
x0: (2,)
x1: (3,)
x2: (4,)
W0: (2, 3)
W1: (3, 4)
= ham.layer_energy(xs); E_L E_L
DeviceArray(0.071, dtype=float32)
= ham.activations(xs)
gs = ham.synapse_energy(gs); E_S E_S
DeviceArray(-0.033, dtype=float32)
+E_S) test_eq(ham.energy(xs), E_L
The update rule for each of the layer states is simply defined as follows:
\[\tau \frac{dx^\alpha}{dt} = -\frac{dE_\text{system}}{dg^\alpha}\]
JAX is wonderful. Autograd does this accurately and efficiently for us.
HAM.updates
HAM.updates (xs:jax.Array)
The negative of our dEdg, computing the update direction each layer should descend
Type | Details | |
---|---|---|
xs | Array | Collection of states for each layer |
HAM.dEdg
HAM.dEdg (xs:jax.Array)
Calculate the gradient of system energy wrt. the activations
Notice that we use an important mathematical property of the Legendre transform to take a mathematical shortcut, where dE_layer / dg = x
Let’s observe the energy descent in practice
@jax.jit
def step(ham, xs, dt):
= ham.energy(xs)
energy = ham.updates(xs)
updates = [dt / tau for tau in ham.layer_taus]
alphas = jtu.tree_map(lambda x, u, alpha: x + alpha*u, xs, updates, alphas)
next_states return energy, next_states
= []
energies = []
allx = ham.init_states(rng=jax.random.PRNGKey(5))
x for i in range(100):
= step(ham, x, 0.1)
energy, x
energies.append(energy) allx.append(x)
Code
= plt.subplots(1)
fig, ax # ax = axs[0]
ax.plot(np.stack(energies))"Energy")
ax.set_ylabel("Timesteps")
ax.set_xlabel("Energy over time for our 3-layer HAM. Random weights")
ax.set_title(
plt.show(fig)
= jnp.abs(jnp.diff(jnp.stack([x[0] for x in allx]))).mean(-1)
x0diffs = jnp.abs(jnp.diff(jnp.stack([x[1] for x in allx]))).mean(-1)
x1diffs = jnp.abs(jnp.diff(jnp.stack([x[2] for x in allx]))).mean(-1)
x2diffs # ax = axs[1]
= plt.subplots(1)
fig2, ax =-1))
ax.plot(np.stack([x0diffs, x1diffs, x2diffs], axis"Avg $|\Delta x|$ over Time")
ax.set_title("Time steps")
ax.set_xlabel("$\Delta x$")
ax.set_ylabel(
plt.show(fig) plt.show(fig2)
So our 3-layer HAM system with randomly initialized weights and states converges to a fixed energy, at which point the states of each layer do not further change.
We implement the above, simple step function into the class, though more advanced optimizations from the JAX ecosystem (e.g., optax) can easily be used.
HAM.step
HAM.step (xs:List[jax.Array], updates:List[jax.Array], dt:float=0.1, masks:Optional[List[jax.Array]]=None)
A discrete step down the energy using step size dt
scaled by the tau
of each layer
Type | Default | Details | |
---|---|---|---|
xs | typing.List[jax.Array] | Collection of current states for each layer | |
updates | typing.List[jax.Array] | Collection of update directions for each state | |
dt | float | 0.1 | Stepsize to take in direction of updates |
masks | typing.Optional[typing.List[jax.Array]] | None | Boolean mask, 0 if clamped neuron, and 1 elsewhere. A pytree identical to xs . Optional. |
It is particularly useful if all of these functions can be applied to a batched collection of states, something JAX makes particularly easy through its jax.vmap
functionality. We prefix vectorized versions of the above methods with a v
.
HAM.vupdates
HAM.vupdates (xs:List[jax.Array])
A vectorized version of updates
Type | Details | |
---|---|---|
xs | typing.List[jax.Array] | Collection of states for each layer |
HAM.vdEdg
HAM.vdEdg (xs:List[jax.Array])
A vectorized version of dEdg
Type | Details | |
---|---|---|
xs | typing.List[jax.Array] | Collection of states for each layer |
HAM.venergy
HAM.venergy (xs:List[jax.Array])
A vectorized version of energy
Type | Details | |
---|---|---|
xs | typing.List[jax.Array] | Collection of states for each layer |
HAM.vactivations
HAM.vactivations (xs:List[jax.Array])
A vectorized version of activations
Type | Details | |
---|---|---|
xs | typing.List[jax.Array] | Collection of states for each layer |
=5
batch_size= ham.init_states(bs=batch_size, rng=jax.random.PRNGKey(2)) vxs
Code
print("\n".join([f"x{i}: {x.shape}" for i,x in enumerate(xs)])) # The dynamic variables
x0: (2,)
x1: (3,)
x2: (4,)
We repeat the above energy analysis for batched samples. For an untrained network, all samples achieve the same energy
@jax.jit
def step(ham, xs, dt):
= ham.venergy(xs)
energy = ham.vupdates(xs)
updates = [dt / tau for tau in ham.layer_taus]
alphas = jtu.tree_map(lambda x, u, alpha: x + alpha*u, xs, updates, alphas)
next_states return energy, next_states
= []; allx = []; dt = 0.03; x = vxs
energies for i in range(100):
= step(ham, x, dt)
energy, x
energies.append(energy)
allx.append(x)
= jnp.stack(energies) senergies
Code
= plt.subplots(1)
fig,ax
ax.plot(senergies)f"Time step (dt={dt})")
ax.set_xlabel(f"Energy")
ax.set_ylabel(f"System energy over time.\nRandom initial states on 3-layer, untrained HAM")
ax.set_title(f"Sample {i}" for i in range(senergies.shape[-1])])
ax.legend([ plt.show(fig)
We create several helper functions to save and load our state dict.
HAM.load_ckpt
HAM.load_ckpt (ckpt_f:Union[str,pathlib.Path])
Load from file name
Type | Details | |
---|---|---|
ckpt_f | typing.Union[str, pathlib.Path] | Filename of checkpoint to load |
HAM.save_state_dict
HAM.save_state_dict (fname:Union[str,pathlib.Path], overwrite:bool=True)
Save the state dictionary for a HAM
Type | Default | Details | |
---|---|---|---|
fname | typing.Union[str, pathlib.Path] | Filename of checkpoint to save | |
overwrite | bool | True | Overwrite an existing file of the same name? |
HAM.to_state_dict
HAM.to_state_dict ()
Convert HAM to state dictionary of parameters and connections
HAM.load_state_dict
HAM.load_state_dict (state_dict:Any)
Load the state dictionary for a HAM
Type | Details | |
---|---|---|
state_dict | typing.Any | The dictionary of all parameters, saved by save_state_dict |
Visualizing a HAM
We employ the hypernetx
package to create a simple visualization of the layers and nodes
@patch
def visualize(self:HAM):
"""Return a simple hypergraph object of the connections."""
= [l.name for l in self.layers]
nodenames = [s.name for s in self.synapses]
edgenames = {f"{edgenames[syn_idx]} ({syn_idx})": [f"{nodenames[i]} ({i})" for i in layer_idxs] for layer_idxs, syn_idx in self.connections}
graph = hnetx.Hypergraph(graph, name=self.name)
H return H
Example Model Architectures
We can use our primitive visualization function to see the connections of our model. We showcase these only to express what is possible to build using our architectures. Hierarchical models have not yet been tested
Simple Hopfield Network
Consisting of one visible layer and one hidden layer
= [
layers 768,)), # Visible Layer
IdentityLayer((1000,)), # Hidden Layer
SoftmaxLayer((
]= [
synapses
DenseSynapse()
]= [
connections 0,1),0)
((
]
= HAM(layers, synapses, connections)
ham = ham.init_states_and_params(jax.random.PRNGKey(0))
_, ham hnetx.draw(ham.visualize())
Simple Convolutional Layer
= [
layers 32,32,3)), # Visible Layer
LayerNormLayer((16,16,128,)),
TanhLayer((8,8,256)),
SigmoidLayer((4,4,256)),
SoftmaxLayer((
]= [
synapses 2,2), strides=(2,2)),
ConvSynapse((2,2), strides=(2,2)),
ConvSynapse((2,2), strides=(2,2)),
ConvSynapse((
]
= [
connections 0,1),0),
((1,2),1),
((2,3),2)
((
]
= HAM(layers, synapses, connections)
ham = ham.init_states_and_params(jax.random.PRNGKey(0))
_, ham hnetx.draw(ham.visualize())
Energy Transformer (explicit Hopfield Module)
= [
layers 768,)), # Visible Layer
LayerNormLayer((3072,)), # Hidden layer of CHN
ReluLayer((
]
= [
synapses
DenseSynapse(),=3, zspace_dim=224)
AttentionSynapse(num_heads
]= [
connections 0,1), 0),
((0,0), 1),
((
]
= HAM(layers, synapses, connections)
ham = ham.init_states_and_params(jax.random.PRNGKey(0))
_, ham hnetx.draw(ham.visualize())
Energy Transformer (implicit Hopfield Module)
= [
layers 768,)), # Visible Layer
LayerNormLayer((
]
= [
synapses =3072),
DenseMatrixSynapseWithHiddenLayer(nhid=3, zspace_dim=224)
SelfAttentionSynapse(num_heads
]= [
connections 0,), 0),
((0,), 1),
((
]
= HAM(layers, synapses, connections)
ham = ham.init_states_and_params(jax.random.PRNGKey(0))
_, ham hnetx.draw(ham.visualize())
Attention and Convolutions
= [
layers 32,32,3)), # Visible Layer
LayerNormLayer((16,16,128,)),
TanhLayer((8,8,256)),
SigmoidLayer((4,4,256)),
SoftmaxLayer((
]= [
synapses 2,2), strides=(2,2)),
ConvSynapse((2,2), strides=(2,2)),
ConvSynapse((2,2), strides=(2,2)),
ConvSynapse((=3, zspace_dim=64),
AttentionSynapse(num_heads=3, zspace_dim=64),
AttentionSynapse(num_heads=3, zspace_dim=64),
AttentionSynapse(num_heads
]
= [
connections 0,1),0),
((1,2),1),
((2,3),2),
((0,2),3),
((1,3),4),
((0,3),5),
((
]
= HAM(layers, synapses, connections)
ham = ham.init_states_and_params(jax.random.PRNGKey(0))
_, ham hnetx.draw(ham.visualize())
N-ary synapses
= [
layers 32,32,3)), # Visible Layer
LayerNormLayer((16,16,128,)),
TanhLayer((8,8,256)),
SigmoidLayer((4,4,256)),
SoftmaxLayer((
]= [
synapses 2,2), strides=(2,2)),
ConvSynapse((2,2), strides=(2,2)),
ConvSynapse((2,2), strides=(2,2)),
ConvSynapse((=3, zspace_dim=64),
AttentionSynapse(num_heads200)
DenseMatrixSynapseWithHiddenLayer(
]
= [
connections 0,1),0),
((1,2),1),
((2,3),2),
((0,2),3),
((0,2,3),4)
((
]
= HAM(layers, synapses, connections)
ham = ham.init_states_and_params(jax.random.PRNGKey(0))
_, ham hnetx.draw(ham.visualize())