Lagrangians

The well-behaved energy of associative memories is captured by the Lagrangian of the neurons

We begin with the lagrangian, the fundamental unit of any neuron layer.

Functional interface

The Lagrangian functions are central to understanding the energy that neuron layers provide to an associative memory, and they can be thought of as the integrand of common activation functions (e.g., relus and softmaxes). All lagrangians are functions of the form:

\[\mathcal{L}(x;\ldots) \mapsto \mathbb{R}\]

where \(x \in \mathbb{R}^{D_1 \times \ldots \times D_n}\) can be a tensor of arbitrary shape and \(\mathcal{L}\) can be optionally parameterized (e.g., the LayerNorm’s learnable bias and scale). It is important that our Lagrangians be convex and differentiable.

We want to rely on JAX’s autograd to automatically differentiate our Lagrangians into activation functions. For certain Lagrangians, the naively autodiff-ed function of the defined Lagrangian is numerically unstable (e.g., lagr_sigmoid(x).sum() and lagr_tanh(x).sum()). In these cases, we follow JAX’s documentation guidelines to define custom_jvps to fix this behavior.


source

lagr_identity

 lagr_identity (x)

The Lagrangian whose activation function is simply the identity.

2022-12-07 14:06:54.891743: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


source

lagr_repu

 lagr_repu (x, n)

Rectified Power Unit of degree n

Details
x
n Degree of the polynomial in the power unit

\[\frac{1}{n} \max(x,0)^n\]


source

lagr_relu

 lagr_relu (x)

Rectified Linear Unit. Same as repu of degree 2

\[\frac{1}{2} \max(x,0)^2\]


source

lagr_softmax_unstable

 lagr_softmax_unstable (x, beta:float=1.0, axis:int=-1)

The lagrangian of the softmax – the logsumexp. However, code the log(sum(exp(...))) part manually, which can lead to instabilities

The benefit is in porting HAMUX models into the browser, because TensorFlowJS does not currently support JAX logsumexp

Type Default Details
x
beta float 1.0 Inverse temperature
axis int -1 Dimension over which to apply logsumexp

source

lagr_softmax

 lagr_softmax (x, beta:float=1.0, axis:int=-1)

The lagrangian of the softmax – the logsumexp

Type Default Details
x
beta float 1.0 Inverse temperature
axis int -1 Dimension over which to apply logsumexp

\[\frac{1}{\beta} \log \sum\limits_i \exp(\beta x)\]

We do not plot the logsumexp because it has an implicit summation


source

lagr_exp

 lagr_exp (x, beta:float=1.0)

Exponential activation function, as in Demicirgil et al.. Operates elementwise

Type Default Details
x
beta float 1.0 Inverse temperature

\[ \frac{1}{\beta} \exp(\beta x)\]

1.2130613
0.36787945
0.14875345


source

lagr_rexp

 lagr_rexp (x, beta:float=1.0)

Rectified exponential activation function

Type Default Details
x
beta float 1.0 Inverse temperature

\[\frac{1}{\beta} \exp(\beta \bar{x}) - \bar{x} \ \ \ \ \text{where} \ \ \bar{x} = \max(x,0)\]

The lagrangian of the tanh and the sigmoid are a bit more numerically unstable. We will need to define custom gradients for them. We show how this is done for the tanh case to forward to gradient compute to the optimized jnp.tanh function:

@jax.custom_jvp
def _lagr_tanh(x, beta=1.0):
    return 1 / beta * jnp.log(jnp.cosh(beta * x))

@_lagr_tanh.defjvp
def _lagr_tanh_defjvp(primals, tangents):
    x, beta = primals
    x_dot, beta_dot = tangents
    primal_out = _lagr_tanh(x, beta)
    tangent_out = jnp.tanh(beta * x) * x_dot
    return primal_out, tangent_out

def lagr_tanh(x, 
              beta=1.0): # Inverse temperature
    """Lagrangian of the tanh activation function"""
    return _lagr_tanh(x, beta)

source

lagr_tanh

 lagr_tanh (x, beta=1.0)

Lagrangian of the tanh activation function

Type Default Details
x
beta float 1.0 Inverse temperature

We define a similar custom JVP for the sigmoid, but its interface is simple.


source

lagr_sigmoid

 lagr_sigmoid (x, beta=1.0, scale=1.0)

The lagrangian of the sigmoid activation function

Type Default Details
x
beta float 1.0 Inverse temperature
scale float 1.0 Amount to stretch the range of the sigmoid’s lagrangian


source

