Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for transformers pyfunc #8181

Merged
merged 17 commits into from Apr 14, 2023

Conversation

BenWilson2
Copy link
Member

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Add pyfunc support to the transformers flavor

How is this patch tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests (describe details, including test results, below)

Does this PR change the documentation?

  • No. You can skip the rest of this section.
  • Yes. Make sure the changed pages / sections render correctly in the documentation preview.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Adds pyfunc support for the transformers flavor.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
@BenWilson2 BenWilson2 requested a review from dbczumar April 6, 2023 02:00
@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs. labels Apr 6, 2023
@mlflow-automation
Copy link
Collaborator

mlflow-automation commented Apr 6, 2023

Documentation preview for daeb811 will be available here when this CircleCI job completes successfully.

More info

mlflow/types/utils.py Outdated Show resolved Hide resolved
mlflow/types/utils.py Outdated Show resolved Hide resolved
mlflow/types/utils.py Outdated Show resolved Hide resolved
mlflow/types/utils.py Outdated Show resolved Hide resolved
mlflow/types/utils.py Outdated Show resolved Hide resolved
Comment on lines 150 to 152
isinstance(value, str) for d in data for value in d.values()
):
schema = Schema([ColSpec(type=DataType.string, name=name) for name in data[0].keys()])
Copy link
Collaborator

@dbczumar dbczumar Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could support List[Dict[str, <any scalar>]] by doing pd.DataFrame.from_records() on the list of dicts and then using the logic for Pandas DataFrame schema inference above:

    elif isinstance(data, pd.DataFrame):
        schema = Schema(
            [ColSpec(type=_infer_pandas_column(data[col]), name=col) for col in data.columns]
        )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any HF language model types where users need to pass dicts with non-string scalar values (e.g. int values, float values, etc), or do string values cover all cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No LLMs require this (but the multimodal ones will in the future). Should we implement this when we work on functionality to support those later?

mlflow/models/signature.py Outdated Show resolved Hide resolved
output_key = "answer"
data = self._parse_json_encoded_dict_payload_to_dict(data, "table")
elif isinstance(self.pipeline, transformers.TokenClassificationPipeline):
output_key = "entity"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docs of https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/pipelines#transformers.TokenClassificationPipeline.__call__, it looks like entity is renamed entity_group when the aggregation strategy is not None.

Can we handle that case, include test coverage, and check the __call__ documentation of other pipeline types to make sure that we don't need condition output key logic for other pipeline types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added handling for this, added a test suite to validate that we parse it correctly, and verified by looking through the source code of the master branch that no other pipelines for LLM use cases do modifications of the output naming conventions.

