(Hyper)Synapses

Using Modern Deep Learning operations with Energy.
NameError: name 'HTML' is not defined

Hypersynapses are the only way that neuron layers can communicate with each other. This generalization of the pairwise synapse from the Hopfield paradigm is now more general than ever before:

Before HAMUX
Synapses connect only one neuron layer to another Hypersynapses can connect arbitrary numbers of layers
Synapses are simple matrix multiplications Hypersynapses can be almost any operation, e.g., convolutions, pooling, attention, \(\ldots\)
Synapses are shallow Hypersynapses can be deep! E.g., a sequence of convolutions, pooling, and activation functions

At its core, a hypersynapse’s energy is completely defined by its alignment function \(\mathcal{F}\) that converts any number of layer activations \((g^1, g^2, \ldots)\) into a scalar describing its alignment:

\[ E_{\text{synapse}} = -\mathcal{F},\ \ \ \ \text{where}\ \ \ \ \mathcal{F} (g^1, g^2, \ldots) \mapsto \mathbb{R}. \]

The hypersynapse’s energy is typically HIGH when all connected layers are “incongruous” and LOW when all connected layers are “aligned” as defined by its operation.

All Synapses conform to the following simple API. Just define a __call__ function to describe the scalar alignment of different activations. Energy is calculated for you.

class Synapse(tx.Module, ABC):
    """The simple interface class for any synapse. Define an alignment function through `__call__` that returns a scalar.

    The energy is simply the negative of this function.
    """

    def energy(self, *gs):
        return -self(*gs)

    @abstractmethod
    def __call__(self, *gs):
        """The alignment function of a synapse"""
        pass

source

Synapse

 Synapse (name:Optional[str]=None)

The simple interface class for any synapse. Define an alignment function through __call__ that returns a scalar.

The energy is simply the negative of this function.


source

Synapse.energy

 Synapse.energy (*gs)

Dense Synapse

The simplest of synapses is a dense alignment synapse. In feedforward networks, dense operations take an input and return an output. In HAMUX, dense operations align the activations \(g^1 \in \mathbb{R}^{D_1}\) and \(g^2 \in \mathbb{R}^{D_2}\) as follows:

\[\mathcal{F}_\text{dense} = g^1_i W_{ij} g^2_j\]

And would be implemented as follows:

class SimpleDenseSynapse(Synapse):
    """The simplest of dense synapses that connects two layers (with vectorized activations) together"""
    W: jnp.ndarray = tx.Parameter.node() # treex's preferred way of declaring an attribute as a parameter
    def __call__(self, g1, g2):
        if self.initializing():
            self.W = nn.initializers.normal(0.02)(tx.next_key(), g1.shape + g2.shape)
        return g1 @ self.W @ g2

source

SimpleDenseSynapse

 SimpleDenseSynapse (name:Optional[str]=None)

The simplest of dense synapses that connects two layers (with vectorized activations) together

g1 = jnp.ones(4, dtype=jnp.float32); g2 = jnp.ones(5, dtype=jnp.float32)
syn = SimpleDenseSynapse().init(jax.random.PRNGKey(0), (g1, g2))
syn(g1, g2)
2022-12-02 01:13:30.993496: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray(-0.09093504, dtype=float32)

When building HAMs in practice, we typically want to follow this pattern: subclass the minimal Synapse class and overwrite the Synapse.__call__ method with our desired alignment function

We extend this simple concept into a more robust synapse that can linearly connect \(>2\) layers and optionally flattens layer activations.


source

DenseSynapse

 DenseSynapse (stdinit:float=0.02, flatten_args=True)

A dense synapse that aligns the representations of any number of gs.

The one learnable parameter W is a tensor with a dimension for each connected layer. In the case of 2 layers, this is the traditional learnable matrix synapse. In cases N>2 layers this is a new kind of layer where the learnable parameter is an N dimensional tensor.

By default, this will flatten all inputs as needed to treat all activations as vectors.

The number of layers we can align with this synapse is capped at the number of ranks that JAX stores (<255), but you’ll probably run out of memory first..

g1 = jnp.ones(4, dtype=jnp.float32); g2 = jnp.ones(5, dtype=jnp.float32); g3 = jnp.ones(6, dtype=jnp.float32)
syn = DenseSynapse().init(jax.random.PRNGKey(0), (g1, g2, g3))
syn(g1, g2, g3)
assert syn.W.shape == (4,5,6)

We can even implement a DenseSynapse with a hidden layer (Lagrangian) inside the alignment function. This is how we can implement layers that do not need to hold state through time.

/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/treeo/api.py:268: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
  flat, treedef = jax.tree_flatten(obj)
/opt/hostedtoolcache/Python/3.9.16/x64/lib/python3.9/site-packages/treeo/api.py:270: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
  obj = jax.tree_unflatten(treedef, flat)

source

DenseTensorSynapseWithHiddenLayer

 DenseTensorSynapseWithHiddenLayer (nhid:int, num_heads=1,
                                    stdinit:float=0.02, hidden_lagrangian:
                                    treex.module.Module=LRelu {     name:
                                    "lrelu",    str })

