Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.4.18
Choose a base ref
...
head repository: caikit/caikit-nlp
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.5.0
Choose a head ref
  • 3 commits
  • 3 files changed
  • 2 contributors

Commits on Aug 1, 2024

  1. set default values for params to false or none

    Signed-off-by: waleedqk <waleedqk@ibm.com>
    waleedqk committed Aug 1, 2024
    Copy the full SHA
    12eeab2 View commit details
  2. format file

    Signed-off-by: waleedqk <waleedqk@ibm.com>
    waleedqk committed Aug 1, 2024
    Copy the full SHA
    842bba3 View commit details
  3. Merge pull request #370 from waleedqk/generated_tokens

    set default values for params to false or none
    evaline-ju authored Aug 1, 2024
    Copy the full SHA
    1e5aac6 View commit details
16 changes: 8 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
@@ -213,10 +213,10 @@ def run(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
@@ -280,10 +280,10 @@ def run_stream_out(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
16 changes: 8 additions & 8 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
@@ -235,10 +235,10 @@ def run(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
@@ -296,10 +296,10 @@ def run_stream_out(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
5 changes: 4 additions & 1 deletion caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
@@ -144,7 +144,10 @@ def validate_inf_params(
error.type_check("<NLP65883540E>", bool, token_logprobs=token_logprobs)
error.type_check("<NLP65883541E>", bool, token_ranks=token_ranks)
error.type_check(
"<NLP65883542E>", bool, include_stop_sequence=include_stop_sequence
"<NLP65883542E>",
bool,
allow_none=True,
include_stop_sequence=include_stop_sequence,
)
error.type_check("<NLP85452188E>", str, allow_none=True, eos_token=eos_token)
error.type_check(