Skip to content

Commit

Permalink
Fix checking for nltk.punkt if a machine not online (#1456)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: SkafteNicki <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people committed Jan 28, 2023
1 parent 8d17f0b commit 5d4ffe0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -43,6 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed mixed precision autocast for `SSIM` metric ([#1454](https://github.com/Lightning-AI/metrics/pull/1454))

- Fix checking for `nltk.punkt` in `RougeScore` if a machine is not online ([#1456](https://github.com/Lightning-AI/metrics/pull/1456))


- Fixed wrongly reset method in `MultioutputWrapper` ([#1460](https://github.com/Lightning-AI/metrics/issues/1460))


Expand Down
31 changes: 30 additions & 1 deletion src/torchmetrics/functional/text/rouge.py
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import urllib.request
from collections import Counter
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from urllib.request import HTTPError

import torch
from torch import Tensor, tensor
Expand All @@ -39,13 +41,40 @@
ALLOWED_ACCUMULATE_VALUES = ("avg", "best")


def _is_internet_connection() -> bool:
try:
urllib.request.urlopen("https://torchmetrics.readthedocs.io/")
except HTTPError:
return False
return True


def _ensure_nltk_punkt_is_downloaded() -> None:
"""Check whether `nltk` `punkt` is downloaded.
If not, try to download if a machine is connected to the internet.
"""
import nltk

try:
nltk.data.find("tokenizers/punkt.zip")
except LookupError:
if _is_internet_connection():
nltk.download("punkt", quiet=True, force=False)
else:
raise OSError(
"`nltk` resource `punkt` is not available on a disk and cannot be downloaded as a machine is not "
"connected to the internet."
)


def _split_sentence(x: str) -> Sequence[str]:
"""The sentence is split to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
if not _NLTK_AVAILABLE:
raise ModuleNotFoundError("ROUGE-Lsum calculation requires that `nltk` is installed. Use `pip install nltk`.")
import nltk

nltk.download("punkt", quiet=True, force=False)
_ensure_nltk_punkt_is_downloaded()

re.sub("<n>", "", x) # remove pegasus newline char
return nltk.sent_tokenize(x)
Expand Down

0 comments on commit 5d4ffe0

Please sign in to comment.