Skip to content

Commit

Permalink
[mypyc] Use table-driven helper for imports
Browse files Browse the repository at this point in the history
Change how imports (not from imports!) are processed so they can be
table-driven and compact. Here's how it works:

Import nodes are divided in groups (in the prebuild visitor). Each group
consists of consecutive Import nodes:

  import mod         <| group python#1
  import mod2         |

  def foo() -> None:
      import mod3    <- group python#2

  import mod4        <- group python#3

Every time we encounter the first import of a group, build IR to call
CPyImport_ImportMany() that will perform all of the group's imports in
one go.

Previously, each module would imported and placed in globals manually
in IR, leading to some pretty verbose code.

The other option to collect all imports and perform them all at once in
the helper would remove even more ops, however, it's problematic for
the same reasons from the previous commit (spoiler: it's not safe).

Implementation notes:

  - I had to add support for loading the address of a static directly,
    so I shoehorned in LoadLiteral support for LoadAddress.

  - Unfortunately by replacing multiple nodes with a single function
    call at the IR level, the traceback line number is static. Even if
    an import several lines down a group fails, the line # of the first
    import in the group would be printed.

    To fix this, I had to make CPyImport_ImportMany() add the traceback
    entry itself on failure (instead of letting codegen handle it
    automatically). This is admittedly ugly.
  • Loading branch information
ichard26 committed Mar 17, 2023
1 parent 11bfb04 commit 4ca6ed0
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 67 deletions.
8 changes: 7 additions & 1 deletion mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,13 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None:
def visit_load_address(self, op: LoadAddress) -> None:
typ = op.type
dest = self.reg(op)
src = self.reg(op.src) if isinstance(op.src, Register) else op.src
if isinstance(op.src, Register):
src = self.reg(op.src)
elif isinstance(op.src, LoadStatic):
prefix = self.PREFIX_MAP[op.src.namespace]
src = self.emitter.static_name(op.src.identifier, op.src.module_name, prefix)
else:
src = op.src
self.emit_line(f"{dest} = ({typ._ctype})&{src};")

def visit_keep_alive(self, op: KeepAlive) -> None:
Expand Down
5 changes: 3 additions & 2 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,13 +1348,14 @@ class LoadAddress(RegisterOp):
Attributes:
type: Type of the loaded address(e.g. ptr/object_ptr)
src: Source value (str for globals like 'PyList_Type',
Register for temporary values or locals)
Register for temporary values or locals, LoadStatic
for statics.)
"""

error_kind = ERR_NEVER
is_borrowed = True

def __init__(self, type: RType, src: str | Register, line: int = -1) -> None:
def __init__(self, type: RType, src: str | Register | LoadStatic, line: int = -1) -> None:
super().__init__(line)
self.type = type
self.src = src
Expand Down
5 changes: 5 additions & 0 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> str:
def visit_load_address(self, op: LoadAddress) -> str:
if isinstance(op.src, Register):
return self.format("%r = load_address %r", op, op.src)
elif isinstance(op.src, LoadStatic):
name = op.src.identifier
if op.src.module_name is not None:
name = f"{op.src.module_name}.{name}"
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
else:
return self.format("%r = load_address %s", op, op.src)

Expand Down
2 changes: 2 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def __init__(
self.encapsulating_funcs = pbv.encapsulating_funcs
self.nested_fitems = pbv.nested_funcs.keys()
self.fdefs_to_decorators = pbv.funcs_to_decorators
self.module_import_groups = pbv.module_import_groups

self.singledispatch_impls = singledispatch_impls

self.visitor = visitor
Expand Down
6 changes: 6 additions & 0 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,12 @@ def new_list_op(self, values: list[Value], line: int) -> Value:
def new_set_op(self, values: list[Value], line: int) -> Value:
return self.call_c(new_set_op, values, line)

def setup_rarray(self, item_type: RType, values: Sequence[Value]) -> Value:
"""Declare and initialize a new RArray, returning its address."""
array = Register(RArray(item_type, len(values)))
self.add(AssignMulti(array, list(values)))
return self.add(LoadAddress(c_pointer_rprimitive, array))

def shortcircuit_helper(
self,
op: str,
Expand Down
25 changes: 23 additions & 2 deletions mypyc/irbuild/prebuildvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@
Expression,
FuncDef,
FuncItem,
Import,
LambdaExpr,
MemberExpr,
MypyFile,
NameExpr,
Node,
SymbolNode,
Var,
)
from mypy.traverser import TraverserVisitor
from mypy.traverser import ExtendedTraverserVisitor
from mypyc.errors import Errors


class PreBuildVisitor(TraverserVisitor):
class PreBuildVisitor(ExtendedTraverserVisitor):
"""Mypy file AST visitor run before building the IR.
This collects various things, including:
Expand All @@ -26,6 +28,7 @@ class PreBuildVisitor(TraverserVisitor):
* Find non-local variables (free variables)
* Find property setters
* Find decorators of functions
* Find module import groups
The main IR build pass uses this information.
"""
Expand Down Expand Up @@ -68,10 +71,28 @@ def __init__(
# Map function to indices of decorators to remove
self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove

# Map starting module import to import groups. Each group is a
# series of imports with nothing between.
self.module_import_groups: dict[Import, list[Import]] = {}
self._current_import_group: Import | None = None

self.errors: Errors = errors

self.current_file: MypyFile = current_file

def visit(self, o: Node) -> bool:
if isinstance(o, Import):
if self._current_import_group is not None:
self.module_import_groups[self._current_import_group].append(o)
else:
self.module_import_groups[o] = [o]
self._current_import_group = o
# Don't recurse into the import's assignments.
return False

self._current_import_group = None
return True

def visit_decorator(self, dec: Decorator) -> None:
if dec.decorators:
# Only add the function being decorated if there exist
Expand Down
88 changes: 64 additions & 24 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
LoadAddress,
LoadErrorValue,
LoadLiteral,
LoadStatic,
MethodCall,
RaiseStandardError,
Register,
Expand All @@ -63,6 +64,7 @@
)
from mypyc.ir.rtypes import (
RInstance,
c_pyssize_t_rprimitive,
exc_rtuple,
is_tagged,
none_rprimitive,
Expand Down Expand Up @@ -100,6 +102,7 @@
check_stop_op,
coro_op,
import_from_many_op,
import_many_op,
send_op,
type_op,
yield_from_except_op,
Expand Down Expand Up @@ -220,32 +223,69 @@ def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignm
def transform_import(builder: IRBuilder, node: Import) -> None:
if node.is_mypy_only:
return
globals = builder.load_globals_dict()
for node_id, as_name in node.ids:
builder.gen_import(node_id, node.line)

# Update the globals dict with the appropriate module:
# * For 'import foo.bar as baz' we add 'foo.bar' with the name 'baz'
# * For 'import foo.bar' we add 'foo' with the name 'foo'
# Typically we then ignore these entries and access things directly
# via the module static, but we will use the globals version for modules
# that mypy couldn't find, since it doesn't analyze module references
# from those properly.

# TODO: Don't add local imports to the global namespace

# Miscompiling imports inside of functions, like below in import from.
if as_name:
name = as_name
base = node_id
else:
base = name = node_id.split(".")[0]

obj = builder.get_module(base, node.line)
# Imports (not from imports!) are processed in an odd way so they can be
# table-driven and compact. Here's how it works:
#
# Import nodes are divided in groups (in the prebuild visitor). Each group
# consists of consecutive Import nodes:
#
# import mod <| group #1
# import mod2 |
#
# def foo() -> None:
# import mod3 <- group #2
#
# import mod4 <- group #3
#
# Every time we encounter the first import of a group, build IR to call a
# helper function that will perform all of the group's imports in one go.
if node not in builder.module_import_groups:
return

builder.gen_method_call(
globals, "__setitem__", [builder.load_str(name), obj], result_type=None, line=node.line
)
modules = []
statics = []
# To show the right line number on failure, we have to add the traceback
# entry within the helper function (which is admittedly ugly). To drive
# this, we'll need the line number corresponding to each import.
import_lines = []
for import_node in builder.module_import_groups[node]:
for mod_id, as_name in import_node.ids:
builder.imports[mod_id] = None
import_lines.append(Integer(import_node.line, c_pyssize_t_rprimitive))

module_static = LoadStatic(object_rprimitive, mod_id, namespace=NAMESPACE_MODULE)
static_ptr = builder.add(LoadAddress(object_pointer_rprimitive, module_static))
statics.append(static_ptr)
# TODO: Don't add local imports to the global namespace
# Update the globals dict with the appropriate module:
# * For 'import foo.bar as baz' we add 'foo.bar' with the name 'baz'
# * For 'import foo.bar' we add 'foo' with the name 'foo'
# Typically we then ignore these entries and access things directly
# via the module static, but we will use the globals version for
# modules that mypy couldn't find, since it doesn't analyze module
# references from those properly.
if as_name or "." not in mod_id:
globals_base = None
else:
globals_base = mod_id.split(".")[0]
modules.append((mod_id, as_name, globals_base))

static_array_ptr = builder.builder.setup_rarray(object_pointer_rprimitive, statics)
import_line_ptr = builder.builder.setup_rarray(c_pyssize_t_rprimitive, import_lines)
function = "<module>" if builder.fn_info.name == "<top level>" else builder.fn_info.name
builder.call_c(
import_many_op,
[
builder.add(LoadLiteral(tuple(modules), object_rprimitive)),
static_array_ptr,
builder.load_globals_dict(),
builder.load_str(builder.module_path),
builder.load_str(function),
import_line_ptr,
],
NO_TRACEBACK_LINE_NO,
)


def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None:
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,8 @@ PyObject *CPy_Super(PyObject *builtins, PyObject *self);
PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *op,
_Py_Identifier *method);

bool CPyImport_ImportMany(PyObject *modules, CPyModule **statics[], PyObject *globals,
PyObject *tb_path, PyObject *tb_function, Py_ssize_t *tb_lines);
PyObject *CPyImport_ImportFromMany(PyObject *mod_id, PyObject *names, PyObject *as_names,
PyObject *globals);

Expand Down
62 changes: 62 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,68 @@ CPy_Super(PyObject *builtins, PyObject *self) {
return result;
}

static bool import_single(PyObject *mod_id,
PyObject *as_name,
PyObject **mod_static,
PyObject *globals_base,
PyObject *globals) {
if (*mod_static == Py_None) {
CPyModule *mod = PyImport_Import(mod_id);
if (mod == NULL) {
return false;
}
*mod_static = mod;
}

if (as_name == Py_None) {
as_name = mod_id;
}
PyObject *globals_id, *globals_name;
if (globals_base == Py_None) {
globals_id = mod_id;
globals_name = as_name;
} else {
globals_id = globals_name = globals_base;
}
PyObject *mod_dict = PyImport_GetModuleDict();
CPyModule *globals_mod = CPyDict_GetItem(mod_dict, globals_id);
if (globals_mod == NULL) {
return false;
}
int ret = CPyDict_SetItem(globals, globals_name, globals_mod);
Py_DECREF(globals_mod);
if (ret < 0) {
return false;
}

return true;
}

// Table-driven import helper. See transform_import() in irbuild for the details.
bool CPyImport_ImportMany(PyObject *modules, CPyModule **statics[], PyObject *globals,
PyObject *tb_path, PyObject *tb_function, Py_ssize_t *tb_lines) {
for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(modules); i++) {
PyObject *module = PyTuple_GET_ITEM(modules, i);
PyObject *mod_id = PyTuple_GET_ITEM(module, 0);
PyObject *as_name = PyTuple_GET_ITEM(module, 1);
PyObject *globals_base = PyTuple_GET_ITEM(module, 2);

if (!import_single(mod_id, as_name, statics[i], globals_base, globals)) {
const char *path = PyUnicode_AsUTF8(tb_path);
if (path == NULL) {
path = "<unable to display>";
}
const char *function = PyUnicode_AsUTF8(tb_function);
if (function == NULL) {
function = "<unable to display>";
}
CPy_AddTraceback(path, function, tb_lines[i], globals);
return false;
}
}
return true;
}

// This helper function is a simplification of cpython/ceval.c/import_from()
static PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
PyObject *import_name, PyObject *as_name) {
Expand Down
18 changes: 17 additions & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
bit_rprimitive,
bool_rprimitive,
c_int_rprimitive,
c_pointer_rprimitive,
c_pyssize_t_rprimitive,
dict_rprimitive,
int_rprimitive,
Expand Down Expand Up @@ -111,14 +112,29 @@
is_borrowed=True,
)

# Import a module
# Import a module (plain)
import_op = custom_op(
arg_types=[str_rprimitive],
return_type=object_rprimitive,
c_function_name="PyImport_Import",
error_kind=ERR_MAGIC,
)

# Import helper op (handles globals/statics & can import multiple modules)
import_many_op = custom_op(
arg_types=[
object_rprimitive,
c_pointer_rprimitive,
object_rprimitive,
object_rprimitive,
object_rprimitive,
c_pointer_rprimitive,
],
return_type=bit_rprimitive,
c_function_name="CPyImport_ImportMany",
error_kind=ERR_FALSE,
)

# From import helper op
import_from_many_op = custom_op(
arg_types=[object_rprimitive, object_rprimitive, object_rprimitive, object_rprimitive],
Expand Down

0 comments on commit 4ca6ed0

Please sign in to comment.