Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix & extend StrEnum.from_str #99

Merged
merged 7 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow frozen dataclasses in `apply_to_collection` ([#98](https://github.com/Lightning-AI/utilities/pull/98))


- Extended `StrEnum.from_str` with optional raising ValueError ([#99](https://github.com/Lightning-AI/utilities/pull/99))


### Changed

- CI/docs: allow passing env. variables ([#96](https://github.com/Lightning-AI/utilities/pull/96))


### Fixed

-
- Fixed `StrEnum.from_str` with source as key ([#99](https://github.com/Lightning-AI/utilities/pull/99))


## [0.6.0] - 2023-01-23
Expand Down
68 changes: 61 additions & 7 deletions src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
import warnings
from enum import Enum
from typing import Optional

Expand All @@ -16,19 +17,72 @@ class StrEnum(str, Enum):
... t2 = "T-2"
>>> MySE("T-1") == MySE.t1
True
>>> MySE.from_str("t-2") == MySE.t2
>>> MySE.from_str("t-2", source="value") == MySE.t2
True
"""

@classmethod
def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
for st, val in cls.__members__.items():
if source in ("key", "any") and st.lower() == value.lower():
return cls[st]
if source in ("value", "any") and val.lower() == value.lower():
return cls[st]
def from_str(
cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False
) -> Optional["StrEnum"]:
"""Create StrEnum from a sting matching the key or value.

Args:
value: matching string
source: compare with:

- ``"key"``: validates only with Enum keys, typical alphanumeric with "_"
- ``"value"``: validates only with Enum values, could be any string
- ``"key"``: validates with any key or value, but key has priority

strict: allow not matching string and returns None; if false raises exceptions

Raises:
ValueError:
if requested string does not match any option based on selected source and use ``"strict=True"``
UserWarning:
if requested string does not match any option based on selected source and use ``"strict=False"``

Example:
>>> class MySE(StrEnum):
... t1 = "T-1"
... t2 = "T-2"
>>> MySE.from_str("t-1", source="key")
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any", strict=True)
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
"""
allowed = cls._allowed_matches(source)
if strict and not any(enum_.lower() == value.lower() for enum_ in allowed):
raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.")

if source in ("key", "any"):
for enum_key in cls.__members__.keys():
if enum_key.lower() == value.lower():
return cls[enum_key]
if source in ("value", "any"):
for enum_key, enum_val in cls.__members__.items():
if enum_val == value:
return cls[enum_key]

warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}."))
return None

@classmethod
def _allowed_matches(cls, source: str) -> list:
keys, vals = [], []
for enum_key, enum_val in cls.__members__.items():
keys.append(enum_key)
vals.append(enum_val.value)
if source == "key":
return keys
if source == "value":
return vals
return keys + vals

def __eq__(self, other: object) -> bool:
if isinstance(other, Enum):
other = other.value
Expand Down