mlflow/transformers.py Outdated Show resolved Hide resolved
mlflow/transformers.py Outdated Show resolved Hide resolved
mlflow/transformers.py Outdated Show resolved Hide resolved
mlflow/transformers.py Outdated Show resolved Hide resolved
mlflow/transformers.py Outdated Show resolved Hide resolved
Comment on lines +1246 to +1247
interim_output = self._parse_lists_of_dict_to_list_of_str(raw_output, output_key)
output = self._parse_list_output_for_multiple_candidate_pipelines(interim_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these functions? From https://github.com/mlflow/mlflow/pull/8181/files#r1159343915, it seems like we can just parse out the first element of the "label" key from each result dictionary and return. This seems like it would be easier to read.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic within the latter one has been expanded a bit and I collapsed the two functions into one.

elif isinstance(self.pipeline, transformers.TokenClassificationPipeline):
output = self._parse_tokenizer_output(raw_output, output_key)
else:
output = self._parse_lists_of_dict_to_list_of_str(raw_output, output_key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have some examples of pipelines that would hit this case and what the output would look like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TextGenerator, Text2TextGenerator, QuestionAnswering, and Seq2Seq pipelines all have a mix of return types of List[Dict[str, str]], List[List[str]] or Dict[str,str] which the else condition function here will all convert to List[str]

task=task,
inference_config=inference_config
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before releasing this with MLflow 2.3.0, can we make sure that we include documentation for the input / output behavior of pyfunc predict for common pipeline types and pipeline types whose inference behavior may not be intuitive? If possible, can you file an internal ticket to track that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a ticket and will work on a full mapping of all support pyfunc types input/outputs (making special note of the input types for 2 pipelines that diverge from the transformers API (requiring a Dict input instead of the pseudo-dict string input that is submitted (i.e. "question: <text> context: <text>")

Comment on lines +621 to +624
if isinstance(pf_input, dict) and all(
not isinstance(value, (dict, list)) for value in pf_input.values()
):
pf_input = pd.DataFrame(pf_input, index=[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a behavior change for dataframes with numpy array columns, scalar columns, etc.

Can you provide more context here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've adjusted this logic to be very explicit about matching on dict where the keys and values are all strings. Without this logic, we can't cast a scalar string to a DataFrame for serving or do signature validation. By supplying the index, we're able to create the DataFrame without an Exception being thrown.

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff, @BenWilson2 ! Left some initial comments :)

@harupy harupy mentioned this pull request Apr 7, 2023
35 tasks
"context": last.strip(),
}
else:
context, question = last.split("question:")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context, question = last.split("question:")
context, question = last.split("question:", 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this to use Dict inputs since the string parsing was very confusing and potentially error-prone.

"required keys are provided."
)
else:
first, last = data.split("context:")
Copy link
Member

@harupy harupy Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
first, last = data.split("context:")
first, last = data.split("context:", 1)

Unpacking without specifying maxsplit is unsafe :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point :) I changed the string parsing logic to just use Dicts so that we reduce the complexity and supported syntax types for the flavor.

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>


@pytest.fixture(scope="module")
def small_vision_model():
architecture = "google/mobilenet_v2_1.0_224"
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(architecture)
model = transformers.MobileNetV2ForImageClassification.from_pretrained(architecture)
return transformers.pipeline(
yield transformers.pipeline(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
yield transformers.pipeline(
return transformers.pipeline(

We can use return if we don't have any teardown code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried putting the cache cleanup call at the conclusion of each fixture after the yield, but that was when using module scoping (where it never got triggered). Switching next commit push to use function-scoping (default) and reverting back to return statements

if not isinstance(self.pipeline, transformers.TableQuestionAnsweringPipeline):
return data

if "table" not in data.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "table" not in data.keys():
if "table" not in data:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed in 4 locations.

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
expected_keys = {"question", "context"}
if not expected_keys.intersection(set(data.keys())) == expected_keys:
raise MlflowException(
"Invalid keys were submitted. Keys must be exclusively " f"{expected_keys}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Invalid keys were submitted. Keys must be exclusively " f"{expected_keys}"
f"Invalid keys were submitted. Keys must be exclusively {expected_keys}"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks :)

elif (
isinstance(data, list)
and all(isinstance(element, dict) for element in data)
and all(isinstance(value, str) for d in data for value in d.values())
Copy link
Member

@harupy harupy Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and all(isinstance(value, str) for d in data for value in d.values())
and all(isinstance(k, str) for d in data for k in d.values())
and all(isinstance(value, str) for d in data for value in d.values())

Do we also need to check the keys?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good catch! Added logic for key validation

Comment on lines 48 to 51
elif isinstance(predictions, list) and all(
isinstance(value, str) for value in predictions
):
return pd.DataFrame(data=predictions, index=[0])
Copy link
Member

@harupy harupy Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got the following error while testing this block. Am I missing something?

>>> pd.DataFrame(data=["a", "b"], index=[0])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/haru/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/pandas/core/frame.py", line 791, in __init__
    mgr = ndarray_to_mgr(
  File "/home/haru/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/pandas/core/internals/construction.py", line 317, in ndarray_to_mgr
    _check_values_indices_shape_match(values, index, columns)
  File "/home/haru/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/pandas/core/internals/construction.py", line 388, in _check_values_indices_shape_match
    raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")
ValueError: Shape of passed values is (2, 1), indices imply (1, 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be mirroring the recent changes to models.utils.py. (no need to specify the index; if string, wrap in a list). Changed!

@@ -141,6 +142,8 @@ def _handle_dataframe_input(input_ex):
input_ex = pd.DataFrame([input_ex], columns=range(len(input_ex)))
else:
input_ex = pd.DataFrame(input_ex)
elif isinstance(input_ex, str):
input_ex = pd.DataFrame([input_ex], index=[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_ex = pd.DataFrame([input_ex], index=[0])
input_ex = pd.DataFrame([input_ex])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated :)

@harupy
Copy link
Member

harupy commented Apr 12, 2023

It looks like we're downloading too many models.

Comment on lines 319 to 327
from mlflow.transformers import _TransformersWrapper
from transformers import pipeline

en_to_de = pipeline("translation_en_to_de")

data = "MLflow is great!"

inference_pyfunc = _TransformersWrapper(en_to_de)
signature = infer_signature(data, inference_pyfunc.predict(data))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid exposing _TransformersWrapper (private) like this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a public function for this. We can't use the raw output from a pipeline to get a correct signature due to the manipulations to the input & outputs that are wrapped around the pipeline objects, so we need something that can wrap a pipeline in our pyfunc logic.

}

# Verify that no exceptions are thrown
sentence_generation(prompts, **inference_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is sentence_generation defined?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I had renamed above it to 'sentence_pipeline'. Updated!

Comment on lines 1103 to 1107
assert (
isinstance(pd_inference, str)
if isinstance(inference_payload, dict)
else isinstance(inference, list)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert (
isinstance(pd_inference, str)
if isinstance(inference_payload, dict)
else isinstance(inference, list)
)
if if isinstance(inference_payload, dict):
assert isinstance(pd_inference, str)
else:
assert isinstance(inference, list)

Can we use if-else for easier debugging?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely

@@ -52,6 +58,22 @@
_IMAGE_PROCESSOR_API_CHANGE_VERSION = "4.26.0"


def clean_cache(threshold=2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def clean_cache(threshold=2):
@pytest.fixture(autouse=True)
def clean_cache(threshold=2):

Can we use autouse?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep! Updated and removed all previous func calls

Comment on lines 1646 to 1648
# NB: do not purge the cache during this test. The pipeline will attempt to refer to
# the cache state within the fixture context which will throw a file not found error
# when attempting to read the stored FastTokenizer for the TextClassificationPipeline.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because scope was set to "module"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed. Removed this note and only disabled the cache purge for two tests that exhibited flakiness when I was locally running. There was an attempt to do a cache fetch that popped up 2/20 times but only for those tests.

inference_pyfunc = _TransformersWrapper(en_to_de)
signature = infer_signature(data, inference_pyfunc.predict(data))

with mlflow.start_run():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with mlflow.start_run():

mlflow.start_run here can be removed.

Copy link
Member

@harupy harupy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few more comments, otherwise LGTM!

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
@BenWilson2
Copy link
Member Author

@mlflow-automation autoformat

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Comment on lines 175 to 178
if [ "${{ matrix.package }}" = "transformers" ]
then
pip install accelerate datasets
fi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? The requirements field for transformers contains both accelerate and datasets.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good point. Removed!

BenWilson2 and others added 2 commits April 12, 2023 20:00
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
Copy link
Collaborator

@jinzhang21 jinzhang21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I truly admire what's being done here, just to cover HF's inability to create a unified API on their own. Nice work, @BenWilson2 !
LGTM! I assume you will add more unit tests for this PR before merging.

BenWilson2 and others added 2 commits April 13, 2023 16:34
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
@BenWilson2
Copy link
Member Author

@jinzhang21 finished adding all of the parsing tests (and fixed 2 small edge case parsing bugs). Good to go (provided we're passing with this commit)

@BenWilson2 BenWilson2 merged commit 84b617a into mlflow:master Apr 14, 2023
26 of 28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors area/tracking Tracking service, tracking client APIs, autologging rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants