Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.4.13
Choose a base ref
...
head repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.4.14
Choose a head ref
  • 15 commits
  • 7 files changed
  • 4 contributors

Commits on Jun 6, 2024

  1. add get_route_info

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 6, 2024
    Copy the full SHA
    e278915 View commit details
  2. lazily create model_connection and _client

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 6, 2024
    Copy the full SHA
    e473c33 View commit details
  3. lazy load model_connection and tgis client for peft

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 6, 2024
    Copy the full SHA
    5578b7a View commit details
  4. remove commented out code

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 6, 2024
    Copy the full SHA
    71371bd View commit details

Commits on Jun 7, 2024

  1. Address review comments

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    21192c0 View commit details
  2. Expand test_get_route_info

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    1c23527 View commit details
  3. Lazily create generation client

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    a4e8539 View commit details
  4. Update minimum caikit-tgis-backend version

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    be72a79 View commit details
  5. Add debug logs

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    d5893d9 View commit details
  6. Linting

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    9848943 View commit details
  7. linting

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    5df0533 View commit details
  8. Update caikit_nlp/toolkit/text_generation/tgis_utils.py

    Co-authored-by: Gabe Goodhart <gabe.l.hart@gmail.com>
    Signed-off-by: Mynhardt Burger <mynhardt@gmail.com>
    mynhardtburger and gabe-l-hart authored Jun 7, 2024
    Copy the full SHA
    6bae66a View commit details
  9. review comments

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    46ae073 View commit details
  10. remove unreachable code

    Signed-off-by: Mynhardt Burger <Mynhardt.Burger@ibm.com>
    mynhardtburger committed Jun 7, 2024
    Copy the full SHA
    527b455 View commit details
  11. Merge pull request #363 from mynhardtburger/lazy-model-connection

    Lazy model connection & x-route-info
    gabe-l-hart authored Jun 7, 2024
    Copy the full SHA
    3b2f8fd View commit details
59 changes: 51 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,9 @@
"""This file contains a distributed backend implementation for leveraging the PEFT-trained
prompt vectors in TGIS generation requests.
"""

# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

@@ -32,6 +34,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

@@ -40,6 +43,7 @@
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
get_route_info,
)
from ...toolkit.verbalizer_utils import render_verbalizer
from . import PeftPromptTuning
@@ -68,15 +72,14 @@ def __init__(
prompt_artifacts: Optional[List[str]] = None,
) -> None:
super().__init__()
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None

self._tgis_backend = tgis_backend
if enable_backend:
error.type_check(
"<NLP33971947E>", TGISBackend, tgis_backend=self._tgis_backend
)
# get_client will also launch a local TGIS process and get the model
# loaded when using the local TGIS backend
self._client = tgis_backend.get_client(base_model_name)

# Tell the backend to load all of the available prompt files
if prompt_artifacts:
@@ -107,6 +110,14 @@ def __del__(self):
if tgis_backend and prompt_cache_id and model_id:
tgis_backend.unload_prompt_artifacts(model_id, prompt_cache_id)

@cached_property
def _client(self):
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
if self._tgis_backend:
return self._tgis_backend.get_client(self.base_model_name)

@classmethod
def load(cls, model_path: str, load_backend: BackendBase) -> "PeftPromptTuningTGIS":
"""Load a TGIS Peft Prompt Tuning distributed module. Note that we do not
@@ -182,7 +193,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
@@ -206,6 +217,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
@@ -221,6 +233,8 @@ def run(
self.enable_backend,
"Backend must be configured and loaded with this module before executing `run` call.",
)
self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.unary_generate(
text=verbalized_text,
@@ -244,7 +258,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
@@ -268,6 +282,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
@@ -283,6 +298,9 @@ def run_stream_out(
"Backend must be configured and loaded with this module \
before executing `run_stream_out` call.",
)

self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.stream_generate(
text=verbalized_text,
@@ -306,10 +324,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
@@ -320,6 +339,30 @@ def run_tokenizer(
TokenizationResults
The token count
"""

self._register_model_connection_with_context(context)

return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
if route_info := get_route_info(context):
log.debug(
"<NLP10705560D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self._tgis_backend.register_model_connection(
self.base_model_name,
{"hostname": route_info},
fill_with_defaults=True,
)
self._model_loaded = True
65 changes: 50 additions & 15 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@


# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

@@ -30,6 +31,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

@@ -43,6 +45,7 @@
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
get_route_info,
)
from .text_generation_local import TextGeneration

@@ -86,28 +89,33 @@ def __init__(
# Set _model_loaded as False by default. This will only get set to True if
# we enable the tgis_backend and we are able to fetch the client successfully.
self._model_loaded = False
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None
if tgis_backend:
self._client = tgis_backend.get_client(model_name)
# mark that the model is loaded so that we can unload it later
self._model_loaded = True
self.tgis_backend = tgis_backend

self._tgis_backend = tgis_backend
self._bos_token = bos_token
self._sep_token = sep_token
self._eos_token = eos_token
self._pad_token = pad_token
self.tgis_generation_client = TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

def __del__(self):
# nothing to unload if we didn't finish loading
if self._model_loaded and self.tgis_backend:
self.tgis_backend.unload_model(self.model_name)
if self._model_loaded and self._tgis_backend:
self._tgis_backend.unload_model(self.model_name)

@cached_property
def _client(self):
# Lazily configure/create the internal tgis backend client
if self._tgis_backend:
return self._tgis_backend.get_client(self.model_name)

@cached_property
def tgis_generation_client(self):
# Lazily create the generation client
# This in turn calls self._client which also lazily gets the tgis backend client
return TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

@classmethod
def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None):
@@ -207,7 +215,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
@@ -231,6 +239,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
@@ -240,6 +249,8 @@ def run(
GeneratedTextResult
Generated text result produced by TGIS.
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text=text,
@@ -263,7 +274,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
@@ -287,6 +298,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
@@ -295,6 +307,7 @@ def run_stream_out(
Returns:
Iterable[GeneratedTextStreamResult]
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.stream_generate(
@@ -319,10 +332,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
@@ -333,7 +347,28 @@ def run_tokenizer(
TokenizationResults
The token count
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
if route_info := get_route_info(context):
log.debug(
"<NLP15770311D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self._tgis_backend.register_model_connection(
self.model_name, {"hostname": route_info}, fill_with_defaults=True
)
self._model_loaded = True
39 changes: 36 additions & 3 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is for helper functions related to TGIS.
"""
"""This file is for helper functions related to TGIS."""

# Standard
from typing import Iterable
from typing import Iterable, Optional

# Third Party
import fastapi
import grpc

# First Party
@@ -33,6 +34,7 @@
TokenizationResults,
TokenStreamDetails,
)
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend.protobufs import generation_pb2
import alog

@@ -84,6 +86,9 @@
grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED,
}

# HTTP Header / gRPC Metadata key used to identify a route override
ROUTE_INFO_HEADER_KEY = "x-route-info"


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
"""Helper to wrap logic of converting from grpc.RpcError ->
@@ -683,3 +688,31 @@ def unary_tokenize(
return TokenizationResults(
token_count=response.token_count,
)


def get_route_info(
context: Optional[RuntimeServerContextType],
) -> Optional[str]:
"""
Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in
the headers/metadata.
Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the
context or if context is None.
"""
if context is None:
return None

if isinstance(context, grpc.ServicerContext):
route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY)
if route_info:
return route_info
elif isinstance(context, fastapi.Request):
route_info = context.headers.get(ROUTE_INFO_HEADER_KEY)
if route_info:
return route_info
else:
error.log_raise(
"<NLP92615097E>",
ValueError(f"context is of an unsupported type: {type(context)}"),
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ classifiers=[
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
"caikit-tgis-backend>=0.1.33,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
Loading