One appeal of HAMs is their general compactness – all model weights are symmetrical, so powerful models can more easily fit inside RAM. In addition, we believe HAMs to be more interpretable and fun to play around with. For these two reasons, we are building some simple helper functions to convert trained HAM models into the frontend.
The below code is primarily copied from the source code present in the examples in the official tutorial.
Converts a JAX function apply_fn to a TensorflowJS model. Works with functools.partial style models if we don’t need to access the variables in the frontend.
Example usage for an arbitrary function:
import functools as ft
...
def predict_fn(model, input):
return model.predict(input)
fn = ft.partial(predict_fn, trained_model)
convert_jax(
apply_fn=fn,
input_signatures=[((D1, D2,), np.float32)],
model_dir=tfjs_model_dir)
Note that when using dynamic shapes, an additional argument polymorphic_shapes should be provided specifying values for the dynamic (“polymorphic”) dimensions). See here for more details: https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion
This is an adaption of the original implementation in jax2tf here: https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py
Arguments: apply_fn: A JAX function that has one or more arguments, of which the first argument are the model parameters. This function typically is the forward pass of the network (e.g., Module.apply() in Flax). input_signatures: the input signatures for the second and remaining arguments to apply_fn (the input). A signature must be a tensorflow.TensorSpec instance, or a (nested) tuple/list/dictionary thereof with a structure matching the second argument of apply_fn. model_dir: Directory where the TensorflowJS model will be written to. polymorphic_shapes: If given then it will be used as the polymorphic_shapes argument for the second parameter of apply_fn. In this case, a single input_signatures is supported, and should have None in the polymorphic (dynamic) dimensions.
Let’s presume the following example model
import hamux as hmximport jaximport jax.numpy as jnpimport functools as ftmodel, fwd = hmx.create_model("hn_softmax_mnist")states, model = model.init_states_and_params(jax.random.PRNGKey(0), bs=1)def simple_batch_fwd( x: jnp.ndarray, # Starting point for clamped layer[0] dt: float): # Step size through time"""A simple version of the forward function"""# Initialize hidden states to our image xs = model.init_states(x.shape[0]) xs[0] = jnp.array(x) updates = model.vupdates(xs) # Calculate the updates new_xs = model.step( xs, updates, dt=dt ) # Add them to our current states# All labels have a softmax activation function as the last layer, spitting out probabilitiesreturn model.layers[-1].g(xs[-1])
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# specify where to save the modeltfjs_model_dir =f'_archive/hamux_model/'convert_jax( simple_batch_fwd, input_signatures=[tf.TensorSpec(states[0].shape, tf.float32), tf.TensorSpec((1,), tf.float32)], # img, dt model_dir=tfjs_model_dir,)
2022-12-13 17:02:40.458245: E tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected