Skip to content

Implement chained comparison improvements and related checks #7611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
214 changes: 161 additions & 53 deletions pylint/checkers/refactoring/refactoring_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import copy
import itertools
import tokenize
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from functools import cached_property, reduce
from re import Pattern
from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast
Expand All @@ -21,6 +21,7 @@
from pylint.checkers import utils
from pylint.checkers.base.basic_error_checker import _loop_exits_early
from pylint.checkers.utils import node_frame_class
from pylint.graph import get_cycles, get_paths
from pylint.interfaces import HIGH, INFERENCE, Confidence

if TYPE_CHECKING:
Expand Down Expand Up @@ -353,7 +354,7 @@ class RefactoringChecker(checkers.BaseTokenChecker):
"more idiomatic, although sometimes a bit slower",
),
"R1716": (
"Simplify chained comparison between the operands",
"Simplify chained comparison between the operands: %s",
"chained-comparison",
"This message is emitted when pylint encounters boolean operation like "
'"a < b and b < c", suggesting instead to refactor it to "a < b < c"',
Expand Down Expand Up @@ -483,6 +484,17 @@ class RefactoringChecker(checkers.BaseTokenChecker):
"value by index lookup. "
"The value can be accessed directly instead.",
),
"R1737": (
"Simplify cycle to ==",
"chained-comparison-all-equal",
"Emitted when items in a boolean condition are all <= or >="
"This is equivalent to asking if they all equal.",
),
"R1738": (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should raise the existing message for constant comparison using-constant-test (https://pylint.pycqa.org/en/latest/user_guide/messages/warning/using-constant-test.html#using-constant-test-w0125). We might need to relax some constraint on what checker can do, I'm not sure a checker can raise messages it does not define right now.

"This comparison always evalutes to False",
"impossible-comparison",
"Emitted when there a comparison that is always False.",
),
}
options = (
(
Expand Down Expand Up @@ -1333,61 +1345,157 @@ def _check_consider_using_in(self, node: nodes.BoolOp) -> None:
confidence=HIGH,
)

def _check_chained_comparison(self, node: nodes.BoolOp) -> None:
"""Check if there is any chained comparison in the expression.
def _check_comparisons(self, node: nodes.BoolOp) -> None:
graph_info = self._get_graph_from_comparison_nodes(node)
if not graph_info:
return
(
graph_dict,
symbol_dict,
indegree_dict,
frequency_dict,
) = graph_info

# Convert graph_dict to all strings to access the get_cycles API
str_dict = {
str(key): {str(dest) for dest in graph_dict[key]} for key in graph_dict
}
cycles = get_cycles(str_dict)
if cycles:
self._handle_cycles(node, symbol_dict, cycles)
return

Add a refactoring message if a boolOp contains comparison like a < b and b < c,
which can be chained as a < b < c.
paths = get_paths(graph_dict, indegree_dict, frequency_dict)

Care is taken to avoid simplifying a < b < c and b < d.
"""
if len(paths) < len(node.values):
suggestions = []
for path in paths:
cur_statement = str(path[0])
for i in range(len(path) - 1):
cur_statement += (
" " + symbol_dict[path[i], path[i + 1]] + " " + str(path[i + 1])
)
suggestions.append(cur_statement)
args = " and ".join(sorted(suggestions))
self.add_message("chained-comparison", node=node, args=(args,))

def _get_graph_from_comparison_nodes(
self, node: nodes.BoolOp
) -> (
None
| tuple[
dict[str | int | float, set[str | int | float]],
dict[tuple[str | int | float, str | int | float], str],
dict[str | int | float, int],
dict[tuple[str | int | float, str | int | float], int],
]
):
if node.op != "and" or len(node.values) < 2:
return
return None

def _find_lower_upper_bounds(
comparison_node: nodes.Compare,
uses: collections.defaultdict[str, dict[str, set[nodes.Compare]]],
) -> None:
left_operand = comparison_node.left
for operator, right_operand in comparison_node.ops:
for operand in (left_operand, right_operand):
value = None
if isinstance(operand, nodes.Name):
value = operand.name
elif isinstance(operand, nodes.Const):
value = operand.value

if value is None:
continue
graph_dict: dict[
str | int | float, set[str | int | float]
] = collections.defaultdict(set)
symbol_dict: dict[
tuple[str | int | float, str | int | float], str
] = collections.defaultdict(lambda: ">")
frequency_dict: dict[
tuple[str | int | float, str | int | float], int
] = collections.defaultdict(int)
indegree_dict: dict[str | int | float, int] = collections.defaultdict(int)
const_values: list[int | float] = []

for statement in node.values:
if not isinstance(statement, nodes.Compare):
return None
ops = list(statement.ops)
left_statement = statement.left
while ops:
left = self._get_compare_operand_value(left_statement, const_values)
# Pop from ops or else we never advance along the statement
operator, right_statement = ops.pop(0)
# The operand is not a constant or variable or the operator is not a comparison
if operator not in {"<", ">", "==", "<=", ">="} or left is None:
return None
right = self._get_compare_operand_value(right_statement, const_values)
if right is None:
return None

# Make the graph always point from larger to smaller
if operator == "<":
operator = ">"
left, right = right, left
elif operator == "<=":
operator = ">="
left, right = right, left

# Update maps
graph_dict[left].add(right)
if not graph_dict[right]:
graph_dict[right] = set() # Ensure the node exists in graph
symbol_dict[(left, right)] = operator
indegree_dict[left] += 0 # Make sure every node has an entry
indegree_dict[right] += 1
frequency_dict[(left, right)] += 1

# advance onto the next comparison if it exists
left_statement = right_statement

# Nothing was added and we have no recommendations
if (
not graph_dict
or not symbol_dict
or all(val == "==" for val in symbol_dict.values())
):
return None

if operator in {"<", "<="}:
if operand is left_operand:
uses[value]["lower_bound"].add(comparison_node)
elif operand is right_operand:
uses[value]["upper_bound"].add(comparison_node)
elif operator in {">", ">="}:
if operand is left_operand:
uses[value]["upper_bound"].add(comparison_node)
elif operand is right_operand:
uses[value]["lower_bound"].add(comparison_node)
left_operand = right_operand

uses: collections.defaultdict[
str, dict[str, set[nodes.Compare]]
] = collections.defaultdict(
lambda: {"lower_bound": set(), "upper_bound": set()}
)
for comparison_node in node.values:
if isinstance(comparison_node, nodes.Compare):
_find_lower_upper_bounds(comparison_node, uses)

for bounds in uses.values():
num_shared = len(bounds["lower_bound"].intersection(bounds["upper_bound"]))
num_lower_bounds = len(bounds["lower_bound"])
num_upper_bounds = len(bounds["upper_bound"])
if num_shared < num_lower_bounds and num_shared < num_upper_bounds:
self.add_message("chained-comparison", node=node)
break
# Link up constant nodes, i.e. create synthetic nodes between 1 and 5 such that 5 > 1
sorted_consts = sorted(const_values)
while sorted_consts:
largest = sorted_consts.pop()
for smaller in set(sorted_consts):
if smaller < largest:
symbol_dict[(largest, smaller)] = ">"
indegree_dict[smaller] += 1
frequency_dict[(largest, smaller)] += 1
graph_dict[largest].add(smaller)

# Remove paths from the larger number to the smaller number's adjacent nodes
# This prevents duplicated paths in the output
for adj in graph_dict[smaller]:
if isinstance(adj, str):
graph_dict[largest].discard(adj)

return (graph_dict, symbol_dict, indegree_dict, frequency_dict)

def _get_compare_operand_value(
self, node: nodes.Compare, const_values: list[int | float | None]
) -> str | int | float | None:
value: str | int | float | None = None
if isinstance(node, nodes.Name) and isinstance(node.name, str):
value = node.name
elif isinstance(node, nodes.Const) and isinstance(node.value, (int, float)):
value = node.value
const_values.append(value)
return value

def _handle_cycles(
self,
node: nodes.BoolOp,
symbol_dict: dict[tuple[str | int | float, str | int | float], str],
cycles: Sequence[list[str]],
) -> None:
for cycle in cycles:
all_geq = all(
symbol_dict[(cur_item, cycle[i + 1])] == ">="
for (i, cur_item) in enumerate(cycle)
if i < len(cycle) - 1
)
all_geq = all_geq and symbol_dict[cycle[-1], cycle[0]] == ">="
if all_geq:
self.add_message("comparison-all-equal", node=node)
else:
self.add_message("impossible-comparison", node=node)

@staticmethod
def _apply_boolean_simplification_rules(
Expand Down Expand Up @@ -1478,7 +1586,7 @@ def _check_simplifiable_condition(self, node: nodes.BoolOp) -> None:
def visit_boolop(self, node: nodes.BoolOp) -> None:
self._check_consider_merging_isinstance(node)
self._check_consider_using_in(node)
self._check_chained_comparison(node)
self._check_comparisons(node)
self._check_simplifiable_condition(node)

@staticmethod
Expand Down
Loading