Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.4.9
Choose a base ref
...
head repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.4.10
Choose a head ref
  • 9 commits
  • 5 files changed
  • 4 contributors

Commits on May 14, 2024

  1. deps: pin grpc-* to a recent version to avoid pip backtracking

    Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
    dtrifiro committed May 14, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    dtrifiro Daniele
    Copy the full SHA
    f377397 View commit details
  2. Dockerfile: create venv with updated pip/setuptools

    Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
    dtrifiro committed May 14, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    dtrifiro Daniele
    Copy the full SHA
    a251d59 View commit details
  3. Dockerfile: use docker build cache

    Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
    dtrifiro committed May 14, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    dtrifiro Daniele
    Copy the full SHA
    cbe8d07 View commit details
  4. Merge pull request #356 from dtrifiro/prevent-pip-backtracking-ibm

    prevent pip backtracking
    gkumbhat authored May 14, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    2823875 View commit details
  5. HandleRpcError: Add RpcError handling in tgis_utils

    #354
    
    Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
    gabe-l-hart committed May 14, 2024
    Copy the full SHA
    b01c0e4 View commit details
  6. HandleRpcError: Tweaks to error code mapping for better user/internal

    #354
    
    Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
    gabe-l-hart committed May 14, 2024
    Copy the full SHA
    ebdfa38 View commit details
  7. Merge pull request #355 from gabe-l-hart/HandleRpcError-354

    HandleRpcError: Add RpcError handling in tgis_utils
    gkumbhat authored May 14, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    a5bf25e View commit details

Commits on May 15, 2024

  1. gha: remove useless steps in build-image

    Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
    dtrifiro committed May 15, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    dtrifiro Daniele
    Copy the full SHA
    1b41528 View commit details
  2. Merge pull request #357 from dtrifiro/use-build-cache

    Dockerfile: use build cache
    evaline-ju authored May 15, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    769812f View commit details
Showing with 237 additions and 49 deletions.
  1. +10 −10 .github/workflows/build-image.yml
  2. +7 −3 Dockerfile
  3. +85 −36 caikit_nlp/toolkit/text_generation/tgis_utils.py
  4. +3 −0 pyproject.toml
  5. +132 −0 tests/toolkit/text_generation/test_tgis_utils.py
20 changes: 10 additions & 10 deletions .github/workflows/build-image.yml
Original file line number Diff line number Diff line change
@@ -16,14 +16,14 @@ jobs:
name: Build Image
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Setup tox
run: |
pip install -U pip wheel
pip install tox
- uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build image
run: |
docker build -t caikit-nlp:latest .
uses: docker/build-push-action@v5
with:
context: .
tags: "caikit-nlp:latest"
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
10 changes: 7 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -15,18 +15,22 @@ COPY pyproject.toml .
COPY tox.ini .
COPY caikit_nlp caikit_nlp
# .git is required for setuptools-scm get the version
RUN --mount=source=.git,target=.git,type=bind tox -e build
RUN --mount=source=.git,target=.git,type=bind \
--mount=type=cache,target=/root/.cache/pip \
tox -e build


FROM base as deploy

RUN python -m venv /opt/caikit/
RUN python -m venv --upgrade-deps /opt/caikit/

ENV VIRTUAL_ENV=/opt/caikit
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

COPY --from=builder /build/dist/caikit_nlp*.whl /tmp/
RUN pip install --no-cache /tmp/caikit_nlp*.whl && rm /tmp/caikit_nlp*.whl
RUN --mount=type=cache,target=/root/.cache/pip \
pip install /tmp/caikit_nlp*.whl && \
rm /tmp/caikit_nlp*.whl

COPY LICENSE /opt/caikit/
COPY README.md /opt/caikit/
121 changes: 85 additions & 36 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,15 @@
# Standard
from typing import Iterable

# Third Party
import grpc

