Skip to content

Commit

Permalink
merge g3 cl (#8138)
Browse files Browse the repository at this point in the history
  • Loading branch information
pyu10055 committed Jan 23, 2024
1 parent 7a1cc67 commit 45f1a86
Showing 1 changed file with 1 addition and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union

from jax.experimental import jax2tf
from jax.experimental.export import shape_poly
import tensorflow as tf
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion


_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Array = Any
DType = Any
PolyShape = shape_poly.PolyShape


class _ReusableSavedModelWrapper(tf.train.Checkpoint):
Expand Down Expand Up @@ -60,7 +58,7 @@ def convert_jax(
*,
input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
model_dir: str,
polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None,
polymorphic_shapes: Optional[Sequence[str]] = None,
**tfjs_converter_params):
"""Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.
Expand Down

0 comments on commit 45f1a86

Please sign in to comment.