From 23b31481c1742a79cf28d1d0e7f2282c20b1b895 Mon Sep 17 00:00:00 2001 From: stancld Date: Mon, 23 Jan 2023 19:21:44 +0100 Subject: [PATCH 1/5] Fix checking for nltk.punkt to avoid freezing if machine is not connected to the internet --- src/torchmetrics/functional/text/rouge.py | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 60aa8abe682..889a6a10326 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +import urllib.request from collections import Counter from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from urllib.request import HTTPError import torch from torch import Tensor, tensor @@ -39,13 +41,37 @@ ALLOWED_ACCUMULATE_VALUES = ("avg", "best") +def _is_internet_connection(): + try: + urllib.request.urlopen("https://torchmetrics.readthedocs.io/") + return True + except HTTPError: + return False + + +def _ensure_nltk_punkt_is_downloaded(): + """Check whether `nltk` `punkt` is downloaded. If not, try to download if a machine is connected to the internet.""" + import nltk + + try: + nltk.data.find('tokenizers/punkt.zip') + except LookupError: + if _is_internet_connection(): + nltk.download("punkt", quiet=True, force=False) + else: + raise OSError( + "`nltk` resource `punkt` is not available on a disk and cannot be downloaded as a machine is not " + "connected to the internet." + ) + + def _split_sentence(x: str) -> Sequence[str]: """The sentence is split to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" if not _NLTK_AVAILABLE: raise ModuleNotFoundError("ROUGE-Lsum calculation requires that `nltk` is installed. Use `pip install nltk`.") import nltk - nltk.download("punkt", quiet=True, force=False) + _ensure_nltk_punkt_is_downloaded() re.sub("", "", x) # remove pegasus newline char return nltk.sent_tokenize(x) From b455556df74d1ab9f2e7c3f8433562269315ecb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Jan 2023 18:23:23 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/text/rouge.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 889a6a10326..3c61d8d707f 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -50,11 +50,14 @@ def _is_internet_connection(): def _ensure_nltk_punkt_is_downloaded(): - """Check whether `nltk` `punkt` is downloaded. If not, try to download if a machine is connected to the internet.""" + """Check whether `nltk` `punkt` is downloaded. + + If not, try to download if a machine is connected to the internet. + """ import nltk try: - nltk.data.find('tokenizers/punkt.zip') + nltk.data.find("tokenizers/punkt.zip") except LookupError: if _is_internet_connection(): nltk.download("punkt", quiet=True, force=False) From fad0c75300f9c9869c747e6f62c18dd3db33b95d Mon Sep 17 00:00:00 2001 From: stancld Date: Mon, 23 Jan 2023 19:37:12 +0100 Subject: [PATCH 3/5] Add missing type hints --- src/torchmetrics/functional/text/rouge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 3c61d8d707f..18361a30ad9 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -41,7 +41,7 @@ ALLOWED_ACCUMULATE_VALUES = ("avg", "best") -def _is_internet_connection(): +def _is_internet_connection() -> bool: try: urllib.request.urlopen("https://torchmetrics.readthedocs.io/") return True @@ -49,7 +49,7 @@ def _is_internet_connection(): return False -def _ensure_nltk_punkt_is_downloaded(): +def _ensure_nltk_punkt_is_downloaded() -> None: """Check whether `nltk` `punkt` is downloaded. If not, try to download if a machine is connected to the internet. From 0d50f79c4c6eedbcb6e237e6555a8ed9addb44fa Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 23 Jan 2023 19:56:27 +0100 Subject: [PATCH 4/5] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94e9854e247..8e3f0faef21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed type checking on the `maximize` parameter at the initialization of `MetricTracker` ([#1428](https://github.com/Lightning-AI/metrics/issues/1428)) +- Fix checking for `nltk.punkt` in `RougeScore` if a machine is not online ([#1456](https://github.com/Lightning-AI/metrics/pull/1456)) + + ## [0.11.0] - 2022-11-30 ### Added From 8930ee92a816c51306dc0ab64fa1478e6eac3ce3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 24 Jan 2023 18:52:38 +0100 Subject: [PATCH 5/5] format --- src/torchmetrics/functional/text/rouge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 18361a30ad9..f629e8cef87 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -44,9 +44,9 @@ def _is_internet_connection() -> bool: try: urllib.request.urlopen("https://torchmetrics.readthedocs.io/") - return True except HTTPError: return False + return True def _ensure_nltk_punkt_is_downloaded() -> None: