Skip to content

Commit

Permalink
allow np.uint64 to be used in indexing. Support numpy 1.24.1 (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dr-Irv committed Jan 12, 2023
1 parent 261eabb commit bfa107b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ np_ndarray_anyint: TypeAlias = npt.NDArray[np.integer]
np_ndarray_bool: TypeAlias = npt.NDArray[np.bool_]
np_ndarray_str: TypeAlias = npt.NDArray[np.str_]

IndexType: TypeAlias = Union[slice, np_ndarray_int64, Index, list[int], Series[int]]
IndexType: TypeAlias = Union[slice, np_ndarray_anyint, Index, list[int], Series[int]]
MaskType: TypeAlias = Union[Series[bool], np_ndarray_bool, list[bool]]
# Scratch types for generics
S1 = TypeVar(
Expand Down
5 changes: 3 additions & 2 deletions pandas-stubs/core/indexes/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ from pandas._typing import (
Level,
NaPosition,
Scalar,
np_ndarray_anyint,
np_ndarray_bool,
np_ndarray_int64,
type_t,
Expand Down Expand Up @@ -192,10 +193,10 @@ class Index(IndexOpsMixin, PandasObject):
@overload
def __getitem__(
self: IndexT,
idx: slice | np_ndarray_int64 | Index | Series[bool] | np_ndarray_bool,
idx: slice | np_ndarray_anyint | Index | Series[bool] | np_ndarray_bool,
) -> IndexT: ...
@overload
def __getitem__(self, idx: int | tuple[np_ndarray_int64, ...]) -> Scalar: ...
def __getitem__(self, idx: int | tuple[np_ndarray_anyint, ...]) -> Scalar: ...
def append(self, other): ...
def putmask(self, mask, value): ...
def equals(self, other) -> bool: ...
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pyright = ">=1.1.286"
poethepoet = ">=0.16.5"
loguru = ">=0.6.0"
pandas = "1.5.2"
numpy = "<=1.23.5"
numpy = ">=1.24.1"
typing-extensions = ">=4.2.0"
matplotlib = ">=3.5.1"
pre-commit = ">=2.19.0"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

import numpy as np
import numpy.typing as npt
import pandas as pd
from pandas._testing import (
ensure_clean,
Expand Down Expand Up @@ -2363,3 +2364,16 @@ def test_frame_dropna_subset() -> None:
assert_type(df.dropna(subset=df.columns.drop("col1")), pd.DataFrame),
pd.DataFrame,
)


def test_npint_loc_indexer() -> None:
# GH 508

df = pd.DataFrame(dict(x=[1, 2, 3]), index=np.array([10, 20, 30], dtype="uint64"))

def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]) -> pd.DataFrame:
df2 = df.loc[key]
return df2

a: npt.NDArray[np.uint64] = np.array([10, 30], dtype="uint64")
check(assert_type(get_NDArray(df, a), pd.DataFrame), pd.DataFrame)
4 changes: 2 additions & 2 deletions tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd
from pandas import Grouper
from pandas.api.extensions import ExtensionArray
from pandas.util.version import Version
import pytest
from typing_extensions import assert_type

Expand Down Expand Up @@ -1705,7 +1706,7 @@ def test_pivot_table() -> None:
),
pd.DataFrame,
)
with pytest.warns(np.VisibleDeprecationWarning):
if Version(np.__version__) <= Version("1.23.5"):
check(
assert_type(
pd.pivot_table(
Expand All @@ -1719,7 +1720,6 @@ def test_pivot_table() -> None:
),
pd.DataFrame,
)
with pytest.warns(np.VisibleDeprecationWarning):
check(
assert_type(
pd.pivot_table(
Expand Down

0 comments on commit bfa107b

Please sign in to comment.