Lagrangians

The well-behaved energy of associative memories is captured by the Lagrangian of the neurons

The dynamic state of each 🌀neuron layer is constrained by a convex Lagrangian.

Functional interface

Lagrangian functions are fundamental to the energy of 🌀neuron layers. These convex functions can be seen as the integral of common activation functions (e.g., relus and softmaxes). All Lagrangians are functions of the form:

\[\mathcal{L}(\mathbf{x};\ldots) \mapsto \mathbb{R}\]

where \(\mathbf{x} \in \mathbb{R}^{D_1 \times \ldots \times D_n}\) can be a tensor of arbitrary shape and \(\mathcal{L}\) can be optionally parameterized (e.g., the LayerNorm’s learnable bias and scale). Lagrangians must be convex and differentiable.

We want to rely on JAX’s autograd to automatically differentiate our Lagrangians into activation functions. For certain Lagrangians, the naively autodiff-ed function of the defined Lagrangian is numerically unstable (e.g., lagr_sigmoid(x) and lagr_tanh(x)). In these cases, we follow JAX’s documentation guidelines to define custom_jvps to fix this behavior.

Let’s look at what some of these Lagrangians look like in practice.

Elementwise Lagrangians

Though we define Lagrangians for an entire tensor, these special “elementwise Lagrangians” take a special form: they are simply the sum of the convex, differentiable function applied elementwise to the underlying tensor. This makes it easy to plot and visualize them.


lagr_identity

 lagr_identity (x:jax.Array)

The Lagrangian whose activation function is simply the identity.

Type Details
x Array Input tensor
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{identity}(\mathbf{x}) &= \frac{1}{2} \sum_i x_i^2 \\ \partial_{x_i} \mathcal{L}_\text{identity}(\mathbf{x}) &= x_i \end{align*} \]


lagr_repu

 lagr_repu (x:jax.Array, n:float)

Rectified Power Unit of degree n

Type Details
x Array Input tensor
n float Degree of the polynomial in the power unit
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{RePU}(\mathbf{x}; n) &= \frac{1}{n} \sum_i \max(x_i, 0)^n \\ \partial_{x_i} \mathcal{L}_\text{RePU}(\mathbf{x}; n) &= \max(x_i, 0)^{n-1} \end{align*} \]


lagr_relu

 lagr_relu (x:jax.Array)

Rectified Linear Unit. Same as lagr_repu of degree 2

Type Details
x Array Input tensor
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{relu}(\mathbf{x}) &= \frac{1}{2} \sum_i \max(x_i, 0)^2 \\ \partial_{x_i} \mathcal{L}_\text{relu}(\mathbf{x}) &= \max(x_i, 0) \end{align*} \]


lagr_exp

 lagr_exp (x:jax.Array, beta:float=1.0)

Exponential activation function, as in Demicirgil et al.. Operates elementwise

Type Default Details
x Array Input tensor
beta float 1.0 Inverse temperature
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{exp}(\mathbf{x}; \beta) &= \frac{1}{\beta} \sum_i e^{\beta x_i} \\ \partial_{x_i} \mathcal{L}_\text{exp}(\mathbf{x}; \beta) &= e^{\beta x_i} \end{align*} \]


lagr_tanh

 lagr_tanh (x:jax.Array, beta:float=1.0)

Lagrangian of the tanh activation function

Type Default Details
x Array Input tensor
beta float 1.0 Inverse temperature
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{tanh}(\mathbf{x}; \beta) &= \frac{1}{\beta} \sum_i \log(\cosh(\beta x_i)) \\ \partial_{x_i} \mathcal{L}_\text{tanh}(\mathbf{x}; \beta) &= \tanh(\beta x_i) \end{align*} \]


lagr_sigmoid

 lagr_sigmoid (x:jax.Array, beta:float=1.0)

The lagrangian of the sigmoid activation function

Type Default Details
x Array Input tensor
beta float 1.0 Inverse temperature
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{sigmoid}(\mathbf{x}; \beta) &= \frac{1}{\beta} \sum_i \log(e^{\beta x_i} + 1) \\ \partial_{x_i} \mathcal{L}_\text{sigmoid}(\mathbf{x}; \beta) &= \frac{1}{1+e^{-\beta x_i}} \end{align*} \]

Lagrangians for competing units

However, we can also have Lagrangians defined on neuron layers where every unit competes. There are many forms of activation functions in modern Deep Learning with this structure; e.g., softmaxes, layernorms, etc. normalize their input by some value. These activation functions can similarly be described by Lagrangians (and there is a nice interpretation of these kinds of activation functions as competing hidden units, but it is harder to plot them.


lagr_softmax

 lagr_softmax (x:jax.Array, beta:float=1.0, axis:int=-1)

The lagrangian of the softmax – the logsumexp

Type Default Details
x Array Input tensor
beta float 1.0 Inverse temperature
axis int -1 Dimension over which to apply logsumexp
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{softmax}(\mathbf{x}; \beta) &= \frac{1}{\beta} \log \sum_i e^{\beta x_i} \\ \partial_{x_i} \mathcal{L}_\text{softmax}(\mathbf{x}; \beta) &= \frac{e^{\beta x_i}}{\sum_j e^{\beta x_j}} \end{align*} \]

We plot its activations (the softmax) for a vector of length 10 below.


lagr_layernorm

 lagr_layernorm (x:jax.Array, gamma:float=1.0,
                 delta:Union[float,jax.Array]=0.0, axis:int=-1,
                 eps:float=1e-05)

*Lagrangian of the layer norm activation function.

gamma must be a float, not a vector.*

Type Default Details
x Array Input tensor
gamma float 1.0 Scale the stdev
delta Union 0.0 Shift the mean
axis int -1 Which axis to normalize
eps float 1e-05 Prevent division by 0
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{layernorm}(\mathbf{x}; \gamma, \delta) &= D \gamma \sqrt{\text{Var}(\mathbf{x}) + \epsilon} + \sum_i \delta_i x_i \\ \partial_{x_i} \mathcal{L}_\text{layernorm}(\mathbf{x}; \gamma, \delta) &= \gamma \frac{x_i - \text{Mean}(\mathbf{x})}{\sqrt{\text{Var}(\mathbf{x}) + \epsilon}} + \delta_i \end{align*} \]


lagr_spherical_norm

 lagr_spherical_norm (x:jax.Array, gamma:float=1.0,
                      delta:Union[float,jax.Array]=0.0, axis:int=-1,
                      eps:float=1e-05)

Lagrangian of the spherical norm (L2 norm) activation function

Type Default Details
x Array input tensor
gamma float 1.0 Scale the stdev
delta Union 0.0 Shift the mean
axis int -1 Which axis to normalize
eps float 1e-05 Prevent division by 0
Returns Float Output scalar

\[ \begin{align*} \mathcal{L}_\text{L2norm}(\mathbf{x}; \gamma, \delta) &= \gamma \sqrt{\sum_i x_i^2 + \epsilon} + \sum_i \delta_i x_i \\ \partial_{x_i} \mathcal{L}_\text{L2norm}(\mathbf{x}; \gamma, \delta) &= \gamma \frac{x_i}{\sqrt{\sum_j x_j^2 + \epsilon}} + \delta_i \end{align*} \]