lagr_layernorm

 lagr_layernorm (x:jax.Array, gamma:float=1.0,
                 delta:Union[float,jax.Array]=0.0, axis=-1, eps=1e-05)

Lagrangian of the layer norm activation function

Type Default Details
x Array
gamma float 1.0 Scale the stdev
delta typing.Union[float, jax.Array] 0.0 Shift the mean
axis int -1 Which axis to normalize
eps float 1e-05 Prevent division by 0

source

lagr_spherical_norm

 lagr_spherical_norm (x:jax.Array, gamma:float=1.0,
                      delta:Union[float,jax.Array]=0.0, axis=-1,
                      eps=1e-05)

Lagrangian of the spherical norm activation function

Type Default Details
x Array
gamma float 1.0 Scale the stdev
delta typing.Union[float, jax.Array] 0.0 Shift the mean
axis int -1 Which axis to normalize
eps float 1e-05 Prevent division by 0

Parameterized Lagrangians

It is beneficial to consider lagrangians as modules with their own learnable parameters.


source

LSphericalNorm

 LSphericalNorm (gamma=1.0, delta=0.0, eps=1e-05)

Reduced Lagrangian whose activation function is the spherical L2 norm

Parameterized by (gamma, delta), a scale and a shift

Type Default Details
gamma float 1.0 Inverse temperature, for the sharpness of the exponent
delta float 0.0
eps float 1e-05

source

LLayerNorm

 LLayerNorm (gamma=1.0, delta=0.0, eps=1e-05)

Reduced Lagrangian whose activation function is the layer norm

Parameterized by (gamma, delta), a scale and a shift

Type Default Details
gamma float 1.0 Inverse temperature, for the sharpness of the exponent
delta float 0.0
eps float 1e-05

source

LTanh

 LTanh (beta=1.0, min_beta=1e-06)

Reduced Lagrangian whose activation function is the tanh

Parameterized by (beta)

Type Default Details
beta float 1.0 Inverse temperature, for the sharpness of the exponent
min_beta float 1e-06 Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.

source

LRexp

 LRexp (beta=1.0, min_beta=1e-06)

Reduced Lagrangian whose activation function is the rectified exponential function

Parameterized by (beta)

Type Default Details
beta float 1.0 Inverse temperature, for the sharpness of the exponent
min_beta float 1e-06 Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.

source

LExp

 LExp (beta=1.0, min_beta=1e-06)

Reduced Lagrangian whose activation function is the exponential function

Parameterized by (beta)

Type Default Details
beta float 1.0 Inverse temperature, for the sharpness of the exponent
min_beta float 1e-06 Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.

source

LSoftmaxUnstable

 LSoftmaxUnstable (beta=1.0, axis=-1, min_beta=1e-06)

Reduced Lagrangian whose activation function is the softmax, using manually coded logsumexp for browser compatibility

Parameterized by (beta)

Type Default Details
beta float 1.0 Inverse temperature
axis int -1 Axis over which to apply the softmax
min_beta float 1e-06 Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.

source

LSoftmax

 LSoftmax (beta=1.0, axis=-1, min_beta=1e-06)

Reduced Lagrangian whose activation function is the softmax

Parameterized by (beta)

Type Default Details
beta float 1.0 Inverse temperature
axis int -1 Axis over which to apply the softmax
min_beta float 1e-06 Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.

source

LSigmoid

 LSigmoid (beta=1.0, scale=1.0, min_beta=1e-06)

Reduced Lagrangian whose activation function is the sigmoid

Parameterized by (beta)

Type Default Details
beta float 1.0 Inverse temperature
scale float 1.0 Amount to stretch the sigmoid.
min_beta float 1e-06 Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.

source

LRelu

 LRelu ()

Reduced Lagrangian whose activation function is the rectified linear unit


source

LRepu

 LRepu (n=2.0)

Reduced Lagrangian whose activation function is the rectified polynomial unit of specified degree n

Type Default Details
n float 2.0 The degree of the RePU. By default, set to the ReLU configuration

source

LIdentity

 LIdentity ()

Reduced Lagrangian whose activation function is the identity function

The modules can be initialized and used as follows:

lag = LRepu(n=4)
x = jax.random.normal(jax.random.PRNGKey(1), (7,))
lag(x)
Array(8.217218, dtype=float32)

To apply the lagrangian to a batch of samples, we should take advantage of jax.vmap

lag = jax.vmap(LRepu(n=4)); batch_size = 4
x = jax.random.normal(jax.random.PRNGKey(1), (batch_size,7))
lag(x)
Array([0.79066116, 0.634871  , 0.13261071, 3.3233747 ], dtype=float32)