NameError: name 'HTML' is not defined
(Hyper)Synapses
Hypersynapses are the only way that neuron layers can communicate with each other. This generalization of the pairwise synapse from the Hopfield paradigm is now more general than ever before:
Before | HAMUX |
---|---|
Synapses connect only one neuron layer to another | Hypersynapses can connect arbitrary numbers of layers |
Synapses are simple matrix multiplications | Hypersynapses can be almost any operation, e.g., convolutions, pooling, attention, \(\ldots\) |
Synapses are shallow | Hypersynapses can be deep! E.g., a sequence of convolutions, pooling, and activation functions |
At its core, a hypersynapse’s energy is completely defined by its alignment function \(\mathcal{F}\) that converts any number of layer activations \((g^1, g^2, \ldots)\) into a scalar describing its alignment:
\[ E_{\text{synapse}} = -\mathcal{F},\ \ \ \ \text{where}\ \ \ \ \mathcal{F} (g^1, g^2, \ldots) \mapsto \mathbb{R}. \]
The hypersynapse’s energy is typically HIGH when all connected layers are “incongruous” and LOW when all connected layers are “aligned” as defined by its operation.
All Synapse
s conform to the following simple API. Just define a __call__
function to describe the scalar alignment of different activations. Energy is calculated for you.
class Synapse(tx.Module, ABC):
"""The simple interface class for any synapse. Define an alignment function through `__call__` that returns a scalar.
The energy is simply the negative of this function.
"""
def energy(self, *gs):
return -self(*gs)
@abstractmethod
def __call__(self, *gs):
"""The alignment function of a synapse"""
pass
Synapse
Synapse (name:Optional[str]=None)
The simple interface class for any synapse. Define an alignment function through __call__
that returns a scalar.
The energy is simply the negative of this function.
Synapse.energy
Synapse.energy (*gs)
Dense Synapse
The simplest of synapses is a dense alignment synapse. In feedforward networks, dense operations take an input and return an output. In HAMUX, dense operations align the activations \(g^1 \in \mathbb{R}^{D_1}\) and \(g^2 \in \mathbb{R}^{D_2}\) as follows:
\[\mathcal{F}_\text{dense} = g^1_i W_{ij} g^2_j\]
And would be implemented as follows:
class SimpleDenseSynapse(Synapse):
"""The simplest of dense synapses that connects two layers (with vectorized activations) together"""
= tx.Parameter.node() # treex's preferred way of declaring an attribute as a parameter
W: jnp.ndarray def __call__(self, g1, g2):
if self.initializing():
self.W = nn.initializers.normal(0.02)(tx.next_key(), g1.shape + g2.shape)
return g1 @ self.W @ g2
SimpleDenseSynapse
SimpleDenseSynapse (name:Optional[str]=None)
The simplest of dense synapses that connects two layers (with vectorized activations) together
= jnp.ones(4, dtype=jnp.float32); g2 = jnp.ones(5, dtype=jnp.float32)
g1 = SimpleDenseSynapse().init(jax.random.PRNGKey(0), (g1, g2))
syn syn(g1, g2)
2022-12-02 01:13:30.993496: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray(-0.09093504, dtype=float32)
When building HAMs in practice, we typically want to follow this pattern: subclass the minimal Synapse
class and overwrite the Synapse.__call__
method with our desired alignment function
We extend this simple concept into a more robust synapse that can linearly connect \(>2\) layers and optionally flattens layer activations.
DenseSynapse
DenseSynapse (stdinit:float=0.02, flatten_args=True)
A dense synapse that aligns the representations of any number of gs
.
The one learnable parameter W
is a tensor with a dimension for each connected layer. In the case of 2 layers, this is the traditional learnable matrix synapse. In cases N>2
layers this is a new kind of layer where the learnable parameter is an N dimensional tensor.
By default, this will flatten all inputs as needed to treat all activations as vectors.
The number of layers we can align with this synapse is capped at the number of ranks that JAX stores (<255), but you’ll probably run out of memory first..
= jnp.ones(4, dtype=jnp.float32); g2 = jnp.ones(5, dtype=jnp.float32); g3 = jnp.ones(6, dtype=jnp.float32)
g1 = DenseSynapse().init(jax.random.PRNGKey(0), (g1, g2, g3))
syn
syn(g1, g2, g3)assert syn.W.shape == (4,5,6)
We can even implement a DenseSynapse
with a hidden layer (Lagrangian) inside the alignment function. This is how we can implement layers that do not need to hold state through time.
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/treeo/api.py:268: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
flat, treedef = jax.tree_flatten(obj)
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/treeo/api.py:270: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
obj = jax.tree_unflatten(treedef, flat)
ConvSynapse
ConvSynapse (kernel_size:Union[int,Iterable[int]], strides:Optional[Iterable[int]]=None, padding:Union[str,Iterable[Tuple[int,int]]]='SAME', input_dilation:Optional[Iterable[int]]=None, kernel_dilation:Optional[Iterable[int]]=None, feature_group_count:int=1, use_bias:bool=True, dtype:Any=<class 'jax.numpy.float32'>, param_dtype:Any=<class 'jax.numpy.float32'>, precision:Any=None, kernel_init:Callable[[Any,Tuple[int,...],Any],Any]=<function init>, bias_init:Callable[[Any,Tuple[int,...],Any],Any]=<function zeros>)
A convolutional, binary synapse. Can automatically detect the number of output features from the 2 layers it connects
Or contain pooling
ConvSynapseWithPool
ConvSynapseWithPool (kernel_size:Union[int,Iterable[int]], pool_window=(5, 5), pool_stride=(2, 2), pool_type='avg', strides:Optional[Iterable[int]]=None, padding:Union[str,Iterable[Tuple[int,int]]]='SAME', input_dilation:Optional[Iterable[int]]=None, kernel_dilation:Optional[Iterable[int]]=None, feature_group_count:int=1, use_bias:bool=True, dtype:Any=<class 'jax.numpy.float32'>, param_dtype:Any=<class 'jax.numpy.float32'>, precision:Any=None, kernel_init:Callable[[Any,Tuple[ int,...],Any],Any]=<function init>, bias_init:Callab le[[Any,Tuple[int,...],Any],Any]=<function zeros>)
A convolutional, binary synapse. Can automatically detect the number of output features from the 2 layers it connects
We can also create synapses that model attention operations in modern networks
BinaryMixerSynapse
BinaryMixerSynapse (num_heads:int=1, zspace_dim:int=64, stdinit:float=0.02, hidden_lagrangian:treex.module.Module=LSoftmax { axis: -1, beta: 1.0, Parameter min_beta: 1e-06, name: "lsoftmax", str })
A generalized binary synapse of quadratic order. This synapse is very similar to the Attention synapse but uses a single weight matrix instead of a query and key matrix.
We can specify any Lagrangian non-linearity for the hidden neuron layer (which operates with tau=0), but we default to the Softmax Lagrangian.
SelfAttentionSynapse
SelfAttentionSynapse (num_heads:int=1, zspace_dim:int=64, stdinit:float=0.02, hidden_lagrangian:treex.module.Module=LSoftmax { axis: -1, beta: 1.0, Parameter min_beta: 1e-06, name: "lsoftmax", str }, do_qk_norm:bool=False)
A special case of the AttentionSynapse where both inputs are of the same layer
AttentionSynapse
AttentionSynapse (num_heads:int=1, zspace_dim:int=64, stdinit:float=0.02, hidden_lagrangian:treex.module.Module=LSoftmax { axis: -1, beta: 1.0, Parameter min_beta: 1e-06, name: "lsoftmax", str }, do_qk_norm:bool=False)
A generalized synapse of quadratic order, whose update rule looks very similar to the Attention operation of Transformers.
We can specify any Lagrangian non-linearity for the hidden neuron layer (which operates with tau=0), but we default to the Softmax Lagrangian.
To replicate similar configuration to the famous BERT-base models and “Attention is all you need” paper:
zspace = 64
syn = AttentionSynapse(zspace_dim=zspace, num_heads=12, hidden_lagrangian=LSoftmax(beta=1/jnp.sqrt(zspace)))
Connecting two layers of shapes: ((Nq, Dq), (Nk, Dk)) and layernorm lagrangians