Skip to content

Commit

Permalink
idbstubs: Split parameter-broadening from signature-sorting and impro…
Browse files Browse the repository at this point in the history
…ve both
  • Loading branch information
WMOkiishi committed Dec 22, 2022
1 parent 0486186 commit 6d9859d
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 86 deletions.
158 changes: 112 additions & 46 deletions idbstubs/processors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import re
from collections import Counter
from collections.abc import Mapping, Sequence
from itertools import combinations, zip_longest
from typing import Final

import attrs

from .reps import (
Alias,
Attribute,
Expand All @@ -25,12 +28,12 @@
)
from .typedata import (
combine_types,
expand_type,
get_mro,
get_param_type_replacement,
process_dependency,
subtype_relationship,
)
from .util import flatten

_logger: Final = logging.getLogger(__name__)
_absent_param: Final = Parameter('', 'Never', is_optional=True, named=False)
Expand Down Expand Up @@ -129,55 +132,125 @@ def merge_signatures(a: Signature, b: Signature, /) -> Signature | None:
return Signature(merged_params, merged_return)


def get_signature_depths(signatures: Sequence[Signature]) -> list[int]:
"""Return a list of integers such that if the given signatures are ordered
according to the integer appearing at the same index, a type checker will
understand which to use.
def sort_signatures(signatures: Sequence[Signature]) -> list[Signature]:
"""Return the signatures from the given sequence sorted so that
a type-checker may choose the first matching one.
"""
# A list of lists of the indices of the signatures that should be defined
# after the signature at the outermost index
define_after: list[list[int]] = [[] for _ in signatures]
for (i, w1), (k, w2) in combinations(enumerate(signatures), 2):
if w1.min_arity() != w2.min_arity():
# `define_after` maps a signature index to a set of the indices
# of all signatures that should be defined after it.
define_after: dict[int, set[int]] = {n: set() for n in range(len(signatures))}
for (i, sig_1), (j, sig_2) in combinations(enumerate(signatures), 2):
if sig_1.min_arity() != sig_2.min_arity():
# This should technically check if the number of parameters
# can overlap, but this works well enough in practice.
continue
if sig_1.return_type == sig_2.return_type:
# Ambiguity doesn't matter if they return the same type.
continue
w1_first = w2_first = True
w1_params = (p for p in w1.parameters if not p.is_optional)
w2_params = (p for p in w2.parameters if not p.is_optional)
for p, q in zip(w1_params, w2_params):
if not (w1_first or w2_first):
sig_1_first = sig_2_first = True
for p, q in zip(
(p for p in sig_1.parameters if not p.is_optional),
(p for p in sig_2.parameters if not p.is_optional),
strict=True,
):
if not (sig_1_first or sig_2_first):
# These signatures don't overlap.
break
p_subtypes_q, q_subtypes_p = subtype_relationship(p.type, q.type)
w1_first &= p_subtypes_q
w2_first &= q_subtypes_p
if w1_first:
define_after[i].append(k)
elif w2_first:
define_after[k].append(i)
sig_depths = [0 for _ in signatures]
for i, others in enumerate(define_after):
seen = {i, *others}
sig_1_first &= p_subtypes_q
sig_2_first &= q_subtypes_p
if sig_1_first:
define_after[i].add(j)
elif sig_2_first:
define_after[j].add(i)
# Count how many "layers" of signatures each one needs to appear before.
sig_depths = Counter[int](range(len(signatures)))
for i, others in define_after.items():
seen = {i} | others
while others:
sig_depths[i] -= 1
next_others: list[int] = []
for k in flatten(define_after[j] for j in others if j not in seen):
if k in seen:
sig_depths[i] += 1
next_others = set[int]()
for j in others:
if j in seen:
continue
seen.add(k)
next_others.append(k)
batch = define_after[j] - seen
seen |= batch
next_others |= batch
others = next_others
return sig_depths
return [signatures[i] for i, n in sig_depths.most_common()]


def insert_casts(signatures: Sequence[Signature]) -> list[Signature]:
"""Broaden the parameter types of a sequence of signatures
without introducing ambiguity.
"""
# Store all possible parameter type expansions in a two-layer mapping.
param_types: dict[int, dict[int, str | None]] = {
i: {
j: get_param_type_replacement(param.type)
for j, param in enumerate(sig.parameters)
}
for i, sig in enumerate(signatures)
}
for (i, sig_1), (j, sig_2) in combinations(enumerate(signatures), 2):
if sig_1.min_arity() != sig_2.min_arity():
# This should technically check if the number of parameters
# can overlap, but this works well enough in practice.
continue
r1_subtypes_r2, r2_subtypes_r1 = subtype_relationship(
sig_1.return_type, sig_2.return_type
)
if r1_subtypes_r2 and r2_subtypes_r1:
# The return types are identical; ambiguity doesn't matter.
continue
exclusive_1_params: dict[int, str] = {}
exclusive_2_params: dict[int, str] = {}
sig_1_narrower = sig_2_narrower = True
for k, (p, q) in enumerate(zip(
(p for p in sig_1.parameters if not p.is_optional),
(p for p in sig_2.parameters if not p.is_optional),
strict=True,
)):
if not (sig_1_narrower or sig_2_narrower):
# These signatures don't overlap.
break
s, t = p.type, q.type
if not r1_subtypes_r2:
s = param_types[i][k] or s
if not r2_subtypes_r1:
t = param_types[j][k] or t
s_subtypes_t, t_subtypes_s = subtype_relationship(s, t)
sig_1_narrower &= s_subtypes_t
sig_2_narrower &= t_subtypes_s
if t_subtypes_s and not s_subtypes_t:
exclusive_1_params[k] = combine_types(
*(expand_type(s) - expand_type(q.type))
)
if s_subtypes_t and not t_subtypes_s:
exclusive_2_params[k] = combine_types(
*(expand_type(t) - expand_type(p.type))
)
# If one of the signature's parameter types are always narrower
# than the other's, don't broaden the narrower types, and remove
# them from the broader types if they appear explicitly.
if sig_1_narrower:
param_types[i] |= {k: None for k in range(sig_1.min_arity())}
param_types[j] |= exclusive_2_params
elif sig_2_narrower:
param_types[j] |= {k: None for k in range(sig_2.min_arity())}
param_types[i] |= exclusive_1_params
return [
attrs.evolve(sig, parameters=[
attrs.evolve(param, type=param_types[i][j] or param.type)
for j, param in enumerate(sig.parameters)
]) for i, sig in enumerate(signatures)
]


def process_signatures(signatures: Sequence[Signature]) -> list[Signature]:
"""Sort and compress the given sequence of function signatures."""
depths = get_signature_depths(signatures)
sorted_signatures: list[Signature] = []
for i, depth in sorted(enumerate(depths), key=lambda x: x[1]):
signature = signatures[i]
if depth == 0:
account_for_casts(signature)
sorted_signatures.append(signature)

with_casts = insert_casts(signatures)
sorted_signatures = sort_signatures(with_casts)
merged_signatures: list[Signature] = []
for new in sorted(sorted_signatures, key=Signature.get_arity):
for i, old in enumerate(merged_signatures):
Expand All @@ -191,13 +264,6 @@ def process_signatures(signatures: Sequence[Signature]) -> list[Signature]:
return merged_signatures


def account_for_casts(signature: Signature) -> None:
for param in signature.parameters:
replacement = get_param_type_replacement(param.type)
if replacement is not None:
param.type = replacement


def process_function(function: Function) -> None:
return_overrides = RETURN_TYPE_OVERRIDES.get(function.scoped_name, {})
param_overrides = PARAM_TYPE_OVERRIDES.get(function.scoped_name, {})
Expand Down
6 changes: 6 additions & 0 deletions idbstubs/special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def log_unused() -> None:
'panda3d.vision.ARToolKit.make': {(0, 0): 'NodePath[Camera]'},
})
RETURN_TYPE_OVERRIDES: Final = TrackingMap[str, str | dict[int, str]]({
'panda3d.core.__mul__': {
16: 'LVecBase2d',
17: 'LVecBase2f',
19: 'LVecBase3d',
21: 'LVecBase3f',
},
'panda3d.core.AsyncFuture.__await__': 'Generator[Awaitable, None, None]',
'panda3d.core.AsyncFuture.__iter__': 'Generator[Awaitable, None, None]',
'panda3d.core.AsyncFuture.gather': 'AsyncFuture',
Expand Down
2 changes: 1 addition & 1 deletion src/panda3d-stubs/core/_downloader.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ class HTTPDate:
def __isub__(self: Self, seconds: int) -> Self: ... # type: ignore[misc]
def __add__(self, seconds: int) -> HTTPDate: ...
@overload
def __sub__(self, other: HTTPDate | int | str) -> int: ...
def __sub__(self, other: HTTPDate | str) -> int: ...
@overload
def __sub__(self, seconds: int) -> HTTPDate: ...
def assign(self, copy: HTTPDate | int | str) -> HTTPDate: ...
Expand Down
4 changes: 2 additions & 2 deletions src/panda3d-stubs/core/_dtoolutil.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,10 @@ class Filename:
subdirectory for storing application-specific data, common to all users.
"""
@overload
def assign(self, filename: str) -> Filename:
def assign(self, copy: StrOrBytesPath) -> Filename:
"""Assignment is via the = operator."""
@overload
def assign(self, copy: StrOrBytesPath) -> Filename: ...
def assign(self, filename: str) -> Filename: ...
def c_str(self) -> str: ...
def empty(self) -> bool: ...
def length(self) -> int: ...
Expand Down
12 changes: 6 additions & 6 deletions src/panda3d-stubs/core/_dxml.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class TiXmlDocument(TiXmlNode):
"""

@overload
def __init__(self, documentName: str = ...) -> None:
def __init__(self, copy: TiXmlDocument | str = ...) -> None:
"""`(self)`:
Create an empty document, that has no name.
Expand All @@ -361,7 +361,7 @@ class TiXmlDocument(TiXmlNode):
Constructor.
"""
@overload
def __init__(self, copy: TiXmlDocument | str) -> None: ...
def __init__(self, documentName: str) -> None: ...
def assign(self, copy: TiXmlDocument | str) -> TiXmlDocument: ...
@overload
def LoadFile(self, encoding: _TiXmlEncoding = ...) -> bool:
Expand Down Expand Up @@ -455,15 +455,15 @@ class TiXmlElement(TiXmlNode):
"""

@overload
def __init__(self, in_value: str) -> None:
def __init__(self, __param0: TiXmlElement | str) -> None:
"""`(self, in_value: str)`:
Construct an element.
`(self, _value: str)`:
std::string constructor.
"""
@overload
def __init__(self, __param0: TiXmlElement | str) -> None: ...
def __init__(self, in_value: str) -> None: ...
@overload
def __init__(self, _value: str) -> None: ...
def assign(self, base: TiXmlElement | str) -> TiXmlElement: ...
Expand Down Expand Up @@ -709,7 +709,7 @@ class TiXmlText(TiXmlNode):
"""

@overload
def __init__(self, initValue: str) -> None:
def __init__(self, copy: TiXmlText | str) -> None:
"""`(self, initValue: str)`:
Constructor for text element. By default, it is treated as
normal, encoded text. If you want it be output as a CDATA text
Expand All @@ -719,7 +719,7 @@ class TiXmlText(TiXmlNode):
Constructor.
"""
@overload
def __init__(self, copy: TiXmlText | str) -> None: ...
def __init__(self, initValue: str) -> None: ...
def assign(self, base: TiXmlText | str) -> TiXmlText: ...
def CDATA(self) -> bool:
"""Queries whether this represents text using a CDATA section."""
Expand Down
8 changes: 5 additions & 3 deletions src/panda3d-stubs/core/_event.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class EventParameter:

DtoolClassDict: ClassVar[dict[str, Any]]
@overload
def __init__(self, value: float | str = ...) -> None:
def __init__(
self, copy: EventParameter | TypedReferenceCount | TypedWritableReferenceCount | float | str | None = ...
) -> None:
"""`(self, ptr: TypedReferenceCount)`:
Defines an EventParameter that stores a pointer to a TypedReferenceCount
object. Note that a TypedReferenceCount is not the same kind of pointer as
Expand Down Expand Up @@ -64,9 +66,9 @@ class EventParameter:
Defines an EventParameter that stores a wstring value.
"""
@overload
def __init__(self, copy: EventParameter | TypedReferenceCount | TypedWritableReferenceCount | float | str | None) -> None: ...
@overload
def __init__(self, ptr: TypedReferenceCount | TypedWritableReferenceCount) -> None: ...
@overload
def __init__(self, value: float | str) -> None: ...
def assign(
self, copy: EventParameter | TypedReferenceCount | TypedWritableReferenceCount | float | str | None
) -> EventParameter: ...
Expand Down

0 comments on commit 6d9859d

Please sign in to comment.