Skip to content

Commit

Permalink
[flake8-pyi] Allow overloaded __exit__ and __aexit__ definition…
Browse files Browse the repository at this point in the history
…s (`PYI036`)
  • Loading branch information
AlexWaygood committed Apr 20, 2024
1 parent c80b9a4 commit 8142db8
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 43 deletions.
42 changes: 41 additions & 1 deletion crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from collections.abc import Awaitable
from types import TracebackType
from typing import Any, Type
from typing import Any, Type, overload

import _typeshed
import typing_extensions
Expand Down Expand Up @@ -73,3 +73,43 @@ async def __aexit__(self, /, typ: type[BaseException] | None, *args: Any) -> Awa
class BadSix:
def __exit__(self, typ, exc, tb, weird_extra_arg, extra_arg2 = None) -> None: ... # PYI036: Extra arg must have default
async def __aexit__(self, typ, exc, tb, *, weird_extra_arg) -> None: ... # PYI036: kwargs must have default

# Here come the overloads...

class AcceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ...
@overload
def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...

# Using `object` or `Unused` in an overload definition is kinda strange,
# but let's allow it to be on the safe side
class AcceptableOverload2:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ...
@overload
def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...

class AcceptableOverload2:
# Just ignore any overloads that don't have exactly 3 annotated non-self parameters.
# We don't have the ability (yet) to do arbitrary checking
# of whether one function definition is a subtype of another...
@overload
def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ...
@overload
def __exit__(self, *args: object) -> None: ...
def __exit__(self, *args: object) -> None: ...

class UnacceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay
@overload
def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036

class UnacceptableOverload2:
@overload
def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036
@overload
def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036
42 changes: 41 additions & 1 deletion crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import types
import typing
from collections.abc import Awaitable
from types import TracebackType
from typing import Any, Type
from typing import Any, Type, overload

import _typeshed
import typing_extensions
Expand Down Expand Up @@ -80,3 +80,43 @@ def isolated_scope():

class ShouldNotError:
def __exit__(self, typ: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...

# Here come the overloads...

class AcceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ...
@overload
def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...

# Using `object` or `Unused` in an overload definition is kinda strange,
# but let's allow it to be on the safe side
class AcceptableOverload2:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ...
@overload
def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...

class AcceptableOverload2:
# Just ignore any overloads that don't have exactly 3 annotated non-self parameters.
# We don't have the ability (yet) to do arbitrary checking
# of whether one function definition is a subtype of another...
@overload
def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ...
@overload
def __exit__(self, *args: object) -> None: ...
def __exit__(self, *args: object) -> None: ...

class UnacceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay
@overload
def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036

class UnacceptableOverload2:
@overload
def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036
@overload
def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036
2 changes: 1 addition & 1 deletion crates/ruff_linter/src/checkers/ast/analyze/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
}
}
if checker.enabled(Rule::BadExitAnnotation) {
flake8_pyi::rules::bad_exit_annotation(checker, *is_async, name, parameters);
flake8_pyi::rules::bad_exit_annotation(checker, function_def);
}
if checker.enabled(Rule::RedundantNumericUnion) {
flake8_pyi::rules::redundant_numeric_union(checker, parameters);
Expand Down
144 changes: 108 additions & 36 deletions crates/ruff_linter/src/rules/flake8_pyi/rules/exit_annotations.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::fmt::{Display, Formatter};

use ruff_python_ast::{
Expr, ExprBinOp, ExprSubscript, ExprTuple, Identifier, Operator, ParameterWithDefault,
Parameters,
Expr, ExprBinOp, ExprSubscript, ExprTuple, Operator, ParameterWithDefault, Parameters,
StmtFunctionDef,
};
use smallvec::SmallVec;

use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};

use ruff_python_semantic::SemanticModel;
use ruff_text_size::Ranged;
use ruff_python_semantic::{analyze::visibility::is_overload, SemanticModel};
use ruff_text_size::{Ranged, TextRange};

use crate::checkers::ast::Checker;

Expand Down Expand Up @@ -68,6 +68,10 @@ impl Violation for BadExitAnnotation {
ErrorKind::FirstArgBadAnnotation => format!("The first argument in `{method_name}` should be annotated with `object` or `type[BaseException] | None`"),
ErrorKind::SecondArgBadAnnotation => format!("The second argument in `{method_name}` should be annotated with `object` or `BaseException | None`"),
ErrorKind::ThirdArgBadAnnotation => format!("The third argument in `{method_name}` should be annotated with `object` or `types.TracebackType | None`"),
ErrorKind::UnrecognizedExitOverload => format!(
"Annotations for a three-argument `{method_name}` overload (excluding `self`) \
should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType`"
)
}
}

Expand Down Expand Up @@ -104,37 +108,52 @@ enum ErrorKind {
ThirdArgBadAnnotation,
ArgsAfterFirstFourMustHaveDefault,
AllKwargsMustHaveDefault,
UnrecognizedExitOverload,
}

/// PYI036
pub(crate) fn bad_exit_annotation(
checker: &mut Checker,
is_async: bool,
name: &Identifier,
parameters: &Parameters,
) {
pub(crate) fn bad_exit_annotation(checker: &mut Checker, function: &StmtFunctionDef) {
let StmtFunctionDef {
is_async,
decorator_list,
name,
parameters,
..
} = function;

let func_kind = match name.as_str() {
"__exit__" if !is_async => FuncKind::Sync,
"__aexit__" if is_async => FuncKind::Async,
"__aexit__" if *is_async => FuncKind::Async,
_ => return,
};

let positional_args = parameters
let non_self_positional_args = parameters
.args
.iter()
.skip(1)
.chain(parameters.posonlyargs.iter())
.collect::<SmallVec<[&ParameterWithDefault; 4]>>();
.collect::<SmallVec<[&ParameterWithDefault; 3]>>();

if is_overload(decorator_list, checker.semantic()) {
check_positional_args_for_overloaded_method(
checker,
&non_self_positional_args,
func_kind,
parameters.range(),
);
return;
}

// If there are less than three positional arguments, at least one of them must be a star-arg,
// and it must be annotated with `object`.
if positional_args.len() < 4 {
if non_self_positional_args.len() < 3 {
check_short_args_list(checker, parameters, func_kind);
}

// Every positional argument (beyond the first four) must have a default.
for parameter in positional_args
for parameter in non_self_positional_args
.iter()
.skip(4)
.skip(3)
.filter(|parameter| parameter.default.is_none())
{
checker.diagnostics.push(Diagnostic::new(
Expand All @@ -161,7 +180,7 @@ pub(crate) fn bad_exit_annotation(
));
}

check_positional_args(checker, &positional_args, func_kind);
check_positional_args_for_non_overloaded_method(checker, &non_self_positional_args, func_kind);
}

/// Determine whether a "short" argument list (i.e., an argument list with less than four elements)
Expand Down Expand Up @@ -204,11 +223,11 @@ fn check_short_args_list(checker: &mut Checker, parameters: &Parameters, func_ki
}
}

/// Determines whether the positional arguments of an `__exit__` or `__aexit__` method are
/// annotated correctly.
fn check_positional_args(
/// Determines whether the positional arguments of an `__exit__` or `__aexit__` method
/// (that is not decorated with `@typing.overload`) are annotated correctly.
fn check_positional_args_for_non_overloaded_method(
checker: &mut Checker,
positional_args: &[&ParameterWithDefault],
non_self_positional_args: &[&ParameterWithDefault],
kind: FuncKind,
) {
// For each argument, define the predicate against which to check the annotation.
Expand All @@ -222,7 +241,7 @@ fn check_positional_args(
(ErrorKind::ThirdArgBadAnnotation, is_traceback_type),
];

for (arg, (error_info, predicate)) in positional_args.iter().skip(1).take(3).zip(validations) {
for (arg, (error_info, predicate)) in non_self_positional_args.iter().take(3).zip(validations) {
let Some(annotation) = arg.parameter.annotation.as_ref() else {
continue;
};
Expand All @@ -249,36 +268,91 @@ fn check_positional_args(
}
}

/// Determines whether the positional arguments of an `__exit__` or `__aexit__` method
/// overload are annotated correctly.
fn check_positional_args_for_overloaded_method(
checker: &mut Checker,
non_self_positional_args: &[&ParameterWithDefault],
kind: FuncKind,
paramters_range: TextRange,
) {
let semantic = checker.semantic();
if non_self_positional_args.len() != 3 {
return;
}
let non_self_annotations: SmallVec<[&Expr; 3]> = non_self_positional_args
.iter()
.filter_map(|arg| arg.parameter.annotation.as_deref())
.collect();
if non_self_annotations.len() != 3 {
return;
}
// We've now established that it's a function overload with 3 non-self positional arguments,
// where all arguments are annotated. It therefore follows that, in order for it to be
// correctly annotated, it must be one of the following two possible overloads:
//
// ```
// @overload
// def __(a)exit__(self, typ: None, exc: None, tb: None) -> None: ...
// @overload
// def __(a)exit__(self, typ: type[BaseException], exc: BaseException, tb: TracebackType) -> None: ...
// ```
//
// We'll allow small varations on either of these (if, e.g. a parameter is annotated
// with `object` or `_typeshed.Unused`). *Basically*, though, the rule is:
// - If the function overload matches *either* of those, it's okay.
// - If not: emit a diagnostic.
//
// Start by checking the first possibility:
if non_self_annotations.iter().all(|annotation| {
annotation.is_none_literal_expr() | is_object_or_unused(annotation, semantic)
}) {
return;
}
// Now check the second:
let matches_second_overload_variant = {
(is_base_exception_type(non_self_annotations[0], semantic)
|| is_object_or_unused(non_self_annotations[0], semantic))
&& (semantic.match_builtin_expr(non_self_annotations[1], "BaseException")
|| is_object_or_unused(non_self_annotations[1], semantic))
&& (is_traceback_type(non_self_annotations[2], semantic)
|| is_object_or_unused(non_self_annotations[2], semantic))
};
if matches_second_overload_variant {
return;
}
// Okay, neither of them match...
checker.diagnostics.push(Diagnostic::new(
BadExitAnnotation {
func_kind: kind,
error_kind: ErrorKind::UnrecognizedExitOverload,
},
paramters_range,
));
}

/// Return the non-`None` annotation element of a PEP 604-style union or `Optional` annotation.
fn non_none_annotation_element<'a>(
annotation: &'a Expr,
semantic: &SemanticModel,
) -> Option<&'a Expr> {
// E.g., `typing.Union` or `typing.Optional`
if let Expr::Subscript(ExprSubscript { value, slice, .. }) = annotation {
let qualified_name = semantic.resolve_qualified_name(value);
let qualified_name = semantic.resolve_qualified_name(value)?;

if qualified_name
.as_ref()
.is_some_and(|value| semantic.match_typing_qualified_name(value, "Optional"))
{
if semantic.match_typing_qualified_name(&qualified_name, "Optional") {
return if slice.is_none_literal_expr() {
None
} else {
Some(slice)
};
}

if !qualified_name
.as_ref()
.is_some_and(|value| semantic.match_typing_qualified_name(value, "Union"))
{
if !semantic.match_typing_qualified_name(&qualified_name, "Union") {
return None;
}

let Expr::Tuple(ExprTuple { elts, .. }) = slice.as_ref() else {
return None;
};
let ExprTuple { elts, .. } = slice.as_tuple_expr()?;

let [left, right] = elts.as_slice() else {
return None;
Expand Down Expand Up @@ -318,7 +392,6 @@ fn non_none_annotation_element<'a>(
fn is_object_or_unused(expr: &Expr, semantic: &SemanticModel) -> bool {
semantic
.resolve_qualified_name(expr)
.as_ref()
.is_some_and(|qualified_name| {
matches!(
qualified_name.segments(),
Expand All @@ -331,7 +404,6 @@ fn is_object_or_unused(expr: &Expr, semantic: &SemanticModel) -> bool {
fn is_traceback_type(expr: &Expr, semantic: &SemanticModel) -> bool {
semantic
.resolve_qualified_name(expr)
.as_ref()
.is_some_and(|qualified_name| {
matches!(qualified_name.segments(), ["types", "TracebackType"])
})
Expand Down

0 comments on commit 8142db8

Please sign in to comment.