Skip to content

Commit

Permalink
Fixed type inference for locally defined class.
Browse files Browse the repository at this point in the history
  • Loading branch information
jdodinh committed Feb 9, 2024
1 parent 2d8d2c4 commit 4fee60f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import org.sonar.python.tree.FunctionDefImpl;
import org.sonar.python.types.TypeShed;

class Scope {
public class Scope {

final Tree rootTree;
private PythonFile pythonFile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void visitFileInput(FileInput fileInput) {
addSymbolsToTree((FileInputImpl) fileInput);
fileInput.accept(new ThirdPhaseVisitor());
if (typeContext != null) {
this.typeContext.setScopesByRootTree(scopesByRootTree);
new PyTypeAnnotation(this.typeContext, pythonFile).annotate(fileInput);
} else {
TypeInference.inferTypes(fileInput, pythonFile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sonar.plugins.python.api.tree.ClassDef;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.Token;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.plugins.python.api.types.InferredType;
import org.sonar.python.semantic.ClassSymbolImpl;
import org.sonar.python.semantic.Scope;

public class TypeContext {
private static final class JsonTypeInfo {
Expand Down Expand Up @@ -90,6 +94,8 @@ public int hashCode() {

private Map<NameAccess, JsonTypeInfo> typesByPosition = null;

private Map<Tree, Scope> scopesByRootTree = null;

public TypeContext() {
this.files = new HashMap<>();
}
Expand All @@ -98,6 +104,10 @@ public static TypeContext fromJSON(String json) {
return new Gson().fromJson("{files: " + json + "}", TypeContext.class);
}

public void setScopesByRootTree(Map<Tree, Scope> scopesByRootTree) {
this.scopesByRootTree = scopesByRootTree;
}

private void populateTypesByPosition() {
typesByPosition = new HashMap<>();
files.forEach((file, typeInfos) -> {
Expand All @@ -108,7 +118,7 @@ private void populateTypesByPosition() {
}

@VisibleForTesting
Optional<InferredType> getTypeFor(String fileName, int line, int column, String name, String kind) {
Optional<InferredType> getTypeFor(String fileName, int line, int column, String name, String kind, @Nullable Tree tree) {
if (typesByPosition == null) {
populateTypesByPosition();
}
Expand All @@ -120,10 +130,10 @@ Optional<InferredType> getTypeFor(String fileName, int line, int column, String
LOG.error("Found type at position, but does not match expected kind ({}): {}", kind, typeInfo);
return Optional.empty();
}
return Optional.of(typeStringToTypeInfo(typeInfo.short_type, typeInfo.type, fileName));
return Optional.of(typeStringToTypeInfo(typeInfo.short_type, typeInfo.type, fileName, tree));
}

private static InferredType typeStringToTypeInfo(String typeString, String detailedType, String fileName) {
private InferredType typeStringToTypeInfo(String typeString, String detailedType, String fileName, @Nullable Tree tree) {
if ("None".equals(typeString)) {
typeString = "NoneType";
}
Expand All @@ -137,6 +147,8 @@ private static InferredType typeStringToTypeInfo(String typeString, String detai
return new RuntimeType(callableClassSymbol);
} else if (typeString.equals("Any")) {
return InferredTypes.anyType();
} else if (isLocallyDefinedClassType(typeString)) {
return getLocallyDefinedClassSymbolType(typeString);
} else {
// Try to make a qualified name. pytype does not give a qualified name for classes defined in the same file.
// The filename should contain the module path, but we may get and extra prefix. The prefix will get removed later.
Expand All @@ -146,6 +158,29 @@ private static InferredType typeStringToTypeInfo(String typeString, String detai
}
}

private boolean isLocallyDefinedClassType(String typeString) {
return scopesByRootTree.keySet().stream()
.filter(ClassDef.class::isInstance)
.findAny()
.map(ClassDef.class::cast)
.map(ClassDef::name)
.map(Name::name)
.filter(typeString::equals)
.isPresent();
}

private InferredType getLocallyDefinedClassSymbolType(String typeString) {
return scopesByRootTree.keySet().stream()
.filter(ClassDef.class::isInstance)
.findAny()
.map(ClassDef.class::cast)
.map(ClassDef::name)
.filter(name -> typeString.equals(name.name()))
.map(Name::symbol)
.map(InferredTypes::runtimeType)
.orElse(InferredTypes.anyType());
}

private static String getBaseType(String typeString) { // Tuple[int, int]
if (typeString.startsWith("Tuple")) {
return "tuple";
Expand All @@ -164,11 +199,12 @@ private static String getBaseType(String typeString) { // Tuple[int, int]

public Optional<InferredType> getTypeFor(String fileName, Name name) {
Token token = name.firstToken();
return getTypeFor(fileName, token.line(), token.column(), name.name(), "Variable");

return getTypeFor(fileName, token.line(), token.column(), name.name(), "Variable", name);
}

public Optional<InferredType> getTypeFor(String fileName, QualifiedExpression attributeAccess) {
Token token = attributeAccess.firstToken();
return getTypeFor(fileName, token.line(), token.column(), attributeAccess.name().name(), "Attribute");
return getTypeFor(fileName, token.line(), token.column(), attributeAccess.name().name(), "Attribute", attributeAccess);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void test_type_inference_builtins() {
@Test
void test_type_ineference_custom_class() {
TypeContext typeContext = TypeContext.fromJSON("{\n" +
" \"test.py\": [\n" +
" \"mod1.py\": [\n" +
" {\n" +
" \"text\": \"a\",\n" +
" \"start_line\": 4,\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,28 +113,28 @@ void test1() {
"}";
String fileName = "src/AttributeError/181733998.py";
TypeContext typeContext = TypeContext.fromJSON(json);
assertThat(typeContext.getTypeFor(fileName, 1, 0, "t", "Variable")).contains(InferredTypes.INT);
assertThat(typeContext.getTypeFor(fileName, 1, 0, "x", "Variable")).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 1, 0, "t", "Attribute")).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 2, 0, "t", "Variable")).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 1, 1, "t", "Variable")).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 7, 4, "append", "Attribute")).contains(new RuntimeType(callableClassSymbol));
assertThat(typeContext.getTypeFor(fileName, 1, 0, "t", "Variable", null)).contains(InferredTypes.INT);
assertThat(typeContext.getTypeFor(fileName, 1, 0, "x", "Variable", null)).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 1, 0, "t", "Attribute", null)).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 2, 0, "t", "Variable", null)).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 1, 1, "t", "Variable", null)).isEmpty();
assertThat(typeContext.getTypeFor(fileName, 7, 4, "append", "Attribute", null)).contains(new RuntimeType(callableClassSymbol));
}

@Test
void test2() {
String json = readJsonTypeInfo("src/test/resources/pytype/code.json");
String fileName = "level1.py";
TypeContext typeContext = TypeContext.fromJSON(json);
assertThat(typeContext.getTypeFor(fileName, 2, 4, "my_int", "Variable")).contains(InferredTypes.INT);
assertThat(typeContext.getTypeFor(fileName, 3, 4, "my_float", "Variable")).contains(InferredTypes.FLOAT);
assertThat(typeContext.getTypeFor(fileName, 4, 4, "my_str", "Variable")).contains(InferredTypes.STR);
assertThat(typeContext.getTypeFor(fileName, 5, 4, "my_bool", "Variable")).contains(InferredTypes.BOOL);
assertThat(typeContext.getTypeFor(fileName, 6, 4, "my_complex", "Variable")).contains(InferredTypes.COMPLEX);
assertThat(typeContext.getTypeFor(fileName, 7, 4, "my_tuple", "Variable")).contains(InferredTypes.TUPLE);
assertThat(typeContext.getTypeFor(fileName, 8, 4, "my_list", "Variable")).contains(InferredTypes.LIST);
assertThat(typeContext.getTypeFor(fileName, 9, 4, "my_set", "Variable")).contains(InferredTypes.SET);
assertThat(typeContext.getTypeFor(fileName, 10, 4, "my_dict", "Variable")).contains(InferredTypes.DICT);
assertThat(typeContext.getTypeFor(fileName, 2, 4, "my_int", "Variable", null)).contains(InferredTypes.INT);
assertThat(typeContext.getTypeFor(fileName, 3, 4, "my_float", "Variable", null)).contains(InferredTypes.FLOAT);
assertThat(typeContext.getTypeFor(fileName, 4, 4, "my_str", "Variable", null)).contains(InferredTypes.STR);
assertThat(typeContext.getTypeFor(fileName, 5, 4, "my_bool", "Variable", null)).contains(InferredTypes.BOOL);
assertThat(typeContext.getTypeFor(fileName, 6, 4, "my_complex", "Variable", null)).contains(InferredTypes.COMPLEX);
assertThat(typeContext.getTypeFor(fileName, 7, 4, "my_tuple", "Variable", null)).contains(InferredTypes.TUPLE);
assertThat(typeContext.getTypeFor(fileName, 8, 4, "my_list", "Variable", null)).contains(InferredTypes.LIST);
assertThat(typeContext.getTypeFor(fileName, 9, 4, "my_set", "Variable", null)).contains(InferredTypes.SET);
assertThat(typeContext.getTypeFor(fileName, 10, 4, "my_dict", "Variable", null)).contains(InferredTypes.DICT);
}

private String readJsonTypeInfo(String path) {
Expand Down

0 comments on commit 4fee60f

Please sign in to comment.