New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for transformers pyfunc #8181
Conversation
Documentation preview for daeb811 will be available here when this CircleCI job completes successfully. More info
|
mlflow/types/utils.py
Outdated
isinstance(value, str) for d in data for value in d.values() | ||
): | ||
schema = Schema([ColSpec(type=DataType.string, name=name) for name in data[0].keys()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could support List[Dict[str, <any scalar>]]
by doing pd.DataFrame.from_records()
on the list of dicts and then using the logic for Pandas DataFrame schema inference above:
elif isinstance(data, pd.DataFrame):
schema = Schema(
[ColSpec(type=_infer_pandas_column(data[col]), name=col) for col in data.columns]
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any HF language model types where users need to pass dicts with non-string scalar values (e.g. int values, float values, etc), or do string values cover all cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No LLMs require this (but the multimodal ones will in the future). Should we implement this when we work on functionality to support those later?
mlflow/transformers.py
Outdated
output_key = "answer" | ||
data = self._parse_json_encoded_dict_payload_to_dict(data, "table") | ||
elif isinstance(self.pipeline, transformers.TokenClassificationPipeline): | ||
output_key = "entity" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the docs of https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/pipelines#transformers.TokenClassificationPipeline.__call__, it looks like entity
is renamed entity_group
when the aggregation strategy is not None
.
Can we handle that case, include test coverage, and check the __call__
documentation of other pipeline types to make sure that we don't need condition output key logic for other pipeline types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added handling for this, added a test suite to validate that we parse it correctly, and verified by looking through the source code of the master branch that no other pipelines for LLM use cases do modifications of the output naming conventions.
interim_output = self._parse_lists_of_dict_to_list_of_str(raw_output, output_key) | ||
output = self._parse_list_output_for_multiple_candidate_pipelines(interim_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need these functions? From https://github.com/mlflow/mlflow/pull/8181/files#r1159343915, it seems like we can just parse out the first element of the "label" key from each result dictionary and return. This seems like it would be easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic within the latter one has been expanded a bit and I collapsed the two functions into one.
elif isinstance(self.pipeline, transformers.TokenClassificationPipeline): | ||
output = self._parse_tokenizer_output(raw_output, output_key) | ||
else: | ||
output = self._parse_lists_of_dict_to_list_of_str(raw_output, output_key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have some examples of pipelines that would hit this case and what the output would look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TextGenerator, Text2TextGenerator, QuestionAnswering, and Seq2Seq pipelines all have a mix of return types of List[Dict[str, str]], List[List[str]] or Dict[str,str] which the else condition function here will all convert to List[str]
task=task, | ||
inference_config=inference_config | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before releasing this with MLflow 2.3.0, can we make sure that we include documentation for the input / output behavior of pyfunc predict for common pipeline types and pipeline types whose inference behavior may not be intuitive? If possible, can you file an internal ticket to track that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a ticket and will work on a full mapping of all support pyfunc types input/outputs (making special note of the input types for 2 pipelines that diverge from the transformers API (requiring a Dict input instead of the pseudo-dict string input that is submitted (i.e. "question: <text> context: <text>"
)
if isinstance(pf_input, dict) and all( | ||
not isinstance(value, (dict, list)) for value in pf_input.values() | ||
): | ||
pf_input = pd.DataFrame(pf_input, index=[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a behavior change for dataframes with numpy array columns, scalar columns, etc.
Can you provide more context here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've adjusted this logic to be very explicit about matching on dict where the keys and values are all strings. Without this logic, we can't cast a scalar string to a DataFrame for serving or do signature validation. By supplying the index, we're able to create the DataFrame without an Exception being thrown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff, @BenWilson2 ! Left some initial comments :)
mlflow/transformers.py
Outdated
"context": last.strip(), | ||
} | ||
else: | ||
context, question = last.split("question:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context, question = last.split("question:") | |
context, question = last.split("question:", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored this to use Dict inputs since the string parsing was very confusing and potentially error-prone.
mlflow/transformers.py
Outdated
"required keys are provided." | ||
) | ||
else: | ||
first, last = data.split("context:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first, last = data.split("context:") | |
first, last = data.split("context:", 1) |
Unpacking without specifying maxsplit
is unsafe :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent point :) I changed the string parsing logic to just use Dicts so that we reduce the complexity and supported syntax types for the flavor.
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
|
||
|
||
@pytest.fixture(scope="module") | ||
def small_vision_model(): | ||
architecture = "google/mobilenet_v2_1.0_224" | ||
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(architecture) | ||
model = transformers.MobileNetV2ForImageClassification.from_pretrained(architecture) | ||
return transformers.pipeline( | ||
yield transformers.pipeline( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yield transformers.pipeline( | |
return transformers.pipeline( |
We can use return
if we don't have any teardown code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had tried putting the cache cleanup call at the conclusion of each fixture after the yield, but that was when using module scoping (where it never got triggered). Switching next commit push to use function-scoping (default) and reverting back to return statements
mlflow/transformers.py
Outdated
if not isinstance(self.pipeline, transformers.TableQuestionAnsweringPipeline): | ||
return data | ||
|
||
if "table" not in data.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "table" not in data.keys(): | |
if "table" not in data: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed in 4 locations.
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
mlflow/transformers.py
Outdated
expected_keys = {"question", "context"} | ||
if not expected_keys.intersection(set(data.keys())) == expected_keys: | ||
raise MlflowException( | ||
"Invalid keys were submitted. Keys must be exclusively " f"{expected_keys}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Invalid keys were submitted. Keys must be exclusively " f"{expected_keys}" | |
f"Invalid keys were submitted. Keys must be exclusively {expected_keys}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks :)
elif ( | ||
isinstance(data, list) | ||
and all(isinstance(element, dict) for element in data) | ||
and all(isinstance(value, str) for d in data for value in d.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and all(isinstance(value, str) for d in data for value in d.values()) | |
and all(isinstance(k, str) for d in data for k in d.values()) | |
and all(isinstance(value, str) for d in data for value in d.values()) |
Do we also need to check the keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good catch! Added logic for key validation
mlflow/deployments/__init__.py
Outdated
elif isinstance(predictions, list) and all( | ||
isinstance(value, str) for value in predictions | ||
): | ||
return pd.DataFrame(data=predictions, index=[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got the following error while testing this block. Am I missing something?
>>> pd.DataFrame(data=["a", "b"], index=[0])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/haru/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/pandas/core/frame.py", line 791, in __init__
mgr = ndarray_to_mgr(
File "/home/haru/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/pandas/core/internals/construction.py", line 317, in ndarray_to_mgr
_check_values_indices_shape_match(values, index, columns)
File "/home/haru/miniconda3/envs/mlflow-dev-env/lib/python3.8/site-packages/pandas/core/internals/construction.py", line 388, in _check_values_indices_shape_match
raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")
ValueError: Shape of passed values is (2, 1), indices imply (1, 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be mirroring the recent changes to models.utils.py. (no need to specify the index; if string, wrap in a list). Changed!
mlflow/models/utils.py
Outdated
@@ -141,6 +142,8 @@ def _handle_dataframe_input(input_ex): | |||
input_ex = pd.DataFrame([input_ex], columns=range(len(input_ex))) | |||
else: | |||
input_ex = pd.DataFrame(input_ex) | |||
elif isinstance(input_ex, str): | |||
input_ex = pd.DataFrame([input_ex], index=[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_ex = pd.DataFrame([input_ex], index=[0]) | |
input_ex = pd.DataFrame([input_ex]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated :)
It looks like we're downloading too many models. |
mlflow/transformers.py
Outdated
from mlflow.transformers import _TransformersWrapper | ||
from transformers import pipeline | ||
|
||
en_to_de = pipeline("translation_en_to_de") | ||
|
||
data = "MLflow is great!" | ||
|
||
inference_pyfunc = _TransformersWrapper(en_to_de) | ||
signature = infer_signature(data, inference_pyfunc.predict(data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid exposing _TransformersWrapper
(private) like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a public function for this. We can't use the raw output from a pipeline to get a correct signature due to the manipulations to the input & outputs that are wrapped around the pipeline objects, so we need something that can wrap a pipeline in our pyfunc logic.
mlflow/transformers.py
Outdated
} | ||
|
||
# Verify that no exceptions are thrown | ||
sentence_generation(prompts, **inference_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is sentence_generation
defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I had renamed above it to 'sentence_pipeline'. Updated!
assert ( | ||
isinstance(pd_inference, str) | ||
if isinstance(inference_payload, dict) | ||
else isinstance(inference, list) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert ( | |
isinstance(pd_inference, str) | |
if isinstance(inference_payload, dict) | |
else isinstance(inference, list) | |
) | |
if if isinstance(inference_payload, dict): | |
assert isinstance(pd_inference, str) | |
else: | |
assert isinstance(inference, list) |
Can we use if-else for easier debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definitely
@@ -52,6 +58,22 @@ | |||
_IMAGE_PROCESSOR_API_CHANGE_VERSION = "4.26.0" | |||
|
|||
|
|||
def clean_cache(threshold=2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def clean_cache(threshold=2): | |
@pytest.fixture(autouse=True) | |
def clean_cache(threshold=2): |
Can we use autouse
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep! Updated and removed all previous func calls
# NB: do not purge the cache during this test. The pipeline will attempt to refer to | ||
# the cache state within the fixture context which will throw a file not found error | ||
# when attempting to read the stored FastTokenizer for the TextClassificationPipeline. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because scope
was set to "module"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed. Removed this note and only disabled the cache purge for two tests that exhibited flakiness when I was locally running. There was an attempt to do a cache fetch that popped up 2/20 times but only for those tests.
mlflow/transformers.py
Outdated
inference_pyfunc = _TransformersWrapper(en_to_de) | ||
signature = infer_signature(data, inference_pyfunc.predict(data)) | ||
|
||
with mlflow.start_run(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with mlflow.start_run(): |
mlflow.start_run
here can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few more comments, otherwise LGTM!
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
@mlflow-automation autoformat |
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
if [ "${{ matrix.package }}" = "transformers" ] | ||
then | ||
pip install accelerate datasets | ||
fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? The requirements field for transformers contains both accelerate
and datasets
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good point. Removed!
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I truly admire what's being done here, just to cover HF's inability to create a unified API on their own. Nice work, @BenWilson2 !
LGTM! I assume you will add more unit tests for this PR before merging.
Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
Signed-off-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
@jinzhang21 finished adding all of the parsing tests (and fixed 2 small edge case parsing bugs). Good to go (provided we're passing with this commit) |
Related Issues/PRs
#xxxWhat changes are proposed in this pull request?
Add pyfunc support to the transformers flavor
How is this patch tested?
Does this PR change the documentation?
Release Notes
Is this a user-facing change?
Adds pyfunc support for the transformers flavor.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templatesarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes