Skip to content

Commit

Permalink
SONARPY-1815 Enable AST-based type inference for functions/module con…
Browse files Browse the repository at this point in the history
…taining try/catch blocks
  • Loading branch information
guillaume-dequenne-sonarsource committed May 13, 2024
1 parent 1c85601 commit d48bdfa
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ private static void inferTypesAndMemberAccessSymbols(Tree scopeTree,
statements.accept(tryStatementVisitor);
if (tryStatementVisitor.hasTryStatement()) {
// CFG doesn't model precisely try-except statements. Hence we fallback to AST based type inference
// TODO: Check if still relevant
/* visitor.processPropagations(getTrackedVars(declaredVariables, assignedNames));
statements.accept(new TypeInference.NameVisitor());*/
propagationVisitor.processPropagations(getTrackedVars(declaredVariables, assignedNames));
} else {
ControlFlowGraph cfg = controlFlowGraphSupplier.get();
if (cfg == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,86 @@
*/
package org.sonar.python.semantic.v2.types;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.python.semantic.v2.SymbolV2;
import org.sonar.python.semantic.v2.UsageV2;
import org.sonar.python.tree.NameImpl;
import org.sonar.python.types.HasTypeDependencies;
import org.sonar.python.types.v2.PythonType;
import org.sonar.python.types.v2.UnionType;

public record Assignment(SymbolV2 lhsSymbol, Name lhsName, Expression rhs) {
public class Assignment {

final SymbolV2 lhsSymbol;
Name lhsName;
Expression rhs;
Set<SymbolV2> variableDependencies = new HashSet<>();
Set<Assignment> dependents = new HashSet<>();
Map<SymbolV2, Set<Assignment>> assignmentsByLhs;

public Assignment(SymbolV2 lhsSymbol, Name lhsName, Expression rhs, Map<SymbolV2, Set<Assignment>> assignmentsByLhs) {
this.lhsSymbol = lhsSymbol;
this.lhsName = lhsName;
this.rhs = rhs;
this.assignmentsByLhs = assignmentsByLhs;
}

void computeDependencies(Expression expression, Set<SymbolV2> trackedVars) {
Deque<Expression> workList = new ArrayDeque<>();
workList.push(expression);
while (!workList.isEmpty()) {
Expression e = workList.pop();
if (e.is(Tree.Kind.NAME)) {
Name name = (Name) e;
SymbolV2 symbol = name.symbolV2();
if (symbol != null && trackedVars.contains(symbol)) {
variableDependencies.add(symbol);
assignmentsByLhs.get(symbol).forEach(a -> a.dependents.add(this));
}
} else if (e instanceof HasTypeDependencies hasTypeDependencies) {
workList.addAll(hasTypeDependencies.typeDependencies());
}
}
}

boolean areDependenciesReady(Set<SymbolV2> initializedVars) {
return initializedVars.containsAll(variableDependencies);
}

/** @return true if the propagation effectively changed the inferred type of lhs */
public boolean propagate(Set<SymbolV2> initializedVars) {
PythonType rhsType = rhs.typeV2();
if (initializedVars.add(lhsSymbol)) {
lhsSymbol.usages().stream().map(UsageV2::tree).filter(NameImpl.class::isInstance).map(NameImpl.class::cast).forEach(n -> n.typeV2(rhsType));
return true;
} else {
PythonType currentType = lhsName.typeV2();
PythonType newType = UnionType.or(rhsType, currentType);
lhsSymbol.usages().stream().map(UsageV2::tree).filter(NameImpl.class::isInstance).map(NameImpl.class::cast).forEach(n -> n.typeV2(newType));
return !newType.equals(currentType);
}
}

public Name lhsName() {
return lhsName;
}

public SymbolV2 lhsSymbol() {
return lhsSymbol;
}

public Expression rhs() {
return rhs;
}

public Set<Assignment> dependents() {
return dependents;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -83,9 +84,39 @@ public void visitAnnotatedAssignment(AnnotatedAssignment annotatedAssignment){
private void processAssignment(Statement assignmentStatement, Expression lhsExpression, Expression rhsExpression){
if (lhsExpression instanceof Name lhs && lhs.symbolV2() != null) {
var symbol = lhs.symbolV2();
Assignment assignment = new Assignment(symbol, lhs, rhsExpression);
Assignment assignment = new Assignment(symbol, lhs, rhsExpression, assignmentsByLhs);
assignmentsByAssignmentStatement.put(assignmentStatement, assignment);
assignmentsByLhs.computeIfAbsent(symbol, s -> new HashSet<>()).add(assignment);
}
}

public void processPropagations(Set<SymbolV2> trackedVars) {
Set<Assignment> propagations = new HashSet<>();
Set<SymbolV2> initializedVars = new HashSet<>();

assignmentsByLhs.forEach((lhs, as) -> {
if (trackedVars.contains(lhs)) {
as.forEach(a -> a.computeDependencies(a.rhs(), trackedVars));
propagations.addAll(as);
}
});

applyPropagations(propagations, initializedVars, true);
applyPropagations(propagations, initializedVars, false);
}

private static void applyPropagations(Set<Assignment> propagations, Set<SymbolV2> initializedVars, boolean checkDependenciesReadiness) {
Set<Assignment> workSet = new HashSet<>(propagations);
while (!workSet.isEmpty()) {
Iterator<Assignment> iterator = workSet.iterator();
Assignment propagation = iterator.next();
iterator.remove();
if (!checkDependenciesReadiness || propagation.areDependenciesReady(initializedVars)) {
boolean learnt = propagation.propagate(initializedVars);
if (learnt) {
workSet.addAll(propagation.dependents());
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.sonar.plugins.python.api.tree.ImportFrom;
import org.sonar.plugins.python.api.tree.ImportName;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.Statement;
import org.sonar.plugins.python.api.tree.StatementList;
import org.sonar.plugins.python.api.tree.Tree;
Expand Down Expand Up @@ -626,6 +627,177 @@ def foo():
""").typeV2().unwrappedType()).isEqualTo(INT_TYPE);
}

@Test
void flow_insensitive_when_try_except() {
FileInput fileInput = inferTypes("""
try:
if p:
x = 42
type(x)
else:
x = "foo"
type(x)
except:
type(x)
""");

List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
RegularArgument secondX = (RegularArgument) calls.get(1).arguments().get(0);
RegularArgument thirdX = (RegularArgument) calls.get(2).arguments().get(0);
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
}

@Test
void nested_try_except() {
FileInput fileInput = inferTypes("""
def f(p):
try:
if p:
x = 42
type(x)
else:
x = "foo"
type(x)
except:
type(x)
def g(p):
if p:
y = 42
type(y)
else:
y = "hello"
type(y)
type(y)
if cond:
z = 42
type(z)
else:
z = "hello"
type(z)
type(z)
""");
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
RegularArgument secondX = (RegularArgument) calls.get(1).arguments().get(0);
RegularArgument thirdX = (RegularArgument) calls.get(2).arguments().get(0);
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);

RegularArgument firstY = (RegularArgument) calls.get(3).arguments().get(0);
RegularArgument secondY = (RegularArgument) calls.get(4).arguments().get(0);
RegularArgument thirdY = (RegularArgument) calls.get(5).arguments().get(0);
assertThat(firstY.expression().typeV2().unwrappedType()).isEqualTo(INT_TYPE);
assertThat(secondY.expression().typeV2().unwrappedType()).isEqualTo(STR_TYPE);
assertThat(((UnionType) thirdY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);

RegularArgument firstZ = (RegularArgument) calls.get(6).arguments().get(0);
RegularArgument secondZ = (RegularArgument) calls.get(7).arguments().get(0);
RegularArgument thirdZ = (RegularArgument) calls.get(8).arguments().get(0);
assertThat(firstZ.expression().typeV2().unwrappedType()).isEqualTo(INT_TYPE);
assertThat(secondZ.expression().typeV2().unwrappedType()).isEqualTo(STR_TYPE);
assertThat(((UnionType) thirdZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
}

@Test
void nested_try_except_2() {
FileInput fileInput = inferTypes("""
try:
if p:
x = 42
type(x)
else:
x = "foo"
type(x)
except:
type(x)
def g(p):
if p:
y = 42
type(y)
else:
y = "hello"
type(y)
type(y)
if cond:
z = 42
type(z)
else:
z = "hello"
type(z)
type(z)
""");
List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
RegularArgument secondX = (RegularArgument) calls.get(1).arguments().get(0);
RegularArgument thirdX = (RegularArgument) calls.get(2).arguments().get(0);
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);

RegularArgument firstY = (RegularArgument) calls.get(3).arguments().get(0);
RegularArgument secondY = (RegularArgument) calls.get(4).arguments().get(0);
RegularArgument thirdY = (RegularArgument) calls.get(5).arguments().get(0);
assertThat(firstY.expression().typeV2().unwrappedType()).isEqualTo(INT_TYPE);
assertThat(secondY.expression().typeV2().unwrappedType()).isEqualTo(STR_TYPE);
assertThat(((UnionType) thirdY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);

RegularArgument firstZ = (RegularArgument) calls.get(6).arguments().get(0);
RegularArgument secondZ = (RegularArgument) calls.get(7).arguments().get(0);
RegularArgument thirdZ = (RegularArgument) calls.get(8).arguments().get(0);
assertThat(((UnionType) firstZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) secondZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) thirdZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
}

@Test
void try_except_with_dependents() {
FileInput fileInput = inferTypes("""
try:
x = 42
y = x
z = y
type(x)
type(y)
type(z)
except:
x = "hello"
y = x
z = y
type(x)
type(y)
type(z)
type(x)
type(y)
type(z)
""");

List<CallExpression> calls = PythonTestUtils.getAllDescendant(fileInput, tree -> tree.is(Tree.Kind.CALL_EXPR));
RegularArgument firstX = (RegularArgument) calls.get(0).arguments().get(0);
RegularArgument firstY = (RegularArgument) calls.get(1).arguments().get(0);
RegularArgument firstZ = (RegularArgument) calls.get(2).arguments().get(0);
assertThat(((UnionType) firstX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) firstY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) firstZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);

RegularArgument secondX = (RegularArgument) calls.get(3).arguments().get(0);
RegularArgument secondY = (RegularArgument) calls.get(4).arguments().get(0);
RegularArgument secondZ = (RegularArgument) calls.get(5).arguments().get(0);
assertThat(((UnionType) secondX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) secondY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) secondZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);

RegularArgument thirdX = (RegularArgument) calls.get(6).arguments().get(0);
RegularArgument thirdY = (RegularArgument) calls.get(7).arguments().get(0);
RegularArgument thirdZ = (RegularArgument) calls.get(8).arguments().get(0);
assertThat(((UnionType) thirdX.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) thirdY.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
assertThat(((UnionType) thirdZ.expression().typeV2()).candidates()).extracting(PythonType::unwrappedType).containsExactlyInAnyOrder(INT_TYPE, STR_TYPE);
}

private static FileInput inferTypes(String lines) {
return inferTypes(lines, new HashMap<>());
}
Expand Down

0 comments on commit d48bdfa

Please sign in to comment.