Skip to content

Commit

Permalink
Update text docs (#1416)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
5 people committed Jan 25, 2023
1 parent d233c9d commit ef13ca1
Show file tree
Hide file tree
Showing 31 changed files with 292 additions and 321 deletions.
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
.. _Bert_score Evaluating Text Generation: https://arxiv.org/abs/1904.09675
.. _BLEU score: https://en.wikipedia.org/wiki/BLEU
.. _BLEU: http://www.aclweb.org/anthology/P02-1040.pdf
.. _SacreBLEU: https://github.com/mjpost/sacrebleu
.. _SacreBleu_ter: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/ter.py
.. _Machine Translation Evolution: https://aclanthology.org/P04-1077.pdf
.. _Rouge score_Text Normalizition: https://github.com/google-research/google-research/blob/master/rouge/tokenize.py
.. _Calculate Rouge Score: https://en.wikipedia.org/wiki/ROUGE_(metric)
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/bert_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.text.bert.BERTScore
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/bleu_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.BLEUScore
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/char_error_rate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.CharErrorRate
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/chrf_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.CHRFScore
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/extended_edit_distance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.ExtendedEditDistance
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/infolm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.text.infolm.InfoLM
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/match_error_rate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.MatchErrorRate
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/perplexity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.text.perplexity.Perplexity
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/rouge_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.text.rouge.ROUGEScore
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/sacre_bleu_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.SacreBLEUScore
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/squad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.SQuAD
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/translation_edit_rate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.TranslationEditRate
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/word_error_rate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.WordErrorRate
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/word_info_lost.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.WordInfoLost
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/text/word_info_preserved.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.WordInfoPreserved
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
48 changes: 24 additions & 24 deletions src/torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,48 +47,55 @@ class BERTScore(Metric):
This implemenation follows the original implementation from `BERT_score`_.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~List`): An iterable of predicted sentences
- ``target`` (:class:`~List`): An iterable of reference sentences
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``score`` (:class:`~Dict`): A dictionary containing the keys ``precision``, ``recall`` and ``f1`` with
corresponding values
Args:
preds: An iterable of predicted sentences.
target: An iterable of target sentences.
model_type: A name or a model path used to load `transformers` pretrained model.
model_type: A name or a model path used to load ``transformers`` pretrained model.
num_layers: A layer of representation to use.
all_layers:
An indication of whether the representation from all model's layers should be used.
If `all_layers = True`, the argument `num_layers` is ignored.
If ``all_layers=True``, the argument ``num_layers`` is ignored.
model: A user's own model. Must be of `torch.nn.Module` instance.
user_tokenizer:
A user's own tokenizer used with the own model. This must be an instance with the `__call__` method.
A user's own tokenizer used with the own model. This must be an instance with the ``__call__`` method.
This method must take an iterable of sentences (`List[str]`) and must return a python dictionary
containing `"input_ids"` and `"attention_mask"` represented by :class:`~torch.Tensor`.
It is up to the user's model of whether `"input_ids"` is a :class:`~torch.Tensor` of input ids or embedding
vectors. This tokenizer must prepend an equivalent of `[CLS]` token and append an equivalent of `[SEP]`
token as `transformers` tokenizer does.
vectors. This tokenizer must prepend an equivalent of ``[CLS]`` token and append an equivalent of ``[SEP]``
token as ``transformers`` tokenizer does.
user_forward_fn:
A user's own forward function used in a combination with `user_model`. This function must take `user_model`
and a python dictionary of containing `"input_ids"` and `"attention_mask"` represented
A user's own forward function used in a combination with ``user_model``. This function must take
``user_model`` and a python dictionary of containing ``"input_ids"`` and ``"attention_mask"`` represented
by :class:`~torch.Tensor` as an input and return the model's output represented by the single
:class:`~torch.Tensor`.
verbose: An indication of whether a progress bar to be displayed during the embeddings' calculation.
idf: An indication whether normalization using inverse document frequencies should be used.
device: A device to be used for calculation.
max_length: A maximum length of input sequences. Sequences longer than `max_length` are to be trimmed.
max_length: A maximum length of input sequences. Sequences longer than ``max_length`` are to be trimmed.
batch_size: A batch size used for model processing.
num_threads: A number of threads to use for a dataloader.
return_hash: An indication of whether the correspodning `hash_code` should be returned.
return_hash: An indication of whether the correspodning ``hash_code`` should be returned.
lang: A language of input sentences.
rescale_with_baseline:
An indication of whether bertscore should be rescaled with a pre-computed baseline.
When a pretrained model from `transformers` model is used, the corresponding baseline is downloaded
from the original `bert-score` package from `BERT_score`_ if available.
When a pretrained model from ``transformers`` model is used, the corresponding baseline is downloaded
from the original ``bert-score`` package from `BERT_score`_ if available.
In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting
of the files from `BERT_score`_.
baseline_path: A path to the user's own local csv/tsv file with the baseline scale.
baseline_url: A url path to the user's own csv/tsv file with the baseline scale.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Returns:
Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
Example:
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
Expand Down Expand Up @@ -175,12 +182,9 @@ def __init__(
self.add_state("target_attention_mask", [], dist_reduce_fx="cat")

def update(self, preds: List[str], target: List[str]) -> None:
"""Store predictions/references for computing BERT scores. It is necessary to store sentences in a
tokenized form to ensure the DDP mode working.
"""Store predictions/references for computing BERT scores.
Args:
preds: An iterable of predicted sentences.
target: An iterable of reference sentences.
It is necessary to store sentences in a tokenized form to ensure the DDP mode working.
"""
preds_dict, _ = _preprocess_text(
preds,
Expand All @@ -205,11 +209,7 @@ def update(self, preds: List[str], target: List[str]) -> None:
self.target_attention_mask.append(target_dict["attention_mask"])

def compute(self) -> Dict[str, Union[List[float], str]]:
"""Calculate BERT scores.
Return:
Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
"""
"""Calculate BERT scores."""
return bert_score(
preds=_get_input_dict(self.preds_input_ids, self.preds_attention_mask),
target=_get_input_dict(self.target_input_ids, self.target_attention_mask),
Expand Down
35 changes: 14 additions & 21 deletions src/torchmetrics/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,18 @@
class BLEUScore(Metric):
"""Calculate `BLEU score`_ of machine translated text with one or more references.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~Sequence`): An iterable of machine translated corpus
- ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus
As output of ``forward`` and ``update`` the metric returns the following output:
- ``bleu`` (:class:`~torch.Tensor`): A tensor with the BLEU Score
Args:
n_gram: Gram value ranged from 1 to 4
smooth: Whether or not to apply smoothing, see [2]
smooth: Whether or not to apply smoothing, see `Machine Translation Evolution`_
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
weights:
Weights used for unigrams, bigrams, etc. to calculate BLEU score.
Expand All @@ -43,16 +52,9 @@ class BLEUScore(Metric):
>>> from torchmetrics import BLEUScore
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric = BLEUScore()
>>> metric(preds, target)
>>> bleu = BLEUScore()
>>> bleu(preds, target)
tensor(0.7598)
References:
[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni,
Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu `BLEU`_
[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
"""

is_differentiable: bool = False
Expand Down Expand Up @@ -84,12 +86,7 @@ def __init__(
self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")

def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None:
"""Compute Precision Scores.
Args:
preds: An iterable of machine translated corpus
target: An iterable of iterables of reference corpus
"""
"""Update state with predictions and targets."""
self.preds_len, self.target_len = _bleu_score_update(
preds,
target,
Expand All @@ -102,11 +99,7 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None:
)

def compute(self) -> Tensor:
"""Calculate BLEU score.
Return:
Tensor with BLEU Score
"""
"""Calculate BLEU score."""
return _bleu_score_compute(
self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.weights, self.smooth
)
32 changes: 15 additions & 17 deletions src/torchmetrics/text/cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


class CharErrorRate(Metric):
r"""Character Error Rate (CER_) is a metric of the performance of an automatic speech recognition (ASR) system.
r"""Character Error Rate (`CER`_) is a metric of the performance of an automatic speech recognition (ASR)
system.
This value indicates the percentage of characters that were incorrectly predicted.
The lower the value, the better the performance of the ASR system with a CharErrorRate of 0 being
Expand All @@ -41,17 +42,23 @@ class CharErrorRate(Metric):
Compute CharErrorRate score of transcribed segments against references.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~str`): Transcription(s) to score as a string or list of strings
- ``target`` (:class:`~str`): Reference(s) for each speech input as a string or list of strings
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``cer`` (:class:`~torch.Tensor`): A tensor with the Character Error Rate score
Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Returns:
Character error rate score
Examples:
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric = CharErrorRate()
>>> metric(preds, target)
>>> cer = CharErrorRate()
>>> cer(preds, target)
tensor(0.3415)
"""
is_differentiable: bool = False
Expand All @@ -70,20 +77,11 @@ def __init__(
self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum")

def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore
"""Store references/predictions for computing Character Error Rate scores.
Args:
preds: Transcription(s) to score as a string or list of strings
target: Reference(s) for each speech input as a string or list of strings
"""
"""Update state with predictions and targets."""
errors, total = _cer_update(preds, target)
self.errors += errors
self.total += total

def compute(self) -> Tensor:
"""Calculate the character error rate.
Returns:
Character error rate score
"""
"""Calculate the character error rate."""
return _cer_compute(self.errors, self.total)

0 comments on commit ef13ca1

Please sign in to comment.