2022-12-07 14:06:54.891743: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Lagrangians
We begin with the lagrangian, the fundamental unit of any neuron layer.
Functional interface
The Lagrangian functions are central to understanding the energy that neuron layers provide to an associative memory, and they can be thought of as the integrand of common activation functions (e.g., relu
s and softmax
es). All lagrangians are functions of the form:
\[\mathcal{L}(x;\ldots) \mapsto \mathbb{R}\]
where \(x \in \mathbb{R}^{D_1 \times \ldots \times D_n}\) can be a tensor of arbitrary shape and \(\mathcal{L}\) can be optionally parameterized (e.g., the LayerNorm
’s learnable bias and scale). It is important that our Lagrangians be convex and differentiable.
We want to rely on JAX’s autograd to automatically differentiate our Lagrangians into activation functions. For certain Lagrangians, the naively autodiff-ed function of the defined Lagrangian is numerically unstable (e.g., lagr_sigmoid(x).sum()
and lagr_tanh(x).sum()
). In these cases, we follow JAX’s documentation guidelines to define custom_jvp
s to fix this behavior.
lagr_identity
lagr_identity (x)
The Lagrangian whose activation function is simply the identity.
lagr_repu
lagr_repu (x, n)
Rectified Power Unit of degree n
Details | |
---|---|
x | |
n | Degree of the polynomial in the power unit |
\[\frac{1}{n} \max(x,0)^n\]
lagr_relu
lagr_relu (x)
Rectified Linear Unit. Same as repu of degree 2
\[\frac{1}{2} \max(x,0)^2\]
lagr_softmax_unstable
lagr_softmax_unstable (x, beta:float=1.0, axis:int=-1)
The lagrangian of the softmax – the logsumexp. However, code the log(sum(exp(...)))
part manually, which can lead to instabilities
The benefit is in porting HAMUX models into the browser, because TensorFlowJS does not currently support JAX logsumexp
Type | Default | Details | |
---|---|---|---|
x | |||
beta | float | 1.0 | Inverse temperature |
axis | int | -1 | Dimension over which to apply logsumexp |
lagr_softmax
lagr_softmax (x, beta:float=1.0, axis:int=-1)
The lagrangian of the softmax – the logsumexp
Type | Default | Details | |
---|---|---|---|
x | |||
beta | float | 1.0 | Inverse temperature |
axis | int | -1 | Dimension over which to apply logsumexp |
\[\frac{1}{\beta} \log \sum\limits_i \exp(\beta x)\]
We do not plot the logsumexp
because it has an implicit summation
lagr_exp
lagr_exp (x, beta:float=1.0)
Exponential activation function, as in Demicirgil et al.. Operates elementwise
Type | Default | Details | |
---|---|---|---|
x | |||
beta | float | 1.0 | Inverse temperature |
\[ \frac{1}{\beta} \exp(\beta x)\]
1.2130613
0.36787945
0.14875345
lagr_rexp
lagr_rexp (x, beta:float=1.0)
Rectified exponential activation function
Type | Default | Details | |
---|---|---|---|
x | |||
beta | float | 1.0 | Inverse temperature |
\[\frac{1}{\beta} \exp(\beta \bar{x}) - \bar{x} \ \ \ \ \text{where} \ \ \bar{x} = \max(x,0)\]
The lagrangian of the tanh
and the sigmoid
are a bit more numerically unstable. We will need to define custom gradients for them. We show how this is done for the tanh
case to forward to gradient compute to the optimized jnp.tanh
function:
@jax.custom_jvp
def _lagr_tanh(x, beta=1.0):
return 1 / beta * jnp.log(jnp.cosh(beta * x))
@_lagr_tanh.defjvp
def _lagr_tanh_defjvp(primals, tangents):
= primals
x, beta = tangents
x_dot, beta_dot = _lagr_tanh(x, beta)
primal_out = jnp.tanh(beta * x) * x_dot
tangent_out return primal_out, tangent_out
def lagr_tanh(x,
=1.0): # Inverse temperature
beta"""Lagrangian of the tanh activation function"""
return _lagr_tanh(x, beta)
lagr_tanh
lagr_tanh (x, beta=1.0)
Lagrangian of the tanh activation function
Type | Default | Details | |
---|---|---|---|
x | |||
beta | float | 1.0 | Inverse temperature |
We define a similar custom JVP for the sigmoid, but its interface is simple.
lagr_sigmoid
lagr_sigmoid (x, beta=1.0, scale=1.0)
The lagrangian of the sigmoid activation function
Type | Default | Details | |
---|---|---|---|
x | |||
beta | float | 1.0 | Inverse temperature |
scale | float | 1.0 | Amount to stretch the range of the sigmoid’s lagrangian |
lagr_layernorm
lagr_layernorm (x:jax.Array, gamma:float=1.0, delta:Union[float,jax.Array]=0.0, axis=-1, eps=1e-05)
Lagrangian of the layer norm activation function
Type | Default | Details | |
---|---|---|---|
x | Array | ||
gamma | float | 1.0 | Scale the stdev |
delta | typing.Union[float, jax.Array] | 0.0 | Shift the mean |
axis | int | -1 | Which axis to normalize |
eps | float | 1e-05 | Prevent division by 0 |
lagr_spherical_norm
lagr_spherical_norm (x:jax.Array, gamma:float=1.0, delta:Union[float,jax.Array]=0.0, axis=-1, eps=1e-05)
Lagrangian of the spherical norm activation function
Type | Default | Details | |
---|---|---|---|
x | Array | ||
gamma | float | 1.0 | Scale the stdev |
delta | typing.Union[float, jax.Array] | 0.0 | Shift the mean |
axis | int | -1 | Which axis to normalize |
eps | float | 1e-05 | Prevent division by 0 |
Parameterized Lagrangians
It is beneficial to consider lagrangians as modules with their own learnable parameters.
LSphericalNorm
LSphericalNorm (gamma=1.0, delta=0.0, eps=1e-05)
Reduced Lagrangian whose activation function is the spherical L2 norm
Parameterized by (gamma, delta), a scale and a shift
Type | Default | Details | |
---|---|---|---|
gamma | float | 1.0 | Inverse temperature, for the sharpness of the exponent |
delta | float | 0.0 | |
eps | float | 1e-05 |
LLayerNorm
LLayerNorm (gamma=1.0, delta=0.0, eps=1e-05)
Reduced Lagrangian whose activation function is the layer norm
Parameterized by (gamma, delta), a scale and a shift
Type | Default | Details | |
---|---|---|---|
gamma | float | 1.0 | Inverse temperature, for the sharpness of the exponent |
delta | float | 0.0 | |
eps | float | 1e-05 |
LTanh
LTanh (beta=1.0, min_beta=1e-06)
Reduced Lagrangian whose activation function is the tanh
Parameterized by (beta)
Type | Default | Details | |
---|---|---|---|
beta | float | 1.0 | Inverse temperature, for the sharpness of the exponent |
min_beta | float | 1e-06 | Minimal accepted value of beta. For energy dynamics, it is important that beta be positive. |
LRexp
LRexp (beta=1.0, min_beta=1e-06)
Reduced Lagrangian whose activation function is the rectified exponential function
Parameterized by (beta)
Type | Default | Details | |
---|---|---|---|
beta | float | 1.0 | Inverse temperature, for the sharpness of the exponent |
min_beta | float | 1e-06 | Minimal accepted value of beta. For energy dynamics, it is important that beta be positive. |
LExp
LExp (beta=1.0, min_beta=1e-06)
Reduced Lagrangian whose activation function is the exponential function
Parameterized by (beta)
Type | Default | Details | |
---|---|---|---|
beta | float | 1.0 | Inverse temperature, for the sharpness of the exponent |
min_beta | float | 1e-06 | Minimal accepted value of beta. For energy dynamics, it is important that beta be positive. |
LSoftmaxUnstable
LSoftmaxUnstable (beta=1.0, axis=-1, min_beta=1e-06)
Reduced Lagrangian whose activation function is the softmax, using manually coded logsumexp
for browser compatibility
Parameterized by (beta)
Type | Default | Details | |
---|---|---|---|
beta | float | 1.0 | Inverse temperature |
axis | int | -1 | Axis over which to apply the softmax |
min_beta | float | 1e-06 | Minimal accepted value of beta. For energy dynamics, it is important that beta be positive. |
LSoftmax
LSoftmax (beta=1.0, axis=-1, min_beta=1e-06)
Reduced Lagrangian whose activation function is the softmax
Parameterized by (beta)
Type | Default | Details | |
---|---|---|---|
beta | float | 1.0 | Inverse temperature |
axis | int | -1 | Axis over which to apply the softmax |
min_beta | float | 1e-06 | Minimal accepted value of beta. For energy dynamics, it is important that beta be positive. |
LSigmoid
LSigmoid (beta=1.0, scale=1.0, min_beta=1e-06)
Reduced Lagrangian whose activation function is the sigmoid
Parameterized by (beta)
Type | Default | Details | |
---|---|---|---|
beta | float | 1.0 | Inverse temperature |
scale | float | 1.0 | Amount to stretch the sigmoid. |
min_beta | float | 1e-06 | Minimal accepted value of beta. For energy dynamics, it is important that beta be positive. |
LRelu
LRelu ()
Reduced Lagrangian whose activation function is the rectified linear unit
LRepu
LRepu (n=2.0)
Reduced Lagrangian whose activation function is the rectified polynomial unit of specified degree n
Type | Default | Details | |
---|---|---|---|
n | float | 2.0 | The degree of the RePU. By default, set to the ReLU configuration |
LIdentity
LIdentity ()
Reduced Lagrangian whose activation function is the identity function
The modules can be initialized and used as follows:
= LRepu(n=4)
lag = jax.random.normal(jax.random.PRNGKey(1), (7,))
x lag(x)
Array(8.217218, dtype=float32)
To apply the lagrangian to a batch of samples, we should take advantage of jax.vmap
= jax.vmap(LRepu(n=4)); batch_size = 4
lag = jax.random.normal(jax.random.PRNGKey(1), (batch_size,7))
x lag(x)
Array([0.79066116, 0.634871 , 0.13261071, 3.3233747 ], dtype=float32)