From 685de912ff1d39656b21843611e151300d60b3da Mon Sep 17 00:00:00 2001 From: Auguste Lalande Date: Wed, 20 Mar 2024 20:36:17 -0400 Subject: [PATCH] [`pylint`] Implement `nan-comparison` (`PLW0117`) (#10401) ## Summary Implement pylint's nan-comparison, part of #970. ## Test Plan Text fixture was added. --- .../test/fixtures/pylint/nan_comparison.py | 76 ++++++++++ .../src/checkers/ast/analyze/expression.rs | 3 + crates/ruff_linter/src/codes.rs | 1 + crates/ruff_linter/src/rules/pylint/mod.rs | 1 + .../ruff_linter/src/rules/pylint/rules/mod.rs | 2 + .../src/rules/pylint/rules/nan_comparison.rs | 135 ++++++++++++++++++ ...int__tests__PLW0117_nan_comparison.py.snap | 82 +++++++++++ ruff.schema.json | 2 + 8 files changed, 302 insertions(+) create mode 100644 crates/ruff_linter/resources/test/fixtures/pylint/nan_comparison.py create mode 100644 crates/ruff_linter/src/rules/pylint/rules/nan_comparison.rs create mode 100644 crates/ruff_linter/src/rules/pylint/snapshots/ruff_linter__rules__pylint__tests__PLW0117_nan_comparison.py.snap diff --git a/crates/ruff_linter/resources/test/fixtures/pylint/nan_comparison.py b/crates/ruff_linter/resources/test/fixtures/pylint/nan_comparison.py new file mode 100644 index 0000000000000..be3b5d7f14f84 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/pylint/nan_comparison.py @@ -0,0 +1,76 @@ +import math +from math import nan as bad_val +import numpy as np +from numpy import nan as npy_nan + + +x = float("nan") +y = np.NaN + +# PLW0117 +if x == float("nan"): + pass + +# PLW0117 +if x == float("NaN"): + pass + +# PLW0117 +if x == float("NAN"): + pass + +# PLW0117 +if x == float("Nan"): + pass + +# PLW0117 +if x == math.nan: + pass + +# PLW0117 +if x == bad_val: + pass + +# PLW0117 +if y == np.NaN: + pass + +# PLW0117 +if y == np.NAN: + pass + +# PLW0117 +if y == np.nan: + pass + +# PLW0117 +if y == npy_nan: + pass + +# OK +if math.isnan(x): + pass + +# OK +if np.isnan(y): + pass + +# OK +if x == 0: + pass + +# OK +if x == float("32"): + pass + +# OK +if x == float(42): + pass + +# OK +if y == np.inf: + pass + +# OK +if x == "nan": + pass diff --git a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs index 9ca9a6df71838..785ffd34a6d98 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs @@ -1283,6 +1283,9 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) { if checker.enabled(Rule::MagicValueComparison) { pylint::rules::magic_value_comparison(checker, left, comparators); } + if checker.enabled(Rule::NanComparison) { + pylint::rules::nan_comparison(checker, left, comparators); + } if checker.enabled(Rule::InDictKeys) { flake8_simplify::rules::key_in_dict_compare(checker, compare); } diff --git a/crates/ruff_linter/src/codes.rs b/crates/ruff_linter/src/codes.rs index f318beddf4b22..774ad36d1b089 100644 --- a/crates/ruff_linter/src/codes.rs +++ b/crates/ruff_linter/src/codes.rs @@ -294,6 +294,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> { #[allow(deprecated)] (Pylint, "R6301") => (RuleGroup::Nursery, rules::pylint::rules::NoSelfUse), (Pylint, "W0108") => (RuleGroup::Preview, rules::pylint::rules::UnnecessaryLambda), + (Pylint, "W0117") => (RuleGroup::Preview, rules::pylint::rules::NanComparison), (Pylint, "W0120") => (RuleGroup::Stable, rules::pylint::rules::UselessElseOnLoop), (Pylint, "W0127") => (RuleGroup::Stable, rules::pylint::rules::SelfAssigningVariable), (Pylint, "W0128") => (RuleGroup::Preview, rules::pylint::rules::RedeclaredAssignedName), diff --git a/crates/ruff_linter/src/rules/pylint/mod.rs b/crates/ruff_linter/src/rules/pylint/mod.rs index fc377eb30b100..7d9bc56575854 100644 --- a/crates/ruff_linter/src/rules/pylint/mod.rs +++ b/crates/ruff_linter/src/rules/pylint/mod.rs @@ -188,6 +188,7 @@ mod tests { Rule::UselessExceptionStatement, Path::new("useless_exception_statement.py") )] + #[test_case(Rule::NanComparison, Path::new("nan_comparison.py"))] fn rules(rule_code: Rule, path: &Path) -> Result<()> { let snapshot = format!("{}_{}", rule_code.noqa_code(), path.to_string_lossy()); let diagnostics = test_path( diff --git a/crates/ruff_linter/src/rules/pylint/rules/mod.rs b/crates/ruff_linter/src/rules/pylint/rules/mod.rs index cc8fd735db43d..c190406c67776 100644 --- a/crates/ruff_linter/src/rules/pylint/rules/mod.rs +++ b/crates/ruff_linter/src/rules/pylint/rules/mod.rs @@ -38,6 +38,7 @@ pub(crate) use magic_value_comparison::*; pub(crate) use manual_import_from::*; pub(crate) use misplaced_bare_raise::*; pub(crate) use named_expr_without_context::*; +pub(crate) use nan_comparison::*; pub(crate) use nested_min_max::*; pub(crate) use no_method_decorator::*; pub(crate) use no_self_use::*; @@ -130,6 +131,7 @@ mod magic_value_comparison; mod manual_import_from; mod misplaced_bare_raise; mod named_expr_without_context; +mod nan_comparison; mod nested_min_max; mod no_method_decorator; mod no_self_use; diff --git a/crates/ruff_linter/src/rules/pylint/rules/nan_comparison.rs b/crates/ruff_linter/src/rules/pylint/rules/nan_comparison.rs new file mode 100644 index 0000000000000..8f0cb708ac333 --- /dev/null +++ b/crates/ruff_linter/src/rules/pylint/rules/nan_comparison.rs @@ -0,0 +1,135 @@ +use ruff_diagnostics::{Diagnostic, Violation}; +use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::{self as ast, Expr}; +use ruff_python_semantic::SemanticModel; +use ruff_text_size::Ranged; + +use crate::checkers::ast::Checker; + +/// ## What it does +/// Checks for comparisons against NaN values. +/// +/// ## Why is this bad? +/// Comparing against a NaN value can lead to unexpected results. For example, +/// `float("NaN") == float("NaN")` will return `False` and, in general, +/// `x == float("NaN")` will always return `False`, even if `x` is `NaN`. +/// +/// To determine whether a value is `NaN`, use `math.isnan` or `np.isnan` +/// instead of comparing against `NaN` directly. +/// +/// ## Example +/// ```python +/// if x == float("NaN"): +/// pass +/// ``` +/// +/// Use instead: +/// ```python +/// import math +/// +/// if math.isnan(x): +/// pass +/// ``` +/// +#[violation] +pub struct NanComparison { + nan: Nan, +} + +impl Violation for NanComparison { + #[derive_message_formats] + fn message(&self) -> String { + let NanComparison { nan } = self; + match nan { + Nan::Math => format!("Comparing against a NaN value; use `math.isnan` instead"), + Nan::NumPy => format!("Comparing against a NaN value; use `np.isnan` instead"), + } + } +} + +/// PLW0117 +pub(crate) fn nan_comparison(checker: &mut Checker, left: &Expr, comparators: &[Expr]) { + for expr in std::iter::once(left).chain(comparators.iter()) { + if let Some(qualified_name) = checker.semantic().resolve_qualified_name(expr) { + match qualified_name.segments() { + ["numpy", "nan" | "NAN" | "NaN"] => { + checker.diagnostics.push(Diagnostic::new( + NanComparison { nan: Nan::NumPy }, + expr.range(), + )); + } + ["math", "nan"] => { + checker.diagnostics.push(Diagnostic::new( + NanComparison { nan: Nan::Math }, + expr.range(), + )); + } + _ => continue, + } + } + + if is_nan_float(expr, checker.semantic()) { + checker.diagnostics.push(Diagnostic::new( + NanComparison { nan: Nan::Math }, + expr.range(), + )); + } + } +} + +#[derive(Debug, PartialEq, Eq)] +enum Nan { + /// `math.isnan` + Math, + /// `np.isnan` + NumPy, +} + +impl std::fmt::Display for Nan { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Nan::Math => fmt.write_str("math"), + Nan::NumPy => fmt.write_str("numpy"), + } + } +} + +/// Returns `true` if the expression is a call to `float("NaN")`. +fn is_nan_float(expr: &Expr, semantic: &SemanticModel) -> bool { + let Expr::Call(call) = expr else { + return false; + }; + + let Expr::Name(ast::ExprName { id, .. }) = call.func.as_ref() else { + return false; + }; + + if id.as_str() != "float" { + return false; + } + + if !call.arguments.keywords.is_empty() { + return false; + } + + let [arg] = call.arguments.args.as_ref() else { + return false; + }; + + let Expr::StringLiteral(ast::ExprStringLiteral { value, .. }) = arg else { + return false; + }; + + if !matches!( + value.to_str(), + "nan" | "NaN" | "NAN" | "Nan" | "nAn" | "naN" | "nAN" | "NAn" + ) { + return false; + } + + if !semantic.is_builtin("float") { + return false; + } + + true +} diff --git a/crates/ruff_linter/src/rules/pylint/snapshots/ruff_linter__rules__pylint__tests__PLW0117_nan_comparison.py.snap b/crates/ruff_linter/src/rules/pylint/snapshots/ruff_linter__rules__pylint__tests__PLW0117_nan_comparison.py.snap new file mode 100644 index 0000000000000..f9046fb384fa0 --- /dev/null +++ b/crates/ruff_linter/src/rules/pylint/snapshots/ruff_linter__rules__pylint__tests__PLW0117_nan_comparison.py.snap @@ -0,0 +1,82 @@ +--- +source: crates/ruff_linter/src/rules/pylint/mod.rs +--- +nan_comparison.py:11:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead + | +10 | # PLW0117 +11 | if x == float("nan"): + | ^^^^^^^^^^^^ PLW0117 +12 | pass + | + +nan_comparison.py:15:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead + | +14 | # PLW0117 +15 | if x == float("NaN"): + | ^^^^^^^^^^^^ PLW0117 +16 | pass + | + +nan_comparison.py:19:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead + | +18 | # PLW0117 +19 | if x == float("NAN"): + | ^^^^^^^^^^^^ PLW0117 +20 | pass + | + +nan_comparison.py:23:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead + | +22 | # PLW0117 +23 | if x == float("Nan"): + | ^^^^^^^^^^^^ PLW0117 +24 | pass + | + +nan_comparison.py:27:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead + | +26 | # PLW0117 +27 | if x == math.nan: + | ^^^^^^^^ PLW0117 +28 | pass + | + +nan_comparison.py:31:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead + | +30 | # PLW0117 +31 | if x == bad_val: + | ^^^^^^^ PLW0117 +32 | pass + | + +nan_comparison.py:35:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead + | +34 | # PLW0117 +35 | if y == np.NaN: + | ^^^^^^ PLW0117 +36 | pass + | + +nan_comparison.py:39:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead + | +38 | # PLW0117 +39 | if y == np.NAN: + | ^^^^^^ PLW0117 +40 | pass + | + +nan_comparison.py:43:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead + | +42 | # PLW0117 +43 | if y == np.nan: + | ^^^^^^ PLW0117 +44 | pass + | + +nan_comparison.py:47:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead + | +46 | # PLW0117 +47 | if y == npy_nan: + | ^^^^^^^ PLW0117 +48 | pass + | diff --git a/ruff.schema.json b/ruff.schema.json index ed4470db8fbf1..2bd6fa00a162c 100644 --- a/ruff.schema.json +++ b/ruff.schema.json @@ -3330,6 +3330,8 @@ "PLW01", "PLW010", "PLW0108", + "PLW011", + "PLW0117", "PLW012", "PLW0120", "PLW0127",