# First Party
from caikit.core.exceptions import error_handler
from caikit.core.exceptions.caikit_core_exception import (
CaikitCoreException,
CaikitCoreStatusCode,
)
from caikit.interfaces.nlp.data_model import (
GeneratedTextResult,
GeneratedTextStreamResult,
@@ -41,20 +48,53 @@
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
input_tokens: bool
Whether or not to include list of input tokens.
Whether or not to include list of input tokens.
generated_tokens: bool
Whether or not to include list of individual generated tokens.
Whether or not to include list of individual generated tokens.
token_logprobs: bool
Whether or not to include logprob for each returned token.
Whether or not to include logprob for each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
token_ranks: bool
Whether or not to include rank of each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
Whether or not to include rank of each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
""".format(
GENERATE_FUNCTION_ARGS
)


# Mapping from grpc status codes to caikit status codes. There is not a 1:1
# mapping at the moment, so this conversion is lossy!
GRPC_TO_CAIKIT_CORE_STATUS = {
grpc.StatusCode.CANCELLED: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.UNKNOWN: CaikitCoreStatusCode.UNKNOWN,
grpc.StatusCode.INVALID_ARGUMENT: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.DEADLINE_EXCEEDED: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.NOT_FOUND: CaikitCoreStatusCode.NOT_FOUND,
grpc.StatusCode.ALREADY_EXISTS: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.PERMISSION_DENIED: CaikitCoreStatusCode.FORBIDDEN,
grpc.StatusCode.RESOURCE_EXHAUSTED: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.FAILED_PRECONDITION: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.ABORTED: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.OUT_OF_RANGE: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.UNIMPLEMENTED: CaikitCoreStatusCode.UNKNOWN,
grpc.StatusCode.INTERNAL: CaikitCoreStatusCode.FATAL,
grpc.StatusCode.UNAVAILABLE: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.DATA_LOSS: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED,
}


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
"""Helper to wrap logic of converting from grpc.RpcError ->
CaikitCoreException
"""
caikit_status_code = GRPC_TO_CAIKIT_CORE_STATUS.get(
rpc_error.code(), CaikitCoreStatusCode.UNKNOWN
)
error_message = rpc_error.details() or f"Unknown RpcError: {rpc_error}"
raise CaikitCoreException(caikit_status_code, error_message) from rpc_error


def validate_inf_params(
text,
preserve_input_text,
@@ -391,7 +431,10 @@ def unary_generate(

# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
batch_response = self.tgis_client.Generate(request)
try:
batch_response = self.tgis_client.Generate(request)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

error.value_check(
"<NLP38899018E>",
@@ -532,37 +575,40 @@ def stream_generate(
)

# stream GenerationResponse
stream_response = self.tgis_client.GenerateStream(request)

for stream_part in stream_response:
details = TokenStreamDetails(
finish_reason=stream_part.stop_reason,
generated_tokens=stream_part.generated_token_count,
seed=stream_part.seed,
input_token_count=stream_part.input_token_count,
)
token_list = []
if stream_part.tokens is not None:
for token in stream_part.tokens:
token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
try:
stream_response = self.tgis_client.GenerateStream(request)

for stream_part in stream_response:
details = TokenStreamDetails(
finish_reason=stream_part.stop_reason,
generated_tokens=stream_part.generated_token_count,
seed=stream_part.seed,
input_token_count=stream_part.input_token_count,
)
token_list = []
if stream_part.tokens is not None:
for token in stream_part.tokens:
token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)
)
input_token_list = []
if stream_part.input_tokens is not None:
for token in stream_part.input_tokens:
input_token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
input_token_list = []
if stream_part.input_tokens is not None:
for token in stream_part.input_tokens:
input_token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)
)
yield GeneratedTextStreamResult(
generated_text=stream_part.text,
tokens=token_list,
input_tokens=input_token_list,
details=details,
)
yield GeneratedTextStreamResult(
generated_text=stream_part.text,
tokens=token_list,
input_tokens=input_token_list,
details=details,
)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

def unary_tokenize(
self,
@@ -598,7 +644,10 @@ def unary_tokenize(

# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
batch_response = self.tgis_client.Tokenize(request)
try:
batch_response = self.tgis_client.Tokenize(request)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

error.value_check(
"<NLP38899081E>",
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -17,6 +17,9 @@ dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
"grpcio-health-checking>=1.62.2",
"accelerate>=0.22.0",
"datasets>=2.4.0",
"huggingface-hub",
132 changes: 132 additions & 0 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for tgis_utils
"""
# Standard
from typing import Iterable, Optional, Type

# Third Party
import grpc
import grpc._channel
import pytest

# First Party
from caikit.core.data_model import ProducerId
from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
from caikit_tgis_backend.protobufs import generation_pb2

# Local
from caikit_nlp.toolkit.text_generation import tgis_utils

## Helpers #####################################################################


class MockTgisClient:
"""Mock of a TGIS client that doesn't actually call anything"""

def __init__(
self,
status_code: Optional[grpc.StatusCode],
error_message: str = "Yikes",
):
self._status_code = status_code
self._error_message = error_message

def _maybe_raise(self, error_type: Type[grpc.RpcError], *args):
if self._status_code not in [None, grpc.StatusCode.OK]:
raise error_type(
grpc._channel._RPCState(
[], [], [], code=self._status_code, details=self._error_message
),
*args,
)

def Generate(
self,
request: generation_pb2.BatchedGenerationRequest,
) -> generation_pb2.BatchedGenerationResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedGenerationResponse()

def GenerateStream(
self,
request: generation_pb2.SingleGenerationRequest,
) -> Iterable[generation_pb2.GenerationResponse]:
self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None)
yield generation_pb2.GenerationResponse()

def Tokenize(
self,
request: generation_pb2.BatchedTokenizeRequest,
) -> generation_pb2.BatchedTokenizeResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedTokenizeResponse()


## TGISGenerationClient ########################################################


@pytest.mark.parametrize(
"status_code",
[code for code in grpc.StatusCode if code != grpc.StatusCode.OK],
)
@pytest.mark.parametrize(
"method", ["unary_generate", "stream_generate", "unary_tokenize"]
)
def test_TGISGenerationClient_rpc_errors(status_code, method):
"""Test that raised errors in downstream RPCs are converted to
CaikitCoreException correctly
"""
tgis_client = MockTgisClient(status_code)
gen_client = tgis_utils.TGISGenerationClient(
"foo",
"bar",
tgis_client,
ProducerId("foobar"),
)
with pytest.raises(CaikitCoreException) as context:
kwargs = (
dict(
preserve_input_text=True,
input_tokens=True,
generated_tokens=True,
token_logprobs=True,
token_ranks=True,
max_new_tokens=20,
min_new_tokens=20,
truncate_input_tokens=True,
decoding_method="GREEDY",
top_k=None,
top_p=None,
typical_p=None,
temperature=None,
seed=None,
repetition_penalty=0.5,
max_time=None,
exponential_decay_length_penalty=None,
stop_sequences=["asdf"],
)
if method.endswith("_generate")
else dict()
)
res = getattr(gen_client, method)(text="foobar", **kwargs)
if method.startswith("stream_"):
next(res)

assert (
context.value.status_code == tgis_utils.GRPC_TO_CAIKIT_CORE_STATUS[status_code]
)
rpc_err = context.value.__context__
assert isinstance(rpc_err, grpc.RpcError)