Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.4.10
Choose a base ref
...
head repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.4.11
Choose a head ref
  • 6 commits
  • 4 files changed
  • 1 contributor

Commits on May 16, 2024

  1. 🧵 Add timeout configuration for TGIS streaming request as an experiment

    Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
    gkumbhat committed May 16, 2024
    Copy the full SHA
    6ecc54a View commit details

Commits on May 17, 2024

  1. ✨ Add tgis req timeout as configurable parameter

    Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
    gkumbhat committed May 17, 2024
    Copy the full SHA
    40fb403 View commit details
  2. ✅ Fix tgis client fixture for acceting kwargs

    Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
    gkumbhat committed May 17, 2024
    Copy the full SHA
    48216eb View commit details
  3. 🎨 Fix formatting

    Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
    gkumbhat committed May 17, 2024
    Copy the full SHA
    401510d View commit details
  4. 🐛✅ Fix fixture for tgis tests

    Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
    gkumbhat committed May 17, 2024
    Copy the full SHA
    fc5ebda View commit details
  5. Merge pull request #358 from caikit/add_tgis_timeout

    Add tgis timeout
    gkumbhat authored May 17, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    98578c9 View commit details
Showing with 24 additions and 15 deletions.
  1. +3 −0 caikit_nlp/config/config.yml
  2. +12 −3 caikit_nlp/toolkit/text_generation/tgis_utils.py
  3. +6 −6 tests/fixtures/__init__.py
  4. +3 −6 tests/toolkit/text_generation/test_tgis_utils.py
3 changes: 3 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
@@ -56,3 +56,6 @@ embedding:

runtime:
library: caikit_nlp

# Configure request timeout for TGIS backend (in seconds)
tgis_request_timeout: 60
15 changes: 12 additions & 3 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
import grpc

# First Party
from caikit import get_config
from caikit.core.exceptions import error_handler
from caikit.core.exceptions.caikit_core_exception import (
CaikitCoreException,
@@ -326,6 +327,8 @@ def __init__(
self.producer_id = producer_id
self.prefix_id = prefix_id

self.tgis_req_timeout = get_config().tgis_request_timeout

def unary_generate(
self,
text,
@@ -432,7 +435,9 @@ def unary_generate(
# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
try:
batch_response = self.tgis_client.Generate(request)
batch_response = self.tgis_client.Generate(
request, timeout=self.tgis_req_timeout
)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

@@ -576,7 +581,9 @@ def stream_generate(

# stream GenerationResponse
try:
stream_response = self.tgis_client.GenerateStream(request)
stream_response = self.tgis_client.GenerateStream(
request, timeout=self.tgis_req_timeout
)

for stream_part in stream_response:
details = TokenStreamDetails(
@@ -645,7 +652,9 @@ def unary_tokenize(
# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
try:
batch_response = self.tgis_client.Tokenize(request)
batch_response = self.tgis_client.Tokenize(
request, timeout=self.tgis_req_timeout
)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

12 changes: 6 additions & 6 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -198,17 +198,17 @@ class StubTGISClient:
def __init__(self, base_model_name):
pass

def Generate(self, request):
def Generate(self, request, **kwargs):
return StubTGISClient.unary_generate(request)

def GenerateStream(self, request):
def GenerateStream(self, request, **kwargs):
return StubTGISClient.stream_generate(request)

def Tokenize(self, request):
def Tokenize(self, request, **kwargs):
return StubTGISClient.tokenize(request)

@staticmethod
def unary_generate(request):
def unary_generate(request, **kwargs):
fake_response = mock.Mock()
fake_result = mock.Mock()
fake_result.stop_reason = 5
@@ -229,7 +229,7 @@ def unary_generate(request):
return fake_response

@staticmethod
def stream_generate(request):
def stream_generate(request, **kwargs):
fake_stream = mock.Mock()
fake_stream.stop_reason = 5
fake_stream.generated_token_count = 1
@@ -250,7 +250,7 @@ def stream_generate(request):
yield fake_stream

@staticmethod
def tokenize(request):
def tokenize(request, **kwargs):
fake_response = mock.Mock()
fake_result = mock.Mock()
fake_result.token_count = 1
9 changes: 3 additions & 6 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
@@ -54,22 +54,19 @@ def _maybe_raise(self, error_type: Type[grpc.RpcError], *args):
)

def Generate(
self,
request: generation_pb2.BatchedGenerationRequest,
self, request: generation_pb2.BatchedGenerationRequest, **kwargs
) -> generation_pb2.BatchedGenerationResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedGenerationResponse()

def GenerateStream(
self,
request: generation_pb2.SingleGenerationRequest,
self, request: generation_pb2.SingleGenerationRequest, **kwargs
) -> Iterable[generation_pb2.GenerationResponse]:
self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None)
yield generation_pb2.GenerationResponse()

def Tokenize(
self,
request: generation_pb2.BatchedTokenizeRequest,
self, request: generation_pb2.BatchedTokenizeRequest, **kwargs
) -> generation_pb2.BatchedTokenizeResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedTokenizeResponse()