Converting to TFJS

Package and ship to the browser

One appeal of HAMs is their general compactness – all model weights are symmetrical, so powerful models can more easily fit inside RAM. In addition, we believe HAMs to be more interpretable and fun to play around with. For these two reasons, we are building some simple helper functions to convert trained HAM models into the frontend.

The below code is primarily copied from the source code present in the examples in the official tutorial.


source

convert_jax

 convert_jax (apply_fn:Callable[...,Any],
              input_signatures:Sequence[Tuple[Sequence[Optional[int]],Any]
              ], model_dir:str, polymorphic_shapes:Optional[Sequence[Union
              [str,jax.experimental.jax2tf.shape_poly.PolyShape]]]=None)

Converts a JAX function apply_fn to a TensorflowJS model. Works with functools.partial style models if we don’t need to access the variables in the frontend.

Example usage for an arbitrary function:

import functools as ft
...
def predict_fn(model, input):
    return model.predict(input)

fn = ft.partial(predict_fn, trained_model)

convert_jax(
    apply_fn=fn,
    input_signatures=[((D1, D2,), np.float32)],
    model_dir=tfjs_model_dir)

Note that when using dynamic shapes, an additional argument polymorphic_shapes should be provided specifying values for the dynamic (“polymorphic”) dimensions). See here for more details: https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion

This is an adaption of the original implementation in jax2tf here: https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py

Arguments: apply_fn: A JAX function that has one or more arguments, of which the first argument are the model parameters. This function typically is the forward pass of the network (e.g., Module.apply() in Flax). input_signatures: the input signatures for the second and remaining arguments to apply_fn (the input). A signature must be a tensorflow.TensorSpec instance, or a (nested) tuple/list/dictionary thereof with a structure matching the second argument of apply_fn. model_dir: Directory where the TensorflowJS model will be written to. polymorphic_shapes: If given then it will be used as the polymorphic_shapes argument for the second parameter of apply_fn. In this case, a single input_signatures is supported, and should have None in the polymorphic (dynamic) dimensions.

Let’s presume the following example model

import hamux as hmx
import jax
import jax.numpy as jnp
import functools as ft

model, fwd = hmx.create_model("hn_softmax_mnist")
states, model = model.init_states_and_params(jax.random.PRNGKey(0), bs=1)

def simple_batch_fwd(
    x: jnp.ndarray, # Starting point for clamped layer[0]
    dt: float): # Step size through time
    """A simple version of the forward function"""
    # Initialize hidden states to our image
    xs = model.init_states(x.shape[0])
    xs[0] = jnp.array(x)

    updates = model.vupdates(xs)  # Calculate the updates
    new_xs = model.step(
        xs, updates, dt=dt
    )  # Add them to our current states

    # All labels have a softmax activation function as the last layer, spitting out probabilities
    return model.layers[-1].g(xs[-1])
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# specify where to save the model
tfjs_model_dir = f'_archive/hamux_model/'
convert_jax(
    simple_batch_fwd,
    input_signatures=[tf.TensorSpec(states[0].shape, tf.float32), tf.TensorSpec((1,), tf.float32)], # img, dt
    model_dir=tfjs_model_dir,
)
2022-12-13 17:02:40.458245: E tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Writing weight file _archive/hamux_model/model.json...