From c60b912181a1bd6d0657c3c651b3f0f82ff2395a Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Sat, 28 Jan 2023 17:03:00 +0100 Subject: [PATCH] Fix checking for nltk.punkt if a machine not online (#1456) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SkafteNicki Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> (cherry picked from commit 5d4ffe01aa09b7108f7e0e4034748bdfd64bf5f9) --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/text/rouge.py | 31 ++++++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80a96065eae..7aaf855568e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed mixed precision autocast for `SSIM` metric ([#1454](https://github.com/Lightning-AI/metrics/pull/1454)) +- Fix checking for `nltk.punkt` in `RougeScore` if a machine is not online ([#1456](https://github.com/Lightning-AI/metrics/pull/1456)) + + - Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460)) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 60aa8abe682..f629e8cef87 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +import urllib.request from collections import Counter from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from urllib.request import HTTPError import torch from torch import Tensor, tensor @@ -39,13 +41,40 @@ ALLOWED_ACCUMULATE_VALUES = ("avg", "best") +def _is_internet_connection() -> bool: + try: + urllib.request.urlopen("https://torchmetrics.readthedocs.io/") + except HTTPError: + return False + return True + + +def _ensure_nltk_punkt_is_downloaded() -> None: + """Check whether `nltk` `punkt` is downloaded. + + If not, try to download if a machine is connected to the internet. + """ + import nltk + + try: + nltk.data.find("tokenizers/punkt.zip") + except LookupError: + if _is_internet_connection(): + nltk.download("punkt", quiet=True, force=False) + else: + raise OSError( + "`nltk` resource `punkt` is not available on a disk and cannot be downloaded as a machine is not " + "connected to the internet." + ) + + def _split_sentence(x: str) -> Sequence[str]: """The sentence is split to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" if not _NLTK_AVAILABLE: raise ModuleNotFoundError("ROUGE-Lsum calculation requires that `nltk` is installed. Use `pip install nltk`.") import nltk - nltk.download("punkt", quiet=True, force=False) + _ensure_nltk_punkt_is_downloaded() re.sub("", "", x) # remove pegasus newline char return nltk.sent_tokenize(x)