A generalized DenseTensorSynapse that has a hidden lagrangian (non-linearity).

We can specify a Lagrangian non-linearity for the hidden neuron layer with tau=0 and shape (nhid,).

The lagrangian can have its own learnable parameters, for example:

from hamux.core.lagrangians import *

syn = DenseTerminusSynapse(20, hidden_lagrangian=LSoftmax(beta_init=0.2)).init(jax.random.PRNGKey(0), tuple(gs))

source

DenseMatrixSynapseWithHiddenLayer

 DenseMatrixSynapseWithHiddenLayer (nhid:int, stdinit:float=0.02,
                                    hidden_lagrangian:treex.module.Module=
                                    LRelu {     name: "lrelu",    str },
                                    do_ravel=True, do_norm=False)

A modified DenseSynapse that has a hidden lagrangian (non-linearity).

We can specify a Lagrangian non-linearity for the hidden neuron layer with tau=0 and shape (nhid,).

Unlike the DenseTensorSynapseWithHiddenLayer, treat layers as if they are concatenated on the same visible layer dimension instead of giving each its own dimension of the tensor space.


source

ConvSynapse

 ConvSynapse (kernel_size:Union[int,Iterable[int]],
              strides:Optional[Iterable[int]]=None,
              padding:Union[str,Iterable[Tuple[int,int]]]='SAME',
              input_dilation:Optional[Iterable[int]]=None,
              kernel_dilation:Optional[Iterable[int]]=None,
              feature_group_count:int=1, use_bias:bool=True,
              dtype:Any=<class 'jax.numpy.float32'>,
              param_dtype:Any=<class 'jax.numpy.float32'>,
              precision:Any=None,
              kernel_init:Callable[[Any,Tuple[int,...],Any],Any]=<function
              init>,
              bias_init:Callable[[Any,Tuple[int,...],Any],Any]=<function
              zeros>)

A convolutional, binary synapse. Can automatically detect the number of output features from the 2 layers it connects

Or contain pooling


source

ConvSynapseWithPool

 ConvSynapseWithPool (kernel_size:Union[int,Iterable[int]],
                      pool_window=(5, 5), pool_stride=(2, 2),
                      pool_type='avg',
                      strides:Optional[Iterable[int]]=None,
                      padding:Union[str,Iterable[Tuple[int,int]]]='SAME',
                      input_dilation:Optional[Iterable[int]]=None,
                      kernel_dilation:Optional[Iterable[int]]=None,
                      feature_group_count:int=1, use_bias:bool=True,
                      dtype:Any=<class 'jax.numpy.float32'>,
                      param_dtype:Any=<class 'jax.numpy.float32'>,
                      precision:Any=None, kernel_init:Callable[[Any,Tuple[
                      int,...],Any],Any]=<function init>, bias_init:Callab
                      le[[Any,Tuple[int,...],Any],Any]=<function zeros>)

A convolutional, binary synapse. Can automatically detect the number of output features from the 2 layers it connects

We can also create synapses that model attention operations in modern networks


source

BinaryMixerSynapse

 BinaryMixerSynapse (num_heads:int=1, zspace_dim:int=64,
                     stdinit:float=0.02,
                     hidden_lagrangian:treex.module.Module=LSoftmax {
                     axis: -1,     beta: 1.0,           Parameter
                     min_beta: 1e-06,     name: "lsoftmax",    str })

A generalized binary synapse of quadratic order. This synapse is very similar to the Attention synapse but uses a single weight matrix instead of a query and key matrix.

We can specify any Lagrangian non-linearity for the hidden neuron layer (which operates with tau=0), but we default to the Softmax Lagrangian.


source

SelfAttentionSynapse

 SelfAttentionSynapse (num_heads:int=1, zspace_dim:int=64,
                       stdinit:float=0.02,
                       hidden_lagrangian:treex.module.Module=LSoftmax {
                       axis: -1,     beta: 1.0,           Parameter
                       min_beta: 1e-06,     name: "lsoftmax",    str },
                       do_qk_norm:bool=False)

A special case of the AttentionSynapse where both inputs are of the same layer


source

AttentionSynapse

 AttentionSynapse (num_heads:int=1, zspace_dim:int=64, stdinit:float=0.02,
                   hidden_lagrangian:treex.module.Module=LSoftmax {
                   axis: -1,     beta: 1.0,           Parameter
                   min_beta: 1e-06,     name: "lsoftmax",    str },
                   do_qk_norm:bool=False)

A generalized synapse of quadratic order, whose update rule looks very similar to the Attention operation of Transformers.

We can specify any Lagrangian non-linearity for the hidden neuron layer (which operates with tau=0), but we default to the Softmax Lagrangian.

To replicate similar configuration to the famous BERT-base models and “Attention is all you need” paper:

zspace = 64
syn = AttentionSynapse(zspace_dim=zspace, num_heads=12, hidden_lagrangian=LSoftmax(beta=1/jnp.sqrt(zspace)))

Connecting two layers of shapes: ((Nq, Dq), (Nk, Dk)) and layernorm lagrangians