From b95d3e0de5df191806cf4f029a940c7fd2297cb3 Mon Sep 17 00:00:00 2001 From: climbus Date: Sun, 26 Sep 2021 11:18:58 +1000 Subject: [PATCH] Blacken all files Also add pyproject.toml to avoid re-running black on ropetest black currently has poor multi-line string treatment for dedent()-ed code. I've ran aneeshusa's black branch https://github.com/psf/black/pull/1879 on ropetest instead, which leaves dedent()-ed lines alone; however most people likely will be running mainline black which would have mucked these formatting , so we're adding an exclusion rule in pyproject.toml prevent people from auto-formatting ropetest. --- pyproject.toml | 4 + rope/__init__.py | 2 +- rope/base/__init__.py | 2 +- rope/base/arguments.py | 18 +- rope/base/ast.py | 18 +- rope/base/astutils.py | 3 +- rope/base/builtins.py | 378 ++-- rope/base/change.py | 88 +- rope/base/codeanalyze.py | 76 +- rope/base/default_config.py | 56 +- rope/base/evaluate.py | 141 +- rope/base/exceptions.py | 7 +- rope/base/fscommands.py | 96 +- rope/base/history.py | 45 +- rope/base/libutils.py | 33 +- rope/base/oi/doa.py | 81 +- rope/base/oi/memorydb.py | 14 +- rope/base/oi/objectdb.py | 13 +- rope/base/oi/objectinfo.py | 60 +- rope/base/oi/runmod.py | 118 +- rope/base/oi/soa.py | 58 +- rope/base/oi/soi.py | 51 +- rope/base/oi/transform.py | 75 +- rope/base/oi/type_hinting/evaluate.py | 136 +- rope/base/oi/type_hinting/factory.py | 33 +- rope/base/oi/type_hinting/interfaces.py | 1 - .../oi/type_hinting/providers/composite.py | 3 - .../oi/type_hinting/providers/docstrings.py | 22 +- .../oi/type_hinting/providers/inheritance.py | 3 - .../oi/type_hinting/providers/interfaces.py | 3 +- .../type_hinting/providers/numpydocstrings.py | 9 +- .../providers/pep0484_type_comments.py | 12 +- .../oi/type_hinting/resolvers/composite.py | 1 - .../oi/type_hinting/resolvers/interfaces.py | 1 - rope/base/oi/type_hinting/utils.py | 71 +- rope/base/prefs.py | 1 - rope/base/project.py | 120 +- rope/base/pycore.py | 108 +- rope/base/pynames.py | 27 +- rope/base/pynamesdef.py | 11 +- rope/base/pyobjects.py | 33 +- rope/base/pyobjectsdef.py | 180 +- rope/base/pyscopes.py | 63 +- rope/base/resourceobserver.py | 48 +- rope/base/resources.py | 53 +- rope/base/simplify.py | 23 +- rope/base/stdmods.py | 15 +- rope/base/taskhandle.py | 8 +- rope/base/utils/__init__.py | 45 +- rope/base/utils/datastructures.py | 7 +- rope/base/utils/pycompat.py | 11 +- rope/base/worder.py | 190 +- rope/contrib/autoimport.py | 46 +- rope/contrib/changestack.py | 3 +- rope/contrib/codeassist.py | 332 +-- rope/contrib/finderrors.py | 23 +- rope/contrib/findit.py | 55 +- rope/contrib/fixmodnames.py | 17 +- rope/contrib/fixsyntax.py | 68 +- rope/contrib/generate.py | 163 +- rope/refactor/__init__.py | 21 +- rope/refactor/change_signature.py | 178 +- rope/refactor/encapsulate_field.py | 122 +- rope/refactor/extract.py | 359 +-- rope/refactor/functionutils.py | 136 +- rope/refactor/importutils/__init__.py | 147 +- rope/refactor/importutils/actions.py | 102 +- rope/refactor/importutils/importinfo.py | 64 +- rope/refactor/importutils/module_imports.py | 188 +- rope/refactor/inline.py | 365 ++-- rope/refactor/introduce_factory.py | 95 +- rope/refactor/introduce_parameter.py | 44 +- rope/refactor/localtofield.py | 29 +- rope/refactor/method_object.py | 80 +- rope/refactor/move.py | 461 ++-- rope/refactor/multiproject.py | 24 +- rope/refactor/occurrences.py | 100 +- rope/refactor/patchedast.py | 462 ++-- rope/refactor/rename.py | 148 +- rope/refactor/restructure.py | 92 +- rope/refactor/similarfinder.py | 83 +- rope/refactor/sourceutils.py | 28 +- rope/refactor/suites.py | 15 +- rope/refactor/topackage.py | 13 +- rope/refactor/usefunction.py | 86 +- rope/refactor/wildcards.py | 47 +- ropetest/__init__.py | 3 +- ropetest/advanced_oi_test.py | 805 +++---- ropetest/builtinstest.py | 545 ++--- ropetest/codeanalyzetest.py | 403 ++-- ropetest/contrib/__init__.py | 19 +- ropetest/contrib/autoimporttest.py | 123 +- ropetest/contrib/changestacktest.py | 17 +- ropetest/contrib/codeassisttest.py | 1204 +++++----- ropetest/contrib/finderrorstest.py | 16 +- ropetest/contrib/findittest.py | 98 +- ropetest/contrib/fixmodnamestest.py | 35 +- ropetest/contrib/generatetest.py | 272 ++- ropetest/doatest.py | 40 +- ropetest/historytest.py | 248 +-- ropetest/objectdbtest.py | 88 +- ropetest/objectinfertest.py | 373 ++-- ropetest/projecttest.py | 688 +++--- ropetest/pycoretest.py | 1303 +++++------ ropetest/pyscopestest.py | 272 +-- ropetest/refactor/__init__.py | 967 ++++---- ropetest/refactor/change_signature_test.py | 549 ++--- ropetest/refactor/extracttest.py | 1940 +++++++++-------- ropetest/refactor/importutilstest.py | 1359 ++++++------ ropetest/refactor/inlinetest.py | 846 +++---- ropetest/refactor/movetest.py | 1058 ++++----- ropetest/refactor/multiprojecttest.py | 70 +- ropetest/refactor/patchedasttest.py | 1557 ++++++++----- ropetest/refactor/renametest.py | 1082 ++++----- ropetest/refactor/restructuretest.py | 173 +- ropetest/refactor/similarfindertest.py | 213 +- ropetest/refactor/suitestest.py | 55 +- ropetest/refactor/usefunctiontest.py | 93 +- ropetest/runmodtest.py | 140 +- ropetest/simplifytest.py | 28 +- ropetest/testutils.py | 48 +- ropetest/type_hinting_test.py | 471 ++-- setup.py | 111 +- 123 files changed, 12914 insertions(+), 11269 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..b6407a20b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.black] + +target-version = ['py27', 'py33', 'py34', 'py35', 'py36', 'py37', 'py38', 'py39'] +exclude = 'ropetest' diff --git a/rope/__init__.py b/rope/__init__.py index 3504e0f1d..dc054508b 100644 --- a/rope/__init__.py +++ b/rope/__init__.py @@ -1,7 +1,7 @@ """rope, a python refactoring library""" INFO = __doc__ -VERSION = '0.20.1' +VERSION = "0.20.1" COPYRIGHT = """\ Copyright (C) 2019-2021 Matej Cepl Copyright (C) 2015-2018 Nicholas Smith diff --git a/rope/base/__init__.py b/rope/base/__init__.py index ff5f8c63a..a6f798b26 100644 --- a/rope/base/__init__.py +++ b/rope/base/__init__.py @@ -5,4 +5,4 @@ """ -__all__ = ['project', 'libutils', 'exceptions'] +__all__ = ["project", "libutils", "exceptions"] diff --git a/rope/base/arguments.py b/rope/base/arguments.py index 7ba436406..73fee0316 100644 --- a/rope/base/arguments.py +++ b/rope/base/arguments.py @@ -47,14 +47,12 @@ def create_arguments(primary, pyfunction, call_node, scope): args.extend(call_node.keywords) called = call_node.func # XXX: Handle constructors - if _is_method_call(primary, pyfunction) and \ - isinstance(called, ast.Attribute): + if _is_method_call(primary, pyfunction) and isinstance(called, ast.Attribute): args.insert(0, called.value) return Arguments(args, scope) class ObjectArguments(object): - def __init__(self, pynames): self.pynames = pynames @@ -75,7 +73,6 @@ def get_instance_pyname(self): class MixedArguments(object): - def __init__(self, pyname, arguments, scope): """`argumens` is an instance of `Arguments`""" self.pyname = pyname @@ -101,11 +98,14 @@ def _is_method_call(primary, pyfunction): if primary is None: return False pyobject = primary.get_object() - if isinstance(pyobject.get_type(), rope.base.pyobjects.PyClass) and \ - isinstance(pyfunction, rope.base.pyobjects.PyFunction) and \ - isinstance(pyfunction.parent, rope.base.pyobjects.PyClass): + if ( + isinstance(pyobject.get_type(), rope.base.pyobjects.PyClass) + and isinstance(pyfunction, rope.base.pyobjects.PyFunction) + and isinstance(pyfunction.parent, rope.base.pyobjects.PyClass) + ): return True - if isinstance(pyobject.get_type(), rope.base.pyobjects.AbstractClass) and \ - isinstance(pyfunction, rope.base.builtins.BuiltinFunction): + if isinstance( + pyobject.get_type(), rope.base.pyobjects.AbstractClass + ) and isinstance(pyfunction, rope.base.builtins.BuiltinFunction): return True return False diff --git a/rope/base/ast.py b/rope/base/ast.py index d24524e72..b893dc953 100644 --- a/rope/base/ast.py +++ b/rope/base/ast.py @@ -10,16 +10,16 @@ unicode = str -def parse(source, filename=''): +def parse(source, filename=""): # NOTE: the raw string should be given to `compile` function if isinstance(source, unicode): source = fscommands.unicode_to_file_data(source) - if b'\r' in source: - source = source.replace(b'\r\n', b'\n').replace(b'\r', b'\n') - if not source.endswith(b'\n'): - source += b'\n' + if b"\r" in source: + source = source.replace(b"\r\n", b"\n").replace(b"\r", b"\n") + if not source.endswith(b"\n"): + source += b"\n" try: - return ast.parse(source, filename='') + return ast.parse(source, filename="") except (TypeError, ValueError) as e: error = SyntaxError() error.lineno = 1 @@ -30,13 +30,13 @@ def parse(source, filename=''): def walk(node, walker): """Walk the syntax tree""" - method_name = '_' + node.__class__.__name__ + method_name = "_" + node.__class__.__name__ method = getattr(walker, method_name, None) if method is not None: if isinstance(node, ast.ImportFrom) and node.module is None: # In python < 2.7 ``node.module == ''`` for relative imports # but for python 2.7 it is None. Generalizing it to ''. - node.module = '' + node.module = "" return method(node) for child in get_child_nodes(node): walk(child, walker) @@ -70,7 +70,7 @@ def get_children(node): result = [] if node._fields is not None: for name in node._fields: - if name in ['lineno', 'col_offset']: + if name in ["lineno", "col_offset"]: continue child = getattr(node, name) result.append(child) diff --git a/rope/base/astutils.py b/rope/base/astutils.py index 6c0b3d78e..68ae42343 100644 --- a/rope/base/astutils.py +++ b/rope/base/astutils.py @@ -19,7 +19,6 @@ def get_name_levels(node): class _NodeNameCollector(object): - def __init__(self, levels=None): self.names = [] self.levels = levels @@ -34,7 +33,7 @@ def _add_node(self, node): self._added(node, new_levels) def _added(self, node, levels): - if hasattr(node, 'id'): + if hasattr(node, "id"): self.names.append((node.id, levels)) def _Name(self, node): diff --git a/rope/base/builtins.py b/rope/base/builtins.py index 2dfbc12ef..ae5f5b861 100644 --- a/rope/base/builtins.py +++ b/rope/base/builtins.py @@ -13,7 +13,6 @@ class BuiltinModule(pyobjects.AbstractModule): - def __init__(self, name, pycore=None, initial={}): super(BuiltinModule, self).__init__() self.name = name @@ -30,7 +29,7 @@ def get_doc(self): return self.module.__doc__ def get_name(self): - return self.name.split('.')[-1] + return self.name.split(".")[-1] @property @utils.saveit @@ -48,7 +47,7 @@ def attributes(self): def module(self): try: result = __import__(self.name) - for token in self.name.split('.')[1:]: + for token in self.name.split(".")[1:]: result = getattr(result, token, None) return result except ImportError: @@ -56,18 +55,17 @@ def module(self): class _BuiltinElement(object): - def __init__(self, builtin, parent=None): self.builtin = builtin self._parent = parent def get_doc(self): if self.builtin: - return getattr(self.builtin, '__doc__', None) + return getattr(self.builtin, "__doc__", None) def get_name(self): if self.builtin: - return getattr(self.builtin, '__name__', None) + return getattr(self.builtin, "__name__", None) @property def parent(self): @@ -77,7 +75,6 @@ def parent(self): class BuiltinClass(_BuiltinElement, pyobjects.AbstractClass): - def __init__(self, builtin, attributes, parent=None): _BuiltinElement.__init__(self, builtin, parent) pyobjects.AbstractClass.__init__(self) @@ -94,9 +91,9 @@ def get_module(self): class BuiltinFunction(_BuiltinElement, pyobjects.AbstractFunction): - - def __init__(self, returned=None, function=None, builtin=None, - argnames=[], parent=None): + def __init__( + self, returned=None, function=None, builtin=None, argnames=[], parent=None + ): _BuiltinElement.__init__(self, builtin, parent) pyobjects.AbstractFunction.__init__(self) self.argnames = argnames @@ -114,14 +111,13 @@ def get_param_names(self, special_args=True): class BuiltinUnknown(_BuiltinElement, pyobjects.PyObject): - def __init__(self, builtin): super(BuiltinUnknown, self).__init__(pyobjects.get_unknown()) self.builtin = builtin self.type = pyobjects.get_unknown() def get_name(self): - return getattr(type(self.builtin), '__name__', None) + return getattr(type(self.builtin), "__name__", None) @utils.saveit def get_attributes(self): @@ -131,7 +127,7 @@ def get_attributes(self): def _object_attributes(obj, parent): attributes = {} for name in dir(obj): - if name == 'None': + if name == "None": continue try: child = getattr(obj, name) @@ -152,11 +148,12 @@ def _object_attributes(obj, parent): def _create_builtin_type_getter(cls): def _get_builtin(*args): - if not hasattr(cls, '_generated'): + if not hasattr(cls, "_generated"): cls._generated = {} if args not in cls._generated: cls._generated[args] = cls(*args) return cls._generated[args] + return _get_builtin @@ -165,11 +162,11 @@ def _create_builtin_getter(cls): def _get_builtin(*args): return pyobjects.PyObject(type_getter(*args)) + return _get_builtin class _CallContext(object): - def __init__(self, argnames, args): self.argnames = argnames self.args = args @@ -233,14 +230,19 @@ def save_per_name(self, value): class _AttributeCollector(object): - def __init__(self, type): self.attributes = {} self.type = type - def __call__(self, name, returned=None, function=None, - argnames=['self'], check_existence=True, - parent=None): + def __call__( + self, + name, + returned=None, + function=None, + argnames=["self"], + check_existence=True, + parent=None, + ): try: builtin = getattr(self.type, name) except AttributeError: @@ -248,38 +250,55 @@ def __call__(self, name, returned=None, function=None, raise builtin = None self.attributes[name] = BuiltinName( - BuiltinFunction(returned=returned, function=function, - argnames=argnames, builtin=builtin, - parent=parent)) + BuiltinFunction( + returned=returned, + function=function, + argnames=argnames, + builtin=builtin, + parent=parent, + ) + ) def __setitem__(self, name, value): self.attributes[name] = value class List(BuiltinClass): - def __init__(self, holding=None): self.holding = holding collector = _AttributeCollector(list) - collector('__iter__', function=self._iterator_get, parent=self) - collector('__new__', function=self._new_list, parent=self) + collector("__iter__", function=self._iterator_get, parent=self) + collector("__new__", function=self._new_list, parent=self) # Adding methods - collector('append', function=self._list_add, - argnames=['self', 'value'], parent=self) - collector('__setitem__', function=self._list_add, - argnames=['self', 'index', 'value'], parent=self) - collector('insert', function=self._list_add, - argnames=['self', 'index', 'value'], parent=self) - collector('extend', function=self._self_set, - argnames=['self', 'iterable'], parent=self) + collector( + "append", function=self._list_add, argnames=["self", "value"], parent=self + ) + collector( + "__setitem__", + function=self._list_add, + argnames=["self", "index", "value"], + parent=self, + ) + collector( + "insert", + function=self._list_add, + argnames=["self", "index", "value"], + parent=self, + ) + collector( + "extend", + function=self._self_set, + argnames=["self", "iterable"], + parent=self, + ) # Getting methods - collector('__getitem__', function=self._list_get, parent=self) - collector('pop', function=self._list_get, parent=self) + collector("__getitem__", function=self._list_get, parent=self) + collector("pop", function=self._list_get, parent=self) try: - collector('__getslice__', function=self._list_get) + collector("__getslice__", function=self._list_get) except AttributeError: pass @@ -291,23 +310,26 @@ def _new_list(self, args): def _list_add(self, context): if self.holding is not None: return - holding = context.get_argument('value') + holding = context.get_argument("value") if holding is not None and holding != pyobjects.get_unknown(): context.save_per_name(holding) def _self_set(self, context): if self.holding is not None: return - iterable = context.get_pyname('iterable') + iterable = context.get_pyname("iterable") holding = _infer_sequence_for_pyname(iterable) if holding is not None and holding != pyobjects.get_unknown(): context.save_per_name(holding) def _list_get(self, context): if self.holding is not None: - args = context.get_arguments(['self', 'key']) - if (len(args) > 1 and args[1] is not None and - args[1].get_type() == builtins['slice'].get_object()): + args = context.get_arguments(["self", "key"]) + if ( + len(args) > 1 + and args[1] is not None + and args[1].get_type() == builtins["slice"].get_object() + ): return get_list(self.holding) return self.holding return context.get_per_name() @@ -324,23 +346,22 @@ def _self_get(self, context): class Dict(BuiltinClass): - def __init__(self, keys=None, values=None): self.keys = keys self.values = values collector = _AttributeCollector(dict) - collector('__new__', function=self._new_dict, parent=self) - collector('__setitem__', function=self._dict_add, parent=self) - collector('popitem', function=self._item_get, parent=self) - collector('pop', function=self._value_get, parent=self) - collector('get', function=self._key_get, parent=self) - collector('keys', function=self._key_list, parent=self) - collector('values', function=self._value_list, parent=self) - collector('items', function=self._item_list, parent=self) - collector('copy', function=self._self_get, parent=self) - collector('__getitem__', function=self._value_get, parent=self) - collector('__iter__', function=self._key_iter, parent=self) - collector('update', function=self._self_set, parent=self) + collector("__new__", function=self._new_dict, parent=self) + collector("__setitem__", function=self._dict_add, parent=self) + collector("popitem", function=self._item_get, parent=self) + collector("pop", function=self._value_get, parent=self) + collector("get", function=self._key_get, parent=self) + collector("keys", function=self._key_list, parent=self) + collector("values", function=self._value_list, parent=self) + collector("items", function=self._item_list, parent=self) + collector("copy", function=self._self_get, parent=self) + collector("__getitem__", function=self._value_get, parent=self) + collector("__iter__", function=self._key_iter, parent=self) + collector("update", function=self._self_set, parent=self) super(Dict, self).__init__(dict, collector.attributes) def _new_dict(self, args): @@ -348,15 +369,15 @@ def do_create(holding=None): if holding is None: return get_dict() type = holding.get_type() - if isinstance(type, Tuple) and \ - len(type.get_holding_objects()) == 2: + if isinstance(type, Tuple) and len(type.get_holding_objects()) == 2: return get_dict(*type.get_holding_objects()) + return _create_builtin(args, do_create) def _dict_add(self, context): if self.keys is not None: return - key, value = context.get_arguments(['self', 'key', 'value'])[1:] + key, value = context.get_arguments(["self", "key", "value"])[1:] if key is not None and key != pyobjects.get_unknown(): context.save_per_name(get_tuple(key, value)) @@ -402,11 +423,12 @@ def _self_get(self, context): def _self_set(self, context): if self.keys is not None: return - new_dict = context.get_pynames(['self', 'd'])[1] + new_dict = context.get_pynames(["self", "d"])[1] if new_dict and isinstance(new_dict.get_object().get_type(), Dict): args = arguments.ObjectArguments([new_dict]) - items = new_dict.get_object()['popitem'].\ - get_object().get_returned_object(args) + items = ( + new_dict.get_object()["popitem"].get_object().get_returned_object(args) + ) context.save_per_name(items) else: holding = _infer_sequence_for_pyname(new_dict) @@ -419,18 +441,19 @@ def _self_set(self, context): class Tuple(BuiltinClass): - def __init__(self, *objects): self.objects = objects first = None if objects: first = objects[0] attributes = { - '__getitem__': BuiltinName(BuiltinFunction(first)), # TODO: add slice support - '__getslice__': - BuiltinName(BuiltinFunction(pyobjects.PyObject(self))), - '__new__': BuiltinName(BuiltinFunction(function=self._new_tuple)), - '__iter__': BuiltinName(BuiltinFunction(get_iterator(first)))} + "__getitem__": BuiltinName( + BuiltinFunction(first) + ), # TODO: add slice support + "__getslice__": BuiltinName(BuiltinFunction(pyobjects.PyObject(self))), + "__new__": BuiltinName(BuiltinFunction(function=self._new_tuple)), + "__iter__": BuiltinName(BuiltinFunction(get_iterator(first))), + } super(Tuple, self).__init__(tuple, attributes) def get_holding_objects(self): @@ -445,25 +468,28 @@ def _new_tuple(self, args): class Set(BuiltinClass): - def __init__(self, holding=None): self.holding = holding collector = _AttributeCollector(set) - collector('__new__', function=self._new_set) - - self_methods = ['copy', 'difference', 'intersection', - 'symmetric_difference', 'union'] + collector("__new__", function=self._new_set) + + self_methods = [ + "copy", + "difference", + "intersection", + "symmetric_difference", + "union", + ] for method in self_methods: collector(method, function=self._self_get, parent=self) - collector('add', function=self._set_add, parent=self) - collector('update', function=self._self_set, parent=self) - collector('update', function=self._self_set, parent=self) - collector('symmetric_difference_update', function=self._self_set, - parent=self) - collector('difference_update', function=self._self_set, parent=self) - - collector('pop', function=self._set_get, parent=self) - collector('__iter__', function=self._iterator_get, parent=self) + collector("add", function=self._set_add, parent=self) + collector("update", function=self._self_set, parent=self) + collector("update", function=self._self_set, parent=self) + collector("symmetric_difference_update", function=self._self_set, parent=self) + collector("difference_update", function=self._self_set, parent=self) + + collector("pop", function=self._set_get, parent=self) + collector("__iter__", function=self._iterator_get, parent=self) super(Set, self).__init__(set, collector.attributes) def _new_set(self, args): @@ -472,14 +498,14 @@ def _new_set(self, args): def _set_add(self, context): if self.holding is not None: return - holding = context.get_arguments(['self', 'value'])[1] + holding = context.get_arguments(["self", "value"])[1] if holding is not None and holding != pyobjects.get_unknown(): context.save_per_name(holding) def _self_set(self, context): if self.holding is not None: return - iterable = context.get_pyname('iterable') + iterable = context.get_pyname("iterable") holding = _infer_sequence_for_pyname(iterable) if holding is not None and holding != pyobjects.get_unknown(): context.save_per_name(holding) @@ -501,17 +527,31 @@ def _self_get(self, context): class Str(BuiltinClass): - def __init__(self): self_object = pyobjects.PyObject(self) collector = _AttributeCollector(str) - collector('__iter__', get_iterator(self_object), check_existence=False) - - self_methods = ['__getitem__', 'capitalize', 'center', - 'encode', 'expandtabs', 'join', 'ljust', - 'lower', 'lstrip', 'replace', 'rjust', 'rstrip', - 'strip', 'swapcase', 'title', 'translate', 'upper', - 'zfill'] + collector("__iter__", get_iterator(self_object), check_existence=False) + + self_methods = [ + "__getitem__", + "capitalize", + "center", + "encode", + "expandtabs", + "join", + "ljust", + "lower", + "lstrip", + "replace", + "rjust", + "rstrip", + "strip", + "swapcase", + "title", + "translate", + "upper", + "zfill", + ] for method in self_methods: collector(method, self_object, parent=self) @@ -522,7 +562,7 @@ def __init__(self): except AttributeError: pass - for method in ['rsplit', 'split', 'splitlines']: + for method in ["rsplit", "split", "splitlines"]: collector(method, get_list(self_object), parent=self) super(Str, self).__init__(str, collector.attributes) @@ -536,7 +576,6 @@ def get_doc(self): class BuiltinName(pynames.PyName): - def __init__(self, pyobject): self.pyobject = pyobject @@ -548,13 +587,13 @@ def get_definition_location(self): class Iterator(pyobjects.AbstractClass): - def __init__(self, holding=None): super(Iterator, self).__init__() self.holding = holding self.attributes = { - 'next': BuiltinName(BuiltinFunction(self.holding)), - '__iter__': BuiltinName(BuiltinFunction(self))} + "next": BuiltinName(BuiltinFunction(self.holding)), + "__iter__": BuiltinName(BuiltinFunction(self)), + } def get_attributes(self): return self.attributes @@ -562,21 +601,21 @@ def get_attributes(self): def get_returned_object(self, args): return self.holding + get_iterator = _create_builtin_getter(Iterator) class Generator(pyobjects.AbstractClass): - def __init__(self, holding=None): super(Generator, self).__init__() self.holding = holding self.attributes = { - 'next': BuiltinName(BuiltinFunction(self.holding)), - '__iter__': BuiltinName(BuiltinFunction( - get_iterator(self.holding))), - 'close': BuiltinName(BuiltinFunction()), - 'send': BuiltinName(BuiltinFunction()), - 'throw': BuiltinName(BuiltinFunction())} + "next": BuiltinName(BuiltinFunction(self.holding)), + "__iter__": BuiltinName(BuiltinFunction(get_iterator(self.holding))), + "close": BuiltinName(BuiltinFunction()), + "send": BuiltinName(BuiltinFunction()), + "throw": BuiltinName(BuiltinFunction()), + } def get_attributes(self): return self.attributes @@ -584,12 +623,12 @@ def get_attributes(self): def get_returned_object(self, args): return self.holding + get_generator = _create_builtin_getter(Generator) class File(BuiltinClass): - - def __init__(self, filename=None, mode='r', *args): + def __init__(self, filename=None, mode="r", *args): self.filename = filename self.mode = mode self.args = args @@ -600,14 +639,24 @@ def __init__(self, filename=None, mode='r', *args): def add(name, returned=None, function=None): builtin = getattr(io.TextIOBase, name, None) attributes[name] = BuiltinName( - BuiltinFunction(returned=returned, function=function, - builtin=builtin)) - add('__iter__', get_iterator(str_object)) - add('__enter__', returned=pyobjects.PyObject(self)) - for method in ['next', 'read', 'readline', 'readlines']: + BuiltinFunction(returned=returned, function=function, builtin=builtin) + ) + + add("__iter__", get_iterator(str_object)) + add("__enter__", returned=pyobjects.PyObject(self)) + for method in ["next", "read", "readline", "readlines"]: add(method, str_list) - for method in ['close', 'flush', 'lineno', 'isatty', 'seek', 'tell', - 'truncate', 'write', 'writelines']: + for method in [ + "close", + "flush", + "lineno", + "isatty", + "seek", + "tell", + "truncate", + "write", + "writelines", + ]: add(method) super(File, self).__init__(open, attributes) @@ -617,16 +666,15 @@ def add(name, returned=None, function=None): class Property(BuiltinClass): - def __init__(self, fget=None, fset=None, fdel=None, fdoc=None): self._fget = fget self._fdoc = fdoc attributes = { - 'fget': BuiltinName(BuiltinFunction()), - 'fset': BuiltinName(pynames.UnboundName()), - 'fdel': BuiltinName(pynames.UnboundName()), - '__new__': BuiltinName( - BuiltinFunction(function=_property_function))} + "fget": BuiltinName(BuiltinFunction()), + "fset": BuiltinName(pynames.UnboundName()), + "fdel": BuiltinName(pynames.UnboundName()), + "__new__": BuiltinName(BuiltinFunction(function=_property_function)), + } super(Property, self).__init__(property, attributes) def get_property_object(self, args): @@ -635,12 +683,11 @@ def get_property_object(self, args): def _property_function(args): - parameters = args.get_arguments(['fget', 'fset', 'fdel', 'fdoc']) + parameters = args.get_arguments(["fget", "fset", "fdel", "fdoc"]) return pyobjects.PyObject(Property(parameters[0])) class Lambda(pyobjects.AbstractFunction): - def __init__(self, node, scope): super(Lambda, self).__init__() self.node = node @@ -661,7 +708,7 @@ def get_scope(self): return self.scope def get_kind(self): - return 'lambda' + return "lambda" def get_ast(self): return self.node @@ -670,15 +717,18 @@ def get_attributes(self): return {} def get_name(self): - return 'lambda' + return "lambda" def get_param_names(self, special_args=True): - result = [pycompat.get_ast_arg_arg(node) for node in self.arguments.args - if isinstance(node, pycompat.ast_arg_type)] + result = [ + pycompat.get_ast_arg_arg(node) + for node in self.arguments.args + if isinstance(node, pycompat.ast_arg_type) + ] if self.arguments.vararg: - result.append('*' + pycompat.get_ast_arg_arg(self.arguments.vararg)) + result.append("*" + pycompat.get_ast_arg_arg(self.arguments.vararg)) if self.arguments.kwarg: - result.append('**' + pycompat.get_ast_arg_arg(self.arguments.kwarg)) + result.append("**" + pycompat.get_ast_arg_arg(self.arguments.kwarg)) return result @property @@ -687,13 +737,11 @@ def parent(self): class BuiltinObject(BuiltinClass): - def __init__(self): super(BuiltinObject, self).__init__(object, {}) class BuiltinType(BuiltinClass): - def __init__(self): super(BuiltinType, self).__init__(type, {}) @@ -703,19 +751,18 @@ def _infer_sequence_for_pyname(pyname): return None seq = pyname.get_object() args = arguments.ObjectArguments([pyname]) - if '__iter__' in seq: - obj = seq['__iter__'].get_object() + if "__iter__" in seq: + obj = seq["__iter__"].get_object() if not isinstance(obj, pyobjects.AbstractFunction): return None iter = obj.get_returned_object(args) - if iter is not None and 'next' in iter: - holding = iter['next'].get_object().\ - get_returned_object(args) + if iter is not None and "next" in iter: + holding = iter["next"].get_object().get_returned_object(args) return holding def _create_builtin(args, creator): - passed = args.get_pynames(['sequence'])[0] + passed = args.get_pynames(["sequence"])[0] if passed is None: holding = None else: @@ -743,11 +790,11 @@ def _sorted_function(args): def _super_function(args): - passed_class, passed_self = args.get_arguments(['type', 'self']) + passed_class, passed_self = args.get_arguments(["type", "self"]) if passed_self is None: return passed_class else: - #pyclass = passed_self.get_type() + # pyclass = passed_self.get_type() pyclass = passed_class if isinstance(pyclass, pyobjects.AbstractClass): supers = pyclass.get_superclasses() @@ -757,7 +804,7 @@ def _super_function(args): def _zip_function(args): - args = args.get_pynames(['sequence']) + args = args.get_pynames(["sequence"]) objects = [] for seq in args: if seq is None: @@ -770,7 +817,7 @@ def _zip_function(args): def _enumerate_function(args): - passed = args.get_pynames(['sequence'])[0] + passed = args.get_pynames(["sequence"])[0] if passed is None: holding = None else: @@ -780,7 +827,7 @@ def _enumerate_function(args): def _iter_function(args): - passed = args.get_pynames(['sequence'])[0] + passed = args.get_pynames(["sequence"])[0] if passed is None: holding = None else: @@ -793,34 +840,33 @@ def _input_function(args): _initial_builtins = { - 'list': BuiltinName(get_list_type()), - 'dict': BuiltinName(get_dict_type()), - 'tuple': BuiltinName(get_tuple_type()), - 'set': BuiltinName(get_set_type()), - 'str': BuiltinName(get_str_type()), - 'file': BuiltinName(get_file_type()), - 'open': BuiltinName(BuiltinFunction(function=_open_function, - builtin=open)), - 'unicode': BuiltinName(get_str_type()), - 'range': BuiltinName(BuiltinFunction(function=_range_function, - builtin=range)), - 'reversed': BuiltinName(BuiltinFunction(function=_reversed_function, - builtin=reversed)), - 'sorted': BuiltinName(BuiltinFunction(function=_sorted_function, - builtin=sorted)), - 'super': BuiltinName(BuiltinFunction(function=_super_function, - builtin=super)), - 'property': BuiltinName(BuiltinFunction(function=_property_function, - builtin=property)), - 'zip': BuiltinName(BuiltinFunction(function=_zip_function, builtin=zip)), - 'enumerate': BuiltinName(BuiltinFunction(function=_enumerate_function, - builtin=enumerate)), - 'object': BuiltinName(BuiltinObject()), - 'type': BuiltinName(BuiltinType()), - 'iter': BuiltinName(BuiltinFunction(function=_iter_function, - builtin=iter)), - 'raw_input': BuiltinName(BuiltinFunction(function=_input_function, - builtin=raw_input)), + "list": BuiltinName(get_list_type()), + "dict": BuiltinName(get_dict_type()), + "tuple": BuiltinName(get_tuple_type()), + "set": BuiltinName(get_set_type()), + "str": BuiltinName(get_str_type()), + "file": BuiltinName(get_file_type()), + "open": BuiltinName(BuiltinFunction(function=_open_function, builtin=open)), + "unicode": BuiltinName(get_str_type()), + "range": BuiltinName(BuiltinFunction(function=_range_function, builtin=range)), + "reversed": BuiltinName( + BuiltinFunction(function=_reversed_function, builtin=reversed) + ), + "sorted": BuiltinName(BuiltinFunction(function=_sorted_function, builtin=sorted)), + "super": BuiltinName(BuiltinFunction(function=_super_function, builtin=super)), + "property": BuiltinName( + BuiltinFunction(function=_property_function, builtin=property) + ), + "zip": BuiltinName(BuiltinFunction(function=_zip_function, builtin=zip)), + "enumerate": BuiltinName( + BuiltinFunction(function=_enumerate_function, builtin=enumerate) + ), + "object": BuiltinName(BuiltinObject()), + "type": BuiltinName(BuiltinType()), + "iter": BuiltinName(BuiltinFunction(function=_iter_function, builtin=iter)), + "raw_input": BuiltinName( + BuiltinFunction(function=_input_function, builtin=raw_input) + ), } builtins = BuiltinModule(pycompat.builtins.__name__, initial=_initial_builtins) diff --git a/rope/base/change.py b/rope/base/change.py index 07679f66c..f0a61c57b 100644 --- a/rope/base/change.py +++ b/rope/base/change.py @@ -85,27 +85,26 @@ def add_change(self, change): self.changes.append(change) def get_description(self): - result = [str(self) + ':\n\n\n'] + result = [str(self) + ":\n\n\n"] for change in self.changes: result.append(change.get_description()) - result.append('\n') - return ''.join(result) + result.append("\n") + return "".join(result) def __str__(self): if self.time is not None: date = datetime.datetime.fromtimestamp(self.time) if date.date() == datetime.date.today(): - string_date = 'today' - elif date.date() == (datetime.date.today() - - datetime.timedelta(1)): - string_date = 'yesterday' + string_date = "today" + elif date.date() == (datetime.date.today() - datetime.timedelta(1)): + string_date = "yesterday" elif date.year == datetime.date.today().year: - string_date = date.strftime('%b %d') + string_date = date.strftime("%b %d") else: - string_date = date.strftime('%d %b, %Y') - string_time = date.strftime('%H:%M:%S') - string_time = '%s %s ' % (string_date, string_time) - return self.description + ' - ' + string_time + string_date = date.strftime("%d %b, %Y") + string_time = date.strftime("%H:%M:%S") + string_time = "%s %s " % (string_date, string_time) + return self.description + " - " + string_time return self.description def get_changed_resources(self): @@ -121,10 +120,12 @@ def _handle_job_set(function): A decorator for handling `taskhandle.JobSet` for `do` and `undo` methods of `Change`. """ + def call(self, job_set=taskhandle.NullJobSet()): job_set.started_job(str(self)) function(self) job_set.finished_job() + return call @@ -152,12 +153,11 @@ def do(self): @_handle_job_set def undo(self): if self.old_contents is None: - raise exceptions.HistoryError( - 'Undoing a change that is not performed yet!') + raise exceptions.HistoryError("Undoing a change that is not performed yet!") self._operations.write_file(self.resource, self.old_contents) def __str__(self): - return 'Change <%s>' % self.resource.path + return "Change <%s>" % self.resource.path def get_description(self): new = self.new_contents @@ -166,11 +166,14 @@ def get_description(self): if self.resource.exists(): old = self.resource.read() else: - old = '' + old = "" result = difflib.unified_diff( - old.splitlines(True), new.splitlines(True), - 'a/' + self.resource.path, 'b/' + self.resource.path) - return ''.join(list(result)) + old.splitlines(True), + new.splitlines(True), + "a/" + self.resource.path, + "b/" + self.resource.path, + ) + return "".join(list(result)) def get_changed_resources(self): return [self.resource] @@ -205,11 +208,13 @@ def undo(self): self._operations.move(self.new_resource, self.resource) def __str__(self): - return 'Move <%s>' % self.resource.path + return "Move <%s>" % self.resource.path def get_description(self): - return 'rename from %s\nrename to %s' % (self.resource.path, - self.new_resource.path) + return "rename from %s\nrename to %s" % ( + self.resource.path, + self.new_resource.path, + ) def get_changed_resources(self): return [self.resource, self.new_resource] @@ -235,19 +240,19 @@ def undo(self): self._operations.remove(self.resource) def __str__(self): - return 'Create Resource <%s>' % (self.resource.path) + return "Create Resource <%s>" % (self.resource.path) def get_description(self): - return 'new file %s' % (self.resource.path) + return "new file %s" % (self.resource.path) def get_changed_resources(self): return [self.resource] def _get_child_path(self, parent, name): - if parent.path == '': + if parent.path == "": return name else: - return parent.path + '/' + name + return parent.path + "/" + name class CreateFolder(CreateResource): @@ -257,8 +262,7 @@ class CreateFolder(CreateResource): """ def __init__(self, parent, name): - resource = parent.project.get_folder( - self._get_child_path(parent, name)) + resource = parent.project.get_folder(self._get_child_path(parent, name)) super(CreateFolder, self).__init__(resource) @@ -291,11 +295,10 @@ def do(self): # TODO: Undoing remove operations @_handle_job_set def undo(self): - raise NotImplementedError( - 'Undoing `RemoveResource` is not implemented yet.') + raise NotImplementedError("Undoing `RemoveResource` is not implemented yet.") def __str__(self): - return 'Remove <%s>' % (self.resource.path) + return "Remove <%s>" % (self.resource.path) def get_changed_resources(self): return [self.resource] @@ -316,7 +319,6 @@ def create_job_set(task_handle, change): class _ResourceOperations(object): - def __init__(self, project): self.project = project self.fscommands = project.fscommands @@ -342,7 +344,7 @@ def move(self, resource, new_resource): def create(self, resource): if resource.is_folder(): - self._create_resource(resource.path, kind='folder') + self._create_resource(resource.path, kind="folder") else: self._create_resource(resource.path) for observer in list(self.project.observers): @@ -354,18 +356,18 @@ def remove(self, resource): for observer in list(self.project.observers): observer.resource_removed(resource) - def _create_resource(self, file_name, kind='file'): + def _create_resource(self, file_name, kind="file"): resource_path = self.project._get_resource_path(file_name) if os.path.exists(resource_path): - raise exceptions.RopeError('Resource <%s> already exists' - % resource_path) + raise exceptions.RopeError("Resource <%s> already exists" % resource_path) resource = self.project.get_file(file_name) if not resource.parent.exists(): raise exceptions.ResourceNotFoundError( - 'Parent folder of <%s> does not exist' % resource.path) + "Parent folder of <%s> does not exist" % resource.path + ) fscommands = self._get_fscommands(resource) try: - if kind == 'file': + if kind == "file": fscommands.create_file(resource_path) else: fscommands.create_folder(resource_path) @@ -376,15 +378,14 @@ def _create_resource(self, file_name, kind='file'): def _get_destination_for_move(resource, destination): dest_path = resource.project._get_resource_path(destination) if os.path.isdir(dest_path): - if destination != '': - return destination + '/' + resource.name + if destination != "": + return destination + "/" + resource.name else: return resource.name return destination class ChangeToData(object): - def convertChangeSet(self, change): description = change.description changes = [] @@ -408,12 +409,11 @@ def __call__(self, change): change_type = type(change) if change_type in (CreateFolder, CreateFile): change_type = CreateResource - method = getattr(self, 'convert' + change_type.__name__) + method = getattr(self, "convert" + change_type.__name__) return (change_type.__name__, method(change)) class DataToChange(object): - def __init__(self, project): self.project = project @@ -446,5 +446,5 @@ def makeRemoveResource(self, path, is_folder): return RemoveResource(resource) def __call__(self, data): - method = getattr(self, 'make' + data[0]) + method = getattr(self, "make" + data[0]) return method(*data[1]) diff --git a/rope/base/codeanalyze.py b/rope/base/codeanalyze.py index fcd3c833b..bb0d988e3 100644 --- a/rope/base/codeanalyze.py +++ b/rope/base/codeanalyze.py @@ -5,7 +5,6 @@ class ChangeCollector(object): - def __init__(self, text): self.text = text self.changes = [] @@ -28,7 +27,7 @@ def get_changed(self): last_changed = end if last_changed < len(self.text): pieces.append(self.text[last_changed:]) - result = ''.join(pieces) + result = "".join(pieces) if result != self.text: return result @@ -50,15 +49,14 @@ def _initialize_line_starts(self): try: i = 0 while True: - i = self.code.index('\n', i) + 1 + i = self.code.index("\n", i) + 1 self.starts.append(i) except ValueError: pass self.starts.append(len(self.code) + 1) def get_line(self, lineno): - return self.code[self.starts[lineno - 1]: - self.starts[lineno] - 1] + return self.code[self.starts[lineno - 1] : self.starts[lineno] - 1] def length(self): return len(self.starts) - 1 @@ -74,7 +72,6 @@ def get_line_end(self, lineno): class ArrayLinesAdapter(object): - def __init__(self, lines): self.lines = lines @@ -86,7 +83,6 @@ def length(self): class LinesToReadline(object): - def __init__(self, lines, start): self.lines = lines self.current = start @@ -94,18 +90,17 @@ def __init__(self, lines, start): def readline(self): if self.current <= self.lines.length(): self.current += 1 - return self.lines.get_line(self.current - 1) + '\n' - return '' + return self.lines.get_line(self.current - 1) + "\n" + return "" def __call__(self): return self.readline() class _CustomGenerator(object): - def __init__(self, lines): self.lines = lines - self.in_string = '' + self.in_string = "" self.open_count = 0 self.continuation = False @@ -121,8 +116,10 @@ def __call__(self): while True: line = self.lines.get_line(i) self._analyze_line(line) - if not (self.continuation or self.open_count or - self.in_string) or i == size: + if ( + not (self.continuation or self.open_count or self.in_string) + or i == size + ): break i += 1 result.append((start, i)) @@ -143,17 +140,19 @@ def _analyze_line(self, line): if token in ["'''", '"""', "'", '"']: if not self.in_string: self.in_string = token - elif self.in_string == token or (self.in_string in ['"', "'"] and token == 3*self.in_string): - self.in_string = '' + elif self.in_string == token or ( + self.in_string in ['"', "'"] and token == 3 * self.in_string + ): + self.in_string = "" if self.in_string: continue - if token == '#': + if token == "#": break - if token in '([{': + if token in "([{": self.open_count += 1 - elif token in ')]}': + elif token in ")]}": self.open_count -= 1 - if line and token != '#' and line.endswith('\\'): + if line and token != "#" and line.endswith("\\"): self.continuation = True else: self.continuation = False @@ -164,7 +163,6 @@ def custom_generator(lines): class LogicalLineFinder(object): - def __init__(self, lines): self.lines = lines @@ -237,7 +235,7 @@ def _first_non_blank(self, line_number): current = line_number while current < self.lines.length(): line = self.lines.get_line(current).strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): return current current += 1 return current @@ -248,7 +246,6 @@ def tokenizer_generator(lines): class CachingLogicalLineFinder(object): - def __init__(self, lines, generate=custom_generator): self.lines = lines self._generate = generate @@ -302,19 +299,21 @@ def get_block_start(lines, lineno, maximum_indents=80): pattern = get_block_start_patterns() for i in range(lineno, 0, -1): match = pattern.search(lines.get_line(i)) - if match is not None and \ - count_line_indents(lines.get_line(i)) <= maximum_indents: + if ( + match is not None + and count_line_indents(lines.get_line(i)) <= maximum_indents + ): striped = match.string.lstrip() # Maybe we're in a list comprehension or generator expression - if i > 1 and striped.startswith('if') or striped.startswith('for'): + if i > 1 and striped.startswith("if") or striped.startswith("for"): bracs = 0 for j in range(i, min(i + 5, lines.length() + 1)): for c in lines.get_line(j): - if c == '#': + if c == "#": break - if c in '[(': + if c in "[(": bracs += 1 - if c in ')]': + if c in ")]": bracs -= 1 if bracs < 0: break @@ -332,8 +331,10 @@ def get_block_start(lines, lineno, maximum_indents=80): def get_block_start_patterns(): global _block_start_pattern if not _block_start_pattern: - pattern = '^\\s*(((def|class|if|elif|except|for|while|with)\\s)|'\ - '((try|else|finally|except)\\s*:))' + pattern = ( + "^\\s*(((def|class|if|elif|except|for|while|with)\\s)|" + "((try|else|finally|except)\\s*:))" + ) _block_start_pattern = re.compile(pattern, re.M) return _block_start_pattern @@ -341,9 +342,9 @@ def get_block_start_patterns(): def count_line_indents(line): indents = 0 for char in line: - if char == ' ': + if char == " ": indents += 1 - elif char == '\t': + elif char == "\t": indents += 8 else: return indents @@ -353,19 +354,20 @@ def count_line_indents(line): def get_string_pattern_with_prefix(prefix): longstr = r'%s"""(\\.|"(?!"")|\\\n|[^"\\])*"""' % prefix shortstr = r'%s"(\\.|\\\n|[^"\\\n])*"' % prefix - return '|'.join([longstr, longstr.replace('"', "'"), - shortstr, shortstr.replace('"', "'")]) + return "|".join( + [longstr, longstr.replace('"', "'"), shortstr, shortstr.replace('"', "'")] + ) def get_string_pattern(): - prefix = r'(? import ` by default. - prefs['prefer_module_from_imports'] = False + prefs["prefer_module_from_imports"] = False # If `True`, rope will transform a comma list of imports into # multiple separate import statements when organizing # imports. - prefs['split_imports'] = False + prefs["split_imports"] = False # If `True`, rope will remove all top-level import statements and # reinsert them at the top of the module when making changes. - prefs['pull_imports_to_top'] = True + prefs["pull_imports_to_top"] = True # If `True`, rope will sort imports alphabetically by module name instead # of alphabetically by import statement, with from imports after normal # imports. - prefs['sort_imports_alphabetically'] = False + prefs["sort_imports_alphabetically"] = False # Location of implementation of # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general @@ -106,8 +115,9 @@ def set_prefs(prefs): # listed in module rope.base.oi.type_hinting.providers.interfaces # For example, you can add you own providers for Django Models, or disable # the search type-hinting in a class hierarchy, etc. - prefs['type_hinting_factory'] = ( - 'rope.base.oi.type_hinting.factory.default_type_hinting_factory') + prefs[ + "type_hinting_factory" + ] = "rope.base.oi.type_hinting.factory.default_type_hinting_factory" def project_opened(project): diff --git a/rope/base/evaluate.py b/rope/base/evaluate.py index 4634981a5..48c8fe2b1 100644 --- a/rope/base/evaluate.py +++ b/rope/base/evaluate.py @@ -42,36 +42,40 @@ def eval_str(holding_scope, name): def eval_str2(holding_scope, name): try: # parenthesizing for handling cases like 'a_var.\nattr' - node = ast.parse('(%s)' % name) + node = ast.parse("(%s)" % name) except SyntaxError: - raise BadIdentifierError( - 'Not a resolvable python identifier selected.') + raise BadIdentifierError("Not a resolvable python identifier selected.") return eval_node2(holding_scope, node) class ScopeNameFinder(object): - def __init__(self, pymodule): self.module_scope = pymodule.get_scope() self.lines = pymodule.lines self.worder = worder.Worder(pymodule.source_code, True) def _is_defined_in_class_body(self, holding_scope, offset, lineno): - if lineno == holding_scope.get_start() and \ - holding_scope.parent is not None and \ - holding_scope.parent.get_kind() == 'Class' and \ - self.worder.is_a_class_or_function_name_in_header(offset): + if ( + lineno == holding_scope.get_start() + and holding_scope.parent is not None + and holding_scope.parent.get_kind() == "Class" + and self.worder.is_a_class_or_function_name_in_header(offset) + ): return True - if lineno != holding_scope.get_start() and \ - holding_scope.get_kind() == 'Class' and \ - self.worder.is_name_assigned_in_class_body(offset): + if ( + lineno != holding_scope.get_start() + and holding_scope.get_kind() == "Class" + and self.worder.is_name_assigned_in_class_body(offset) + ): return True return False def _is_function_name_in_function_header(self, scope, offset, lineno): - if scope.get_start() <= lineno <= scope.get_body_start() and \ - scope.get_kind() == 'Function' and \ - self.worder.is_a_class_or_function_name_in_header(offset): + if ( + scope.get_start() <= lineno <= scope.get_body_start() + and scope.get_kind() == "Function" + and self.worder.is_a_class_or_function_name_in_header(offset) + ): return True return False @@ -86,8 +90,7 @@ def get_primary_and_pyname_at(self, offset): keyword_name = self.worder.get_word_at(offset) pyobject = self.get_enclosing_function(offset) if isinstance(pyobject, pyobjects.PyFunction): - return (None, - pyobject.get_parameters().get(keyword_name, None)) + return (None, pyobject.get_parameters().get(keyword_name, None)) # class body if self._is_defined_in_class_body(holding_scope, offset, lineno): class_scope = holding_scope @@ -99,13 +102,13 @@ def get_primary_and_pyname_at(self, offset): except rope.base.exceptions.AttributeNotFoundError: return (None, None) # function header - if self._is_function_name_in_function_header(holding_scope, - offset, lineno): + if self._is_function_name_in_function_header(holding_scope, offset, lineno): name = self.worder.get_primary_at(offset).strip() return (None, holding_scope.parent[name]) # module in a from statement or an imported name that is aliased - if (self.worder.is_from_statement_module(offset) or - self.worder.is_import_statement_aliased_module(offset)): + if self.worder.is_from_statement_module( + offset + ) or self.worder.is_import_statement_aliased_module(offset): module = self.worder.get_primary_at(offset) module_pyname = self._find_module(module) return (None, module_pyname) @@ -125,23 +128,24 @@ def get_enclosing_function(self, offset): pyobject = function_pyname.get_object() if isinstance(pyobject, pyobjects.AbstractFunction): return pyobject - elif isinstance(pyobject, pyobjects.AbstractClass) and \ - '__init__' in pyobject: - return pyobject['__init__'].get_object() - elif '__call__' in pyobject: - return pyobject['__call__'].get_object() + elif ( + isinstance(pyobject, pyobjects.AbstractClass) and "__init__" in pyobject + ): + return pyobject["__init__"].get_object() + elif "__call__" in pyobject: + return pyobject["__call__"].get_object() return None def _find_module(self, module_name): dots = 0 - while module_name[dots] == '.': + while module_name[dots] == ".": dots += 1 return rope.base.pynames.ImportedModule( - self.module_scope.pyobject, module_name[dots:], dots) + self.module_scope.pyobject, module_name[dots:], dots + ) class StatementEvaluator(object): - def __init__(self, scope): self.scope = scope self.result = None @@ -167,16 +171,15 @@ def _Call(self, node): return def _get_returned(pyobject): - args = arguments.create_arguments(primary, pyobject, - node, self.scope) + args = arguments.create_arguments(primary, pyobject, node, self.scope) return pyobject.get_returned_object(args) + if isinstance(pyobject, rope.base.pyobjects.AbstractClass): result = None - if '__new__' in pyobject: - new_function = pyobject['__new__'].get_object() + if "__new__" in pyobject: + new_function = pyobject["__new__"].get_object() result = _get_returned(new_function) - if result is None or \ - result == rope.base.pyobjects.get_unknown(): + if result is None or result == rope.base.pyobjects.get_unknown(): result = rope.base.pyobjects.PyObject(pyobject) self.result = rope.base.pynames.UnboundName(pyobject=result) return @@ -184,15 +187,17 @@ def _get_returned(pyobject): pyfunction = None if isinstance(pyobject, rope.base.pyobjects.AbstractFunction): pyfunction = pyobject - elif '__call__' in pyobject: - pyfunction = pyobject['__call__'].get_object() + elif "__call__" in pyobject: + pyfunction = pyobject["__call__"].get_object() if pyfunction is not None: self.result = rope.base.pynames.UnboundName( - pyobject=_get_returned(pyfunction)) + pyobject=_get_returned(pyfunction) + ) def _Str(self, node): self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.get_str()) + pyobject=rope.base.builtins.get_str() + ) def _Num(self, node): type_name = type(node.n).__name__ @@ -208,12 +213,12 @@ def _Constant(self, node): def _get_builtin_name(self, type_name): pytype = rope.base.builtins.builtins[type_name].get_object() - return rope.base.pynames.UnboundName( - rope.base.pyobjects.PyObject(pytype)) + return rope.base.pynames.UnboundName(rope.base.pyobjects.PyObject(pytype)) def _BinOp(self, node): self.result = rope.base.pynames.UnboundName( - self._get_object_for_node(node.left)) + self._get_object_for_node(node.left) + ) def _BoolOp(self, node): pyobject = self._get_object_for_node(node.values[0]) @@ -222,43 +227,50 @@ def _BoolOp(self, node): self.result = rope.base.pynames.UnboundName(pyobject) def _Repr(self, node): - self.result = self._get_builtin_name('str') + self.result = self._get_builtin_name("str") def _UnaryOp(self, node): self.result = rope.base.pynames.UnboundName( - self._get_object_for_node(node.operand)) + self._get_object_for_node(node.operand) + ) def _Compare(self, node): - self.result = self._get_builtin_name('bool') + self.result = self._get_builtin_name("bool") def _Dict(self, node): keys = None values = None if node.keys and node.keys[0]: - keys, values = next(iter(filter(itemgetter(0), zip(node.keys, node.values))), (None, None)) + keys, values = next( + iter(filter(itemgetter(0), zip(node.keys, node.values))), (None, None) + ) if keys: keys = self._get_object_for_node(keys) if values: values = self._get_object_for_node(values) self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.get_dict(keys, values)) + pyobject=rope.base.builtins.get_dict(keys, values) + ) def _List(self, node): holding = None if node.elts: holding = self._get_object_for_node(node.elts[0]) self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.get_list(holding)) + pyobject=rope.base.builtins.get_list(holding) + ) def _ListComp(self, node): pyobject = self._what_does_comprehension_hold(node) self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.get_list(pyobject)) + pyobject=rope.base.builtins.get_list(pyobject) + ) def _GeneratorExp(self, node): pyobject = self._what_does_comprehension_hold(node) self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.get_iterator(pyobject)) + pyobject=rope.base.builtins.get_iterator(pyobject) + ) def _what_does_comprehension_hold(self, node): scope = self._make_comprehension_scope(node) @@ -270,8 +282,9 @@ def _make_comprehension_scope(self, node): module = scope.pyobject.get_module() names = {} for comp in node.generators: - new_names = _get_evaluated_names(comp.target, comp.iter, module, - '.__iter__().next()', node.lineno) + new_names = _get_evaluated_names( + comp.target, comp.iter, module, ".__iter__().next()", node.lineno + ) names.update(new_names) return rope.base.pyscopes.TemporaryScope(scope.pycore, scope, names) @@ -284,7 +297,8 @@ def _Tuple(self, node): else: objects.append(self._get_object_for_node(node.elts[0])) self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.get_tuple(*objects)) + pyobject=rope.base.builtins.get_tuple(*objects) + ) def _get_object_for_node(self, stmt): pyname = eval_node(self.scope, stmt) @@ -302,17 +316,14 @@ def _get_primary_and_object_for_node(self, stmt): def _Subscript(self, node): if isinstance(node.slice, ast.Index): - self._call_function(node.value, '__getitem__', - [node.slice.value]) + self._call_function(node.value, "__getitem__", [node.slice.value]) elif isinstance(node.slice, ast.Slice): - self._call_function(node.value, '__getitem__', - [node.slice]) + self._call_function(node.value, "__getitem__", [node.slice]) elif isinstance(node.slice, ast.expr): - self._call_function(node.value, '__getitem__', - [node.value]) + self._call_function(node.value, "__getitem__", [node.value]) def _Slice(self, node): - self.result = self._get_builtin_name('slice') + self.result = self._get_builtin_name("slice") def _call_function(self, node, function_name, other_args=None): pyname = eval_node(self.scope, node) @@ -322,26 +333,26 @@ def _call_function(self, node, function_name, other_args=None): return if function_name in pyobject: called = pyobject[function_name].get_object() - if not called or \ - not isinstance(called, pyobjects.AbstractFunction): + if not called or not isinstance(called, pyobjects.AbstractFunction): return args = [node] if other_args: args += other_args arguments_ = arguments.Arguments(args, self.scope) self.result = rope.base.pynames.UnboundName( - pyobject=called.get_returned_object(arguments_)) + pyobject=called.get_returned_object(arguments_) + ) def _Lambda(self, node): self.result = rope.base.pynames.UnboundName( - pyobject=rope.base.builtins.Lambda(node, self.scope)) + pyobject=rope.base.builtins.Lambda(node, self.scope) + ) def _get_evaluated_names(targets, assigned, module, evaluation, lineno): result = {} for name, levels in astutils.get_name_levels(targets): - assignment = rope.base.pynames.AssignmentValue(assigned, levels, - evaluation) + assignment = rope.base.pynames.AssignmentValue(assigned, levels, evaluation) # XXX: this module should not access `rope.base.pynamesdef`! pyname = rope.base.pynamesdef.AssignedName(lineno, module) pyname.assignments.append(assignment) diff --git a/rope/base/exceptions.py b/rope/base/exceptions.py index d161c89ed..ff38a405d 100644 --- a/rope/base/exceptions.py +++ b/rope/base/exceptions.py @@ -47,8 +47,8 @@ def __init__(self, filename, lineno, message): self.lineno = lineno self.message_ = message super(ModuleSyntaxError, self).__init__( - 'Syntax error in file <%s> line <%s>: %s' % - (filename, lineno, message)) + "Syntax error in file <%s> line <%s>: %s" % (filename, lineno, message) + ) class ModuleDecodeError(RopeError): @@ -58,4 +58,5 @@ def __init__(self, filename, message): self.filename = filename self.message_ = message super(ModuleDecodeError, self).__init__( - 'Cannot decode file <%s>: %s' % (filename, message)) + "Cannot decode file <%s>: %s" % (filename, message) + ) diff --git a/rope/base/fscommands.py b/rope/base/fscommands.py index fff01282e..6fd03ca52 100644 --- a/rope/base/fscommands.py +++ b/rope/base/fscommands.py @@ -18,13 +18,16 @@ except NameError: unicode = str + def create_fscommands(root): dirlist = os.listdir(root) - commands = {'.hg': MercurialCommands, - '.svn': SubversionCommands, - '.git': GITCommands, - '_svn': SubversionCommands, - '_darcs': DarcsCommands} + commands = { + ".hg": MercurialCommands, + ".svn": SubversionCommands, + ".git": GITCommands, + "_svn": SubversionCommands, + "_darcs": DarcsCommands, + } for key in commands: if key in dirlist: try: @@ -35,9 +38,8 @@ def create_fscommands(root): class FileSystemCommands(object): - def create_file(self, path): - open(path, 'w').close() + open(path, "w").close() def create_folder(self, path): os.mkdir(path) @@ -52,7 +54,7 @@ def remove(self, path): shutil.rmtree(path) def write(self, path, data): - file_ = open(path, 'wb') + file_ = open(path, "wb") try: file_.write(data) finally: @@ -60,10 +62,10 @@ def write(self, path, data): class SubversionCommands(object): - def __init__(self, *args): self.normal_actions = FileSystemCommands() import pysvn + self.client = pysvn.Client() def create_file(self, path): @@ -85,22 +87,26 @@ def write(self, path, data): class MercurialCommands(object): - def __init__(self, root): self.hg = self._import_mercurial() self.normal_actions = FileSystemCommands() try: self.ui = self.hg.ui.ui( - verbose=False, debug=False, quiet=True, - interactive=False, traceback=False, report_untrusted=False) + verbose=False, + debug=False, + quiet=True, + interactive=False, + traceback=False, + report_untrusted=False, + ) except: self.ui = self.hg.ui.ui() - self.ui.setconfig('ui', 'interactive', 'no') - self.ui.setconfig('ui', 'debug', 'no') - self.ui.setconfig('ui', 'traceback', 'no') - self.ui.setconfig('ui', 'verbose', 'no') - self.ui.setconfig('ui', 'report_untrusted', 'no') - self.ui.setconfig('ui', 'quiet', 'yes') + self.ui.setconfig("ui", "interactive", "no") + self.ui.setconfig("ui", "debug", "no") + self.ui.setconfig("ui", "traceback", "no") + self.ui.setconfig("ui", "verbose", "no") + self.ui.setconfig("ui", "report_untrusted", "no") + self.ui.setconfig("ui", "quiet", "yes") self.repo = self.hg.hg.repository(self.ui, root) @@ -108,6 +114,7 @@ def _import_mercurial(self): import mercurial.commands import mercurial.hg import mercurial.ui + return mercurial def create_file(self, path): @@ -118,8 +125,7 @@ def create_folder(self, path): self.normal_actions.create_folder(path) def move(self, path, new_location): - self.hg.commands.rename(self.ui, self.repo, path, - new_location, after=False) + self.hg.commands.rename(self.ui, self.repo, path, new_location, after=False) def remove(self, path): self.hg.commands.remove(self.ui, self.repo, path) @@ -129,54 +135,52 @@ def write(self, path, data): class GITCommands(object): - def __init__(self, root): self.root = root - self._do(['version']) + self._do(["version"]) self.normal_actions = FileSystemCommands() def create_file(self, path): self.normal_actions.create_file(path) - self._do(['add', self._in_dir(path)]) + self._do(["add", self._in_dir(path)]) def create_folder(self, path): self.normal_actions.create_folder(path) def move(self, path, new_location): - self._do(['mv', self._in_dir(path), self._in_dir(new_location)]) + self._do(["mv", self._in_dir(path), self._in_dir(new_location)]) def remove(self, path): - self._do(['rm', self._in_dir(path)]) + self._do(["rm", self._in_dir(path)]) def write(self, path, data): # XXX: should we use ``git add``? self.normal_actions.write(path, data) def _do(self, args): - _execute(['git'] + args, cwd=self.root) + _execute(["git"] + args, cwd=self.root) def _in_dir(self, path): if path.startswith(self.root): - return path[len(self.root) + 1:] + return path[len(self.root) + 1 :] return self.root class DarcsCommands(object): - def __init__(self, root): self.root = root self.normal_actions = FileSystemCommands() def create_file(self, path): self.normal_actions.create_file(path) - self._do(['add', path]) + self._do(["add", path]) def create_folder(self, path): self.normal_actions.create_folder(path) - self._do(['add', path]) + self._do(["add", path]) def move(self, path, new_location): - self._do(['mv', path, new_location]) + self._do(["mv", path, new_location]) def remove(self, path): self.normal_actions.remove(path) @@ -185,7 +189,7 @@ def write(self, path, data): self.normal_actions.write(path, data) def _do(self, args): - _execute(['darcs'] + args, cwd=self.root) + _execute(["darcs"] + args, cwd=self.root) def _execute(args, cwd=None): @@ -204,13 +208,13 @@ def unicode_to_file_data(contents, encoding=None): try: return contents.encode() except UnicodeEncodeError: - return contents.encode('utf-8') + return contents.encode("utf-8") def file_data_to_unicode(data, encoding=None): result = _decode_data(data, encoding) - if '\r' in result: - result = result.replace('\r\n', '\n').replace('\r', '\n') + if "\r" in result: + result = result.replace("\r\n", "\n").replace("\r", "\n") return result @@ -224,24 +228,24 @@ def _decode_data(data, encoding): # PEP263 says that "encoding not explicitly defined" means it is ascii, # but we will use utf8 instead since utf8 fully covers ascii and btw is # the only non-latin sane encoding. - encoding = 'utf-8' + encoding = "utf-8" try: return data.decode(encoding) except (UnicodeError, LookupError): # fallback to latin1: it should never fail - return data.decode('latin1') + return data.decode("latin1") def read_str_coding(source): # as defined by PEP-263 (https://www.python.org/dev/peps/pep-0263/) - CODING_LINE_PATTERN = b'^[ \t\f]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)' + CODING_LINE_PATTERN = b"^[ \t\f]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)" if type(source) == bytes: - newline = b'\n' + newline = b"\n" CODING_LINE_PATTERN = re.compile(CODING_LINE_PATTERN) else: - newline = '\n' - CODING_LINE_PATTERN = re.compile(CODING_LINE_PATTERN.decode('ascii')) + newline = "\n" + CODING_LINE_PATTERN = re.compile(CODING_LINE_PATTERN.decode("ascii")) for line in source.split(newline, 2)[:2]: if re.match(CODING_LINE_PATTERN, line): return _find_coding(line) @@ -251,12 +255,12 @@ def read_str_coding(source): def _find_coding(text): if isinstance(text, pycompat.str): - text = text.encode('utf-8') - coding = b'coding' + text = text.encode("utf-8") + coding = b"coding" to_chr = chr if pycompat.PY3 else lambda x: x try: start = text.index(coding) + len(coding) - if text[start] not in b'=:': + if text[start] not in b"=:": return start += 1 while start < len(text) and to_chr(text[start]).isspace(): @@ -264,12 +268,12 @@ def _find_coding(text): end = start while end < len(text): c = text[end] - if not to_chr(c).isalnum() and c not in b'-_': + if not to_chr(c).isalnum() and c not in b"-_": break end += 1 result = text[start:end] if isinstance(result, bytes): - result = result.decode('utf-8') + result = result.decode("utf-8") return result except ValueError: pass diff --git a/rope/base/history.py b/rope/base/history.py index d3c523d31..d2ba8072f 100644 --- a/rope/base/history.py +++ b/rope/base/history.py @@ -16,7 +16,8 @@ def __init__(self, project, maxundos=None): def _load_history(self): if self.save: result = self.project.data_files.read_data( - 'history', compress=self.compress, import_=True) + "history", compress=self.compress, import_=True + ) if result is not None: to_change = change.DataToChange(self.project) for data in result[0]: @@ -43,7 +44,7 @@ def do(self, changes, task_handle=taskhandle.NullTaskHandle()): def _remove_extra_items(self): if len(self.undo_list) > self.max_undos: - del self.undo_list[0:len(self.undo_list) - self.max_undos] + del self.undo_list[0 : len(self.undo_list) - self.max_undos] def _is_change_interesting(self, changes): for resource in changes.get_changed_resources(): @@ -51,8 +52,7 @@ def _is_change_interesting(self, changes): return True return False - def undo(self, change=None, drop=False, - task_handle=taskhandle.NullTaskHandle()): + def undo(self, change=None, drop=False, task_handle=taskhandle.NullTaskHandle()): """Redo done changes from the history When `change` is `None`, the last done change will be undone. @@ -66,15 +66,15 @@ def undo(self, change=None, drop=False, """ if not self._undo_list: - raise exceptions.HistoryError('Undo list is empty') + raise exceptions.HistoryError("Undo list is empty") if change is None: change = self.undo_list[-1] dependencies = self._find_dependencies(self.undo_list, change) self._move_front(self.undo_list, dependencies) self._perform_undos(len(dependencies), task_handle) - result = self.redo_list[-len(dependencies):] + result = self.redo_list[-len(dependencies) :] if drop: - del self.redo_list[-len(dependencies):] + del self.redo_list[-len(dependencies) :] return result def redo(self, change=None, task_handle=taskhandle.NullTaskHandle()): @@ -88,13 +88,13 @@ def redo(self, change=None, task_handle=taskhandle.NullTaskHandle()): """ if not self.redo_list: - raise exceptions.HistoryError('Redo list is empty') + raise exceptions.HistoryError("Redo list is empty") if change is None: change = self.redo_list[-1] dependencies = self._find_dependencies(self.redo_list, change) self._move_front(self.redo_list, dependencies) self._perform_redos(len(dependencies), task_handle) - return self.undo_list[-len(dependencies):] + return self.undo_list[-len(dependencies) :] def _move_front(self, change_list, changes): for change in changes: @@ -109,8 +109,7 @@ def _perform_undos(self, count, task_handle): for i in range(count): self.current_change = self.undo_list[-1] try: - job_set = change.create_job_set(task_handle, - self.current_change) + job_set = change.create_job_set(task_handle, self.current_change) self.current_change.undo(job_set) finally: self.current_change = None @@ -120,8 +119,7 @@ def _perform_redos(self, count, task_handle): for i in range(count): self.current_change = self.redo_list[-1] try: - job_set = change.create_job_set(task_handle, - self.current_change) + job_set = change.create_job_set(task_handle, self.current_change) self.current_change.do(job_set) finally: self.current_change = None @@ -141,12 +139,10 @@ def contents_before_current_change(self, file): def _search_for_change_contents(self, change_list, file): for change_ in reversed(change_list): if isinstance(change_, change.ChangeSet): - result = self._search_for_change_contents(change_.changes, - file) + result = self._search_for_change_contents(change_.changes, file) if result is not None: return result - if isinstance(change_, change.ChangeContents) and \ - change_.resource == file: + if isinstance(change_, change.ChangeContents) and change_.resource == file: return change_.old_contents def write(self): @@ -156,8 +152,7 @@ def write(self): self._remove_extra_items() data.append([to_data(change_) for change_ in self.undo_list]) data.append([to_data(change_) for change_ in self.redo_list]) - self.project.data_files.write_data('history', data, - compress=self.compress) + self.project.data_files.write_data("history", data, compress=self.compress) def get_file_undo_list(self, resource): result = [] @@ -167,8 +162,9 @@ def get_file_undo_list(self, resource): return result def __str__(self): - return 'History holds %s changes in memory' % \ - (len(self.undo_list) + len(self.redo_list)) + return "History holds %s changes in memory" % ( + len(self.undo_list) + len(self.redo_list) + ) undo_list = property(lambda self: self._undo_list) redo_list = property(lambda self: self._redo_list) @@ -188,17 +184,17 @@ def tobe_redone(self): @property def max_undos(self): if self._maxundos is None: - return self.project.prefs.get('max_history_items', 100) + return self.project.prefs.get("max_history_items", 100) else: return self._maxundos @property def save(self): - return self.project.prefs.get('save_history', False) + return self.project.prefs.get("save_history", False) @property def compress(self): - return self.project.prefs.get('compress_history', False) + return self.project.prefs.get("compress_history", False) def clear(self): """Forget all undo and redo information""" @@ -207,7 +203,6 @@ def clear(self): class _FindChangeDependencies(object): - def __init__(self, change_list): self.change = change_list[0] self.change_list = change_list diff --git a/rope/base/libutils.py b/rope/base/libutils.py index 7cbacb2cc..4c966ccae 100644 --- a/rope/base/libutils.py +++ b/rope/base/libutils.py @@ -25,9 +25,9 @@ def path_to_resource(project, path, type=None): project = rope.base.project.get_no_project() if type is None: return project.get_resource(project_path) - if type == 'file': + if type == "file": return project.get_file(project_path) - if type == 'folder': + if type == "folder": return project.get_folder(project_path) return None @@ -38,12 +38,12 @@ def path_relative_to_project_root(project, path): @utils.deprecated() def relative(root, path): - root = rope.base.project._realpath(root).replace(os.path.sep, '/') - path = rope.base.project._realpath(path).replace(os.path.sep, '/') + root = rope.base.project._realpath(root).replace(os.path.sep, "/") + path = rope.base.project._realpath(path).replace(os.path.sep, "/") if path == root: - return '' - if path.startswith(root + '/'): - return path[len(root) + 1:] + return "" + if path.startswith(root + "/"): + return path[len(root) + 1 :] def report_change(project, path, old_content): @@ -58,8 +58,7 @@ def report_change(project, path, old_content): for observer in list(project.observers): observer.resource_changed(resource) if project.pycore.automatic_soa: - rope.base.pycore.perform_soa_on_changed_scopes(project, resource, - old_content) + rope.base.pycore.perform_soa_on_changed_scopes(project, resource, old_content) def analyze_module(project, resource): @@ -76,7 +75,7 @@ def analyze_modules(project, task_handle=taskhandle.NullTaskHandle()): Note that this might be really time consuming. """ resources = project.get_python_files() - job_set = task_handle.create_jobset('Analyzing Modules', len(resources)) + job_set = task_handle.create_jobset("Analyzing Modules", len(resources)) for resource in resources: job_set.started_job(resource.path) analyze_module(project, resource) @@ -91,8 +90,9 @@ def get_string_module(project, code, resource=None, force_errors=False): ``ignore_syntax_errors`` project config. """ - return pyobjectsdef.PyModule(project.pycore, code, resource, - force_errors=force_errors) + return pyobjectsdef.PyModule( + project.pycore, code, resource, force_errors=force_errors + ) def get_string_scope(project, code, resource=None): @@ -108,16 +108,17 @@ def modname(resource): if resource.is_folder(): module_name = resource.name source_folder = resource.parent - elif resource.name == '__init__.py': + elif resource.name == "__init__.py": module_name = resource.parent.name source_folder = resource.parent.parent else: module_name = resource.name[:-3] source_folder = resource.parent - while source_folder != source_folder.parent and \ - source_folder.has_child('__init__.py'): - module_name = source_folder.name + '.' + module_name + while source_folder != source_folder.parent and source_folder.has_child( + "__init__.py" + ): + module_name = source_folder.name + "." + module_name source_folder = source_folder.parent return module_name diff --git a/rope/base/oi/doa.py b/rope/base/oi/doa.py index 63ebc50ea..48b5db98f 100644 --- a/rope/base/oi/doa.py +++ b/rope/base/oi/doa.py @@ -1,6 +1,7 @@ import base64 import hashlib import hmac + try: import cPickle as pickle except ImportError: @@ -30,6 +31,7 @@ def _compat_compare_digest(a, b): difference |= ord(a_char) ^ ord(b_char) return difference == 0 + try: from hmac import compare_digest except ImportError: @@ -39,8 +41,9 @@ def _compat_compare_digest(a, b): class PythonFileRunner(object): """A class for running python project files""" - def __init__(self, pycore, file_, args=None, stdin=None, - stdout=None, analyze_data=None): + def __init__( + self, pycore, file_, args=None, stdin=None, stdout=None, analyze_data=None + ): self.pycore = pycore self.file = file_ self.analyze_data = analyze_data @@ -53,26 +56,38 @@ def run(self): """Execute the process""" env = dict(os.environ) file_path = self.file.real_path - path_folders = self.pycore.project.get_source_folders() + \ - self.pycore.project.get_python_path_folders() - env['PYTHONPATH'] = os.pathsep.join(folder.real_path - for folder in path_folders) - runmod_path = self.pycore.project.find_module('rope.base.oi.runmod').real_path + path_folders = ( + self.pycore.project.get_source_folders() + + self.pycore.project.get_python_path_folders() + ) + env["PYTHONPATH"] = os.pathsep.join(folder.real_path for folder in path_folders) + runmod_path = self.pycore.project.find_module("rope.base.oi.runmod").real_path self.receiver = None self._init_data_receiving() - send_info = '-' + send_info = "-" if self.receiver: send_info = self.receiver.get_send_info() - args = [sys.executable, runmod_path, send_info, - self.pycore.project.address, self.file.real_path] + args = [ + sys.executable, + runmod_path, + send_info, + self.pycore.project.address, + self.file.real_path, + ] if self.analyze_data is None: del args[1:4] if self.args is not None: args.extend(self.args) self.process = subprocess.Popen( - executable=sys.executable, args=args, env=env, - cwd=os.path.split(file_path)[0], stdin=self.stdin, - stdout=self.stdout, stderr=self.stdout, close_fds=os.name != 'nt') + executable=sys.executable, + args=args, + env=env, + cwd=os.path.split(file_path)[0], + stdin=self.stdin, + stdout=self.stdout, + stderr=self.stdout, + close_fds=os.name != "nt", + ) def _init_data_receiving(self): if self.analyze_data is None: @@ -80,21 +95,20 @@ def _init_data_receiving(self): # Disabling FIFO data transfer due to blocking when running # unittests in the GUI. # XXX: Handle FIFO data transfer for `rope.ui.testview` - if True or os.name == 'nt': + if True or os.name == "nt": self.receiver = _SocketReceiver() else: self.receiver = _FIFOReceiver() - self.receiving_thread = threading.Thread( - target=self._receive_information) + self.receiving_thread = threading.Thread(target=self._receive_information) self.receiving_thread.setDaemon(True) self.receiving_thread.start() def _receive_information(self): - #temp = open('/dev/shm/info', 'wb') + # temp = open('/dev/shm/info', 'wb') for data in self.receiver.receive_data(): self.analyze_data(data) - #temp.write(str(data) + '\n') - #temp.close() + # temp.write(str(data) + '\n') + # temp.close() for observer in self.observers: observer() @@ -109,12 +123,13 @@ def kill_process(self): if self.process.poll() is not None: return try: - if hasattr(self.process, 'terminate'): + if hasattr(self.process, "terminate"): self.process.terminate() - elif os.name != 'nt': + elif os.name != "nt": os.kill(self.process.pid, 9) else: import ctypes + handle = int(self.process._handle) ctypes.windll.kernel32.TerminateProcess(handle, -1) except OSError: @@ -126,7 +141,6 @@ def add_finishing_observer(self, observer): class _MessageReceiver(object): - def receive_data(self): pass @@ -135,7 +149,6 @@ def get_send_info(self): class _SocketReceiver(_MessageReceiver): - def __init__(self): self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.data_port = 3037 @@ -143,24 +156,23 @@ def __init__(self): while self.data_port < 4000: try: - self.server_socket.bind(('localhost', self.data_port)) + self.server_socket.bind(("localhost", self.data_port)) break except socket.error: self.data_port += 1 self.server_socket.listen(1) def get_send_info(self): - return '%d:%s' % (self.data_port, - base64.b64encode(self.key).decode('utf-8')) + return "%d:%s" % (self.data_port, base64.b64encode(self.key).decode("utf-8")) def receive_data(self): conn, addr = self.server_socket.accept() self.server_socket.close() - my_file = conn.makefile('rb') + my_file = conn.makefile("rb") while True: # Received messages must meet the following criteria: # 1. Must be contained on a single line. - # 2. Must be prefixed with a base64 encoded sha256 message digest + # 2. Must be prefixed with a base64 encoded sha256 message digest # of the base64 encoded pickle data. # 3. Message digest must be computed using the correct key. # @@ -172,9 +184,9 @@ def receive_data(self): break try: - digest_end = buf.index(b':') + digest_end = buf.index(b":") buf_digest = base64.b64decode(buf[:digest_end]) - buf_data = buf[digest_end + 1:-1] + buf_data = buf[digest_end + 1 : -1] decoded_buf_data = base64.b64decode(buf_data) except: # Corrupted data; the payload cannot be trusted and just has @@ -195,24 +207,23 @@ def receive_data(self): class _FIFOReceiver(_MessageReceiver): - def __init__(self): # XXX: this is insecure and might cause race conditions self.file_name = self._get_file_name() os.mkfifo(self.file_name) def _get_file_name(self): - prefix = tempfile.gettempdir() + '/__rope_' + prefix = tempfile.gettempdir() + "/__rope_" i = 0 - while os.path.exists(prefix + str(i).rjust(4, '0')): + while os.path.exists(prefix + str(i).rjust(4, "0")): i += 1 - return prefix + str(i).rjust(4, '0') + return prefix + str(i).rjust(4, "0") def get_send_info(self): return self.file_name def receive_data(self): - my_file = open(self.file_name, 'rb') + my_file = open(self.file_name, "rb") while True: try: yield marshal.load(my_file) diff --git a/rope/base/oi/memorydb.py b/rope/base/oi/memorydb.py index 01c814ce4..6388e53f5 100644 --- a/rope/base/oi/memorydb.py +++ b/rope/base/oi/memorydb.py @@ -2,7 +2,6 @@ class MemoryDB(objectdb.FileDict): - def __init__(self, project, persist=None): self.project = project self._persist = persist @@ -14,7 +13,8 @@ def _load_files(self): self._files = {} if self.persist: result = self.project.data_files.read_data( - 'objectdb', compress=self.compress, import_=True) + "objectdb", compress=self.compress, import_=True + ) if result is not None: self._files = result @@ -51,23 +51,21 @@ def __delitem__(self, file): def write(self): if self.persist: - self.project.data_files.write_data('objectdb', self._files, - self.compress) + self.project.data_files.write_data("objectdb", self._files, self.compress) @property def compress(self): - return self.project.prefs.get('compress_objectdb', False) + return self.project.prefs.get("compress_objectdb", False) @property def persist(self): if self._persist is not None: return self._persist else: - return self.project.prefs.get('save_objectdb', False) + return self.project.prefs.get("save_objectdb", False) class FileInfo(objectdb.FileInfo): - def __init__(self, scopes): self.scopes = scopes @@ -97,9 +95,7 @@ def __setitem__(self): raise NotImplementedError() - class ScopeInfo(objectdb.ScopeInfo): - def __init__(self): self.call_info = {} self.per_name = {} diff --git a/rope/base/oi/objectdb.py b/rope/base/oi/objectdb.py index 1ec4e350b..53583ee0c 100644 --- a/rope/base/oi/objectdb.py +++ b/rope/base/oi/objectdb.py @@ -2,7 +2,6 @@ class ObjectDB(object): - def __init__(self, db, validation): self.db = db self.validation = validation @@ -93,12 +92,13 @@ def __str__(self): scope_count = 0 for file_dict in self.files.values(): scope_count += len(file_dict) - return 'ObjectDB holds %s file and %s scope infos' % \ - (len(self.files), scope_count) + return "ObjectDB holds %s file and %s scope infos" % ( + len(self.files), + scope_count, + ) class _NullScopeInfo(object): - def __init__(self, error_on_write=True): self.error_on_write = error_on_write @@ -121,13 +121,11 @@ def add_call(self, parameters, returned): class FileInfo(dict): - def create_scope(self, key): pass class FileDict(dict): - def create(self, key): pass @@ -136,7 +134,6 @@ def rename(self, key, new_key): class ScopeInfo(object): - def get_per_name(self, name): pass @@ -154,7 +151,6 @@ def add_call(self, parameters, returned): class CallInfo(object): - def __init__(self, args, returned): self.args = args self.returned = returned @@ -167,7 +163,6 @@ def get_returned(self): class FileListObserver(object): - def added(self, path): pass diff --git a/rope/base/oi/objectinfo.py b/rope/base/oi/objectinfo.py index f86d72e0b..fc8296f37 100644 --- a/rope/base/oi/objectinfo.py +++ b/rope/base/oi/objectinfo.py @@ -18,19 +18,20 @@ def __init__(self, project): self.to_pyobject = transform.TextualToPyObject(project) self.doi_to_pyobject = transform.DOITextualToPyObject(project) self._init_objectdb() - if project.prefs.get('validate_objectdb', False): + if project.prefs.get("validate_objectdb", False): self._init_validation() def _init_objectdb(self): - dbtype = self.project.get_prefs().get('objectdb_type', None) + dbtype = self.project.get_prefs().get("objectdb_type", None) persist = None if dbtype is not None: warnings.warn( '"objectdb_type" project config is deprecated;\n' 'Use "save_objectdb" instead in your project ' 'config file.\n(".ropeproject/config.py" by default)\n', - DeprecationWarning) - if dbtype != 'memory' and self.project.ropefolder is not None: + DeprecationWarning, + ) + if dbtype != "memory" and self.project.ropefolder is not None: persist = True self.validation = TextualValidation(self.to_pyobject) db = memorydb.MemoryDB(self.project, persist=persist) @@ -39,22 +40,22 @@ def _init_objectdb(self): def _init_validation(self): self.objectdb.validate_files() observer = resourceobserver.ResourceObserver( - changed=self._resource_changed, moved=self._resource_moved, - removed=self._resource_moved) + changed=self._resource_changed, + moved=self._resource_moved, + removed=self._resource_moved, + ) files = [] for path in self.objectdb.get_files(): resource = self.to_pyobject.path_to_resource(path) if resource is not None and resource.project == self.project: files.append(resource) - self.observer = resourceobserver.FilteredResourceObserver(observer, - files) + self.observer = resourceobserver.FilteredResourceObserver(observer, files) self.objectdb.add_file_list_observer(_FileListObserver(self)) self.project.add_observer(self.observer) def _resource_changed(self, resource): try: - self.objectdb.validate_file( - self.to_textual.resource_to_path(resource)) + self.objectdb.validate_file(self.to_textual.resource_to_path(resource)) except exceptions.ModuleSyntaxError: pass @@ -75,7 +76,7 @@ def get_returned(self, pyobject, args): return None for call_info in self.objectdb.get_callinfos(path, key): returned = call_info.get_returned() - if returned and returned[0] not in ('unknown', 'none'): + if returned and returned[0] not in ("unknown", "none"): result = returned break if result is None: @@ -87,15 +88,15 @@ def get_exact_returned(self, pyobject, args): path, key = self._get_scope(pyobject) if path is not None: returned = self.objectdb.get_returned( - path, key, self._args_to_textual(pyobject, args)) + path, key, self._args_to_textual(pyobject, args) + ) if returned is not None: return self.to_pyobject(returned) def _args_to_textual(self, pyfunction, args): parameters = list(pyfunction.get_param_names(special_args=False)) - arguments = args.get_arguments(parameters)[:len(parameters)] - textual_args = tuple([self.to_textual(arg) - for arg in arguments]) + arguments = args.get_arguments(parameters)[: len(parameters)] + textual_args = tuple([self.to_textual(arg) for arg in arguments]) return textual_args def get_parameter_objects(self, pyobject): @@ -116,8 +117,7 @@ def get_parameter_objects(self, pyobject): if unknowns == 0: break if unknowns < arg_count: - return [self.to_pyobject(parameter) - for parameter in parameters] + return [self.to_pyobject(parameter) for parameter in parameters] def get_passed_objects(self, pyfunction, parameter_index): path, key = self._get_scope(pyfunction) @@ -136,17 +136,17 @@ def doa_data_received(self, data): def doi_to_normal(textual): pyobject = self.doi_to_pyobject(textual) return self.to_textual(pyobject) + function = doi_to_normal(data[0]) args = tuple([doi_to_normal(textual) for textual in data[1]]) returned = doi_to_normal(data[2]) - if function[0] == 'defined' and len(function) == 3: + if function[0] == "defined" and len(function) == 3: self._save_data(function, args, returned) def function_called(self, pyfunction, params, returned=None): function_text = self.to_textual(pyfunction) - params_text = tuple([self.to_textual(param) - for param in params]) - returned_text = ('unknown',) + params_text = tuple([self.to_textual(param) for param in params]) + returned_text = ("unknown",) if returned is not None: returned_text = self.to_textual(returned) self._save_data(function_text, params_text, returned_text) @@ -163,7 +163,7 @@ def get_per_name(self, scope, name): if result is not None: return self.to_pyobject(result) - def _save_data(self, function, args, returned=('unknown',)): + def _save_data(self, function, args, returned=("unknown",)): self.objectdb.add_callinfo(function[1], function[2], args, returned) def _get_scope(self, pyobject): @@ -171,12 +171,12 @@ def _get_scope(self, pyobject): if resource is None: return None, None textual = self.to_textual(pyobject) - if textual[0] == 'defined': + if textual[0] == "defined": path = textual[1] if len(textual) == 3: key = textual[2] else: - key = '' + key = "" return path, key return None, None @@ -188,34 +188,32 @@ def __str__(self): class TextualValidation(object): - def __init__(self, to_pyobject): self.to_pyobject = to_pyobject def is_value_valid(self, value): # ???: Should none and unknown be considered valid? - if value is None or value[0] in ('none', 'unknown'): + if value is None or value[0] in ("none", "unknown"): return False return self.to_pyobject(value) is not None def is_more_valid(self, new, old): if old is None: return True - return new[0] not in ('unknown', 'none') + return new[0] not in ("unknown", "none") def is_file_valid(self, path): return self.to_pyobject.path_to_resource(path) is not None def is_scope_valid(self, path, key): - if key == '': - textual = ('defined', path) + if key == "": + textual = ("defined", path) else: - textual = ('defined', path, key) + textual = ("defined", path, key) return self.to_pyobject(textual) is not None class _FileListObserver(object): - def __init__(self, object_info): self.object_info = object_info self.observer = self.object_info.observer diff --git a/rope/base/oi/runmod.py b/rope/base/oi/runmod.py index 055d9ae86..2f169e392 100644 --- a/rope/base/oi/runmod.py +++ b/rope/base/oi/runmod.py @@ -2,6 +2,7 @@ def __rope_start_everything(): import os import sys import socket + try: import cPickle as pickle except ImportError: @@ -16,32 +17,30 @@ def __rope_start_everything(): import hmac class _MessageSender(object): - def send_data(self, data): pass class _SocketSender(_MessageSender): - def __init__(self, port, key): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect(('127.0.0.1', port)) - self.my_file = s.makefile('wb') + s.connect(("127.0.0.1", port)) + self.my_file = s.makefile("wb") self.key = base64.b64decode(key) def send_data(self, data): if not self.my_file.closed: pickled_data = base64.b64encode( - pickle.dumps(data, pickle.HIGHEST_PROTOCOL)) + pickle.dumps(data, pickle.HIGHEST_PROTOCOL) + ) dgst = hmac.new(self.key, pickled_data, hashlib.sha256).digest() - self.my_file.write(base64.b64encode(dgst) + b':' + - pickled_data + b'\n') + self.my_file.write(base64.b64encode(dgst) + b":" + pickled_data + b"\n") + def close(self): self.my_file.close() class _FileSender(_MessageSender): - def __init__(self, file_name): - self.my_file = open(file_name, 'wb') + self.my_file = open(file_name, "wb") def send_data(self, data): if not self.my_file.closed: @@ -59,14 +58,14 @@ def newfunc(self, arg): result = func(self, arg) cache[arg] = result return result + return newfunc class _FunctionCallDataSender(object): - def __init__(self, send_info, project_root): self.project_root = project_root if send_info[0].isdigit(): - port, key = send_info.split(':', 1) + port, key = send_info.split(":", 1) self.sender = _SocketSender(int(port), key) else: self.sender = _FileSender(send_info) @@ -76,85 +75,97 @@ def global_trace(frame, event, arg): # This might lose some information if self._is_an_interesting_call(frame): return self.on_function_call + sys.settrace(global_trace) threading.settrace(global_trace) def on_function_call(self, frame, event, arg): - if event != 'return': + if event != "return": return args = [] - returned = ('unknown',) + returned = ("unknown",) code = frame.f_code - for argname in code.co_varnames[:code.co_argcount]: + for argname in code.co_varnames[: code.co_argcount]: try: - argvalue = self._object_to_persisted_form( - frame.f_locals[argname]) + argvalue = self._object_to_persisted_form(frame.f_locals[argname]) args.append(argvalue) except (TypeError, AttributeError): - args.append(('unknown',)) + args.append(("unknown",)) try: returned = self._object_to_persisted_form(arg) except (TypeError, AttributeError): pass try: - data = (self._object_to_persisted_form(frame.f_code), - tuple(args), returned) + data = ( + self._object_to_persisted_form(frame.f_code), + tuple(args), + returned, + ) self.sender.send_data(data) except (TypeError): pass return self.on_function_call def _is_an_interesting_call(self, frame): - #if frame.f_code.co_name in ['?', '']: + # if frame.f_code.co_name in ['?', '']: # return False - #return not frame.f_back or + # return not frame.f_back or # not self._is_code_inside_project(frame.f_back.f_code) - if not self._is_code_inside_project(frame.f_code) and \ - (not frame.f_back or - not self._is_code_inside_project(frame.f_back.f_code)): + if not self._is_code_inside_project(frame.f_code) and ( + not frame.f_back + or not self._is_code_inside_project(frame.f_back.f_code) + ): return False return True def _is_code_inside_project(self, code): source = self._path(code.co_filename) - return source is not None and os.path.exists(source) and \ - _realpath(source).startswith(self.project_root) + return ( + source is not None + and os.path.exists(source) + and _realpath(source).startswith(self.project_root) + ) @_cached def _get_persisted_code(self, object_): source = self._path(object_.co_filename) if not os.path.exists(source): - raise TypeError('no source') - return ('defined', _realpath(source), str(object_.co_firstlineno)) + raise TypeError("no source") + return ("defined", _realpath(source), str(object_.co_firstlineno)) @_cached def _get_persisted_class(self, object_): try: - return ('defined', _realpath(inspect.getsourcefile(object_)), - object_.__name__) + return ( + "defined", + _realpath(inspect.getsourcefile(object_)), + object_.__name__, + ) except (TypeError, AttributeError): - return ('unknown',) + return ("unknown",) def _get_persisted_builtin(self, object_): if isinstance(object_, pycompat.string_types): - return ('builtin', 'str') + return ("builtin", "str") if isinstance(object_, list): holding = None if len(object_) > 0: holding = object_[0] - return ('builtin', 'list', - self._object_to_persisted_form(holding)) + return ("builtin", "list", self._object_to_persisted_form(holding)) if isinstance(object_, dict): keys = None values = None if len(object_) > 0: # @todo - fix it properly, why is __locals__ being # duplicated ? - keys = [key for key in object_.keys() if key != '__locals__'][0] + keys = [key for key in object_.keys() if key != "__locals__"][0] values = object_[keys] - return ('builtin', 'dict', - self._object_to_persisted_form(keys), - self._object_to_persisted_form(values)) + return ( + "builtin", + "dict", + self._object_to_persisted_form(keys), + self._object_to_persisted_form(values), + ) if isinstance(object_, tuple): objects = [] if len(object_) < 3: @@ -162,20 +173,19 @@ def _get_persisted_builtin(self, object_): objects.append(self._object_to_persisted_form(holding)) else: objects.append(self._object_to_persisted_form(object_[0])) - return tuple(['builtin', 'tuple'] + objects) + return tuple(["builtin", "tuple"] + objects) if isinstance(object_, set): holding = None if len(object_) > 0: for o in object_: holding = o break - return ('builtin', 'set', - self._object_to_persisted_form(holding)) - return ('unknown',) + return ("builtin", "set", self._object_to_persisted_form(holding)) + return ("unknown",) def _object_to_persisted_form(self, object_): if object_ is None: - return ('none',) + return ("none",) if isinstance(object_, types.CodeType): return self._get_persisted_code(object_) if isinstance(object_, types.FunctionType): @@ -188,19 +198,19 @@ def _object_to_persisted_form(self, object_): return self._get_persisted_builtin(object_) if isinstance(object_, type): return self._get_persisted_class(object_) - return ('instance', self._get_persisted_class(type(object_))) + return ("instance", self._get_persisted_class(type(object_))) @_cached def _get_persisted_module(self, object_): path = self._path(object_.__file__) if path and os.path.exists(path): - return ('defined', _realpath(path)) - return ('unknown',) + return ("defined", _realpath(path)) + return ("unknown",) def _path(self, path): - if path.endswith('.pyc'): + if path.endswith(".pyc"): path = path[:-1] - if path.endswith('.py'): + if path.endswith(".py"): return path def close(self): @@ -214,17 +224,17 @@ def _realpath(path): project_root = sys.argv[2] file_to_run = sys.argv[3] run_globals = globals() - run_globals.update({'__name__': '__main__', - '__builtins__': __builtins__, - '__file__': file_to_run}) + run_globals.update( + {"__name__": "__main__", "__builtins__": __builtins__, "__file__": file_to_run} + ) - if send_info != '-': + if send_info != "-": data_sender = _FunctionCallDataSender(send_info, project_root) del sys.argv[1:4] pycompat.execfile(file_to_run, run_globals) - if send_info != '-': + if send_info != "-": data_sender.close() -if __name__ == '__main__': +if __name__ == "__main__": __rope_start_everything() diff --git a/rope/base/oi/soa.py b/rope/base/oi/soa.py index 20ab567ed..6d8de5ca7 100644 --- a/rope/base/oi/soa.py +++ b/rope/base/oi/soa.py @@ -4,32 +4,31 @@ from rope.base import pyobjects, evaluate, astutils, arguments -def analyze_module(pycore, pymodule, should_analyze, - search_subscopes, followed_calls): +def analyze_module(pycore, pymodule, should_analyze, search_subscopes, followed_calls): """Analyze `pymodule` for static object inference Analyzes scopes for collecting object information. The analysis starts from inner scopes. """ - _analyze_node(pycore, pymodule, should_analyze, - search_subscopes, followed_calls) + _analyze_node(pycore, pymodule, should_analyze, search_subscopes, followed_calls) -def _analyze_node(pycore, pydefined, should_analyze, - search_subscopes, followed_calls): +def _analyze_node(pycore, pydefined, should_analyze, search_subscopes, followed_calls): if search_subscopes(pydefined): for scope in pydefined.get_scope().get_scopes(): - _analyze_node(pycore, scope.pyobject, should_analyze, - search_subscopes, followed_calls) + _analyze_node( + pycore, scope.pyobject, should_analyze, search_subscopes, followed_calls + ) if should_analyze(pydefined): new_followed_calls = max(0, followed_calls - 1) return_true = lambda pydefined: True return_false = lambda pydefined: False def _follow(pyfunction): - _analyze_node(pycore, pyfunction, return_true, - return_false, new_followed_calls) + _analyze_node( + pycore, pyfunction, return_true, return_false, new_followed_calls + ) if not followed_calls: _follow = None @@ -39,7 +38,6 @@ def _follow(pyfunction): class SOAVisitor(object): - def __init__(self, pycore, pydefined, follow_callback=None): self.pycore = pycore self.pymodule = pydefined.get_module() @@ -60,24 +58,22 @@ def _Call(self, node): return pyfunction = pyname.get_object() if isinstance(pyfunction, pyobjects.AbstractFunction): - args = arguments.create_arguments(primary, pyfunction, - node, self.scope) + args = arguments.create_arguments(primary, pyfunction, node, self.scope) elif isinstance(pyfunction, pyobjects.PyClass): pyclass = pyfunction - if '__init__' in pyfunction: - pyfunction = pyfunction['__init__'].get_object() + if "__init__" in pyfunction: + pyfunction = pyfunction["__init__"].get_object() pyname = rope.base.pynames.UnboundName(pyobjects.PyObject(pyclass)) args = self._args_with_self(primary, pyname, pyfunction, node) - elif '__call__' in pyfunction: - pyfunction = pyfunction['__call__'].get_object() + elif "__call__" in pyfunction: + pyfunction = pyfunction["__call__"].get_object() args = self._args_with_self(primary, pyname, pyfunction, node) else: return self._call(pyfunction, args) def _args_with_self(self, primary, self_pyname, pyfunction, node): - base_args = arguments.create_arguments(primary, pyfunction, - node, self.scope) + base_args = arguments.create_arguments(primary, pyfunction, node, self.scope) return arguments.MixedArguments(self_pyname, base_args, self.scope) def _call(self, pyfunction, args): @@ -85,7 +81,8 @@ def _call(self, pyfunction, args): if self.follow is not None: before = self._parameter_objects(pyfunction) self.pycore.object_info.function_called( - pyfunction, args.get_arguments(pyfunction.get_param_names())) + pyfunction, args.get_arguments(pyfunction.get_param_names()) + ) pyfunction._set_parameter_pyobjects(None) if self.follow is not None: after = self._parameter_objects(pyfunction) @@ -125,29 +122,30 @@ def _Assign(self, node): def _evaluate_assign_value(self, node, nodes, type_hint=False): for subscript, levels in nodes: instance = evaluate.eval_node(self.scope, subscript.value) - args_pynames = [evaluate.eval_node(self.scope, - subscript.slice)] + args_pynames = [evaluate.eval_node(self.scope, subscript.slice)] value = rope.base.oi.soi._infer_assignment( - rope.base.pynames.AssignmentValue(node.value, levels, - type_hint=type_hint), - self.pymodule) + rope.base.pynames.AssignmentValue( + node.value, levels, type_hint=type_hint + ), + self.pymodule, + ) args_pynames.append(rope.base.pynames.UnboundName(value)) if instance is not None and value is not None: pyobject = instance.get_object() - if '__setitem__' in pyobject: - pyfunction = pyobject['__setitem__'].get_object() + if "__setitem__" in pyobject: + pyfunction = pyobject["__setitem__"].get_object() args = arguments.ObjectArguments([instance] + args_pynames) self._call(pyfunction, args) # IDEA: handle `__setslice__`, too class _SOAAssignVisitor(astutils._NodeNameCollector): - def __init__(self): super(_SOAAssignVisitor, self).__init__() self.nodes = [] def _added(self, node, levels): - if isinstance(node, rope.base.ast.Subscript) and \ - isinstance(node.slice, (rope.base.ast.Index, rope.base.ast.expr)): + if isinstance(node, rope.base.ast.Subscript) and isinstance( + node.slice, (rope.base.ast.Index, rope.base.ast.expr) + ): self.nodes.append((node, levels)) diff --git a/rope/base/oi/soi.py b/rope/base/oi/soi.py index adea9f3d5..2cecdf8b1 100644 --- a/rope/base/oi/soi.py +++ b/rope/base/oi/soi.py @@ -11,8 +11,7 @@ from rope.base.oi.type_hinting.factory import get_type_hinting_factory -_ignore_inferred = utils.ignore_exception( - rope.base.pyobjects.IsBeingInferredError) +_ignore_inferred = utils.ignore_exception(rope.base.pyobjects.IsBeingInferredError) @_ignore_inferred @@ -25,14 +24,15 @@ def infer_returned_object(pyfunction, args): result = _infer_returned(pyfunction, args) if result is not None: if args and pyfunction.get_module().get_resource() is not None: - params = args.get_arguments( - pyfunction.get_param_names(special_args=False)) + params = args.get_arguments(pyfunction.get_param_names(special_args=False)) object_info.function_called(pyfunction, params, result) return result result = object_info.get_returned(pyfunction, args) if result is not None: return result - hint_return = get_type_hinting_factory(pyfunction.pycore.project).make_return_provider() + hint_return = get_type_hinting_factory( + pyfunction.pycore.project + ).make_return_provider() type_ = hint_return(pyfunction) if type_ is not None: return rope.base.pyobjects.PyObject(type_) @@ -51,15 +51,15 @@ def infer_parameter_objects(pyfunction): def _handle_first_parameter(pyobject, parameters): kind = pyobject.get_kind() - if parameters is None or kind not in ['method', 'classmethod']: + if parameters is None or kind not in ["method", "classmethod"]: pass if not parameters: if not pyobject.get_param_names(special_args=False): return parameters.append(rope.base.pyobjects.get_unknown()) - if kind == 'method': + if kind == "method": parameters[0] = rope.base.pyobjects.PyObject(pyobject.parent) - if kind == 'classmethod': + if kind == "classmethod": parameters[0] = pyobject.parent @@ -69,14 +69,19 @@ def infer_assigned_object(pyname): return for assignment in reversed(pyname.assignments): result = _infer_assignment(assignment, pyname.module) - if isinstance(result, rope.base.builtins.BuiltinUnknown) and result.get_name() == 'NotImplementedType': + if ( + isinstance(result, rope.base.builtins.BuiltinUnknown) + and result.get_name() == "NotImplementedType" + ): break elif result == rope.base.pyobjects.get_unknown(): break elif result is not None: return result - hint_assignment = get_type_hinting_factory(pyname.module.pycore.project).make_assignment_provider() + hint_assignment = get_type_hinting_factory( + pyname.module.pycore.project + ).make_assignment_provider() hinting_result = hint_assignment(pyname) if hinting_result is not None: return rope.base.pyobjects.PyObject(hinting_result) @@ -85,8 +90,7 @@ def infer_assigned_object(pyname): def get_passed_objects(pyfunction, parameter_index): object_info = pyfunction.pycore.object_info - result = object_info.get_passed_objects(pyfunction, - parameter_index) + result = object_info.get_passed_objects(pyfunction, parameter_index) if not result: statically_inferred = _parameter_objects(pyfunction) if len(statically_inferred) > parameter_index: @@ -101,7 +105,8 @@ def _infer_returned(pyobject, args): # does not come from a good call site pyobject.get_scope().invalidate_data() pyobject._set_parameter_pyobjects( - args.get_arguments(pyobject.get_param_names(special_args=False))) + args.get_arguments(pyobject.get_param_names(special_args=False)) + ) scope = pyobject.get_scope() if not scope._get_returned_asts(): return @@ -134,6 +139,7 @@ def _parameter_objects(pyobject): result.append(rope.base.pyobjects.get_unknown()) return result + # handling `rope.base.pynames.AssignmentValue` @@ -175,21 +181,25 @@ def _follow_pyname(assignment, pymodule, lineno=None): pyname = evaluate.eval_node(holding_scope, assign_node) if pyname is not None: result = pyname.get_object() - if isinstance(result.get_type(), rope.base.builtins.Property) and \ - holding_scope.get_kind() == 'Class': + if ( + isinstance(result.get_type(), rope.base.builtins.Property) + and holding_scope.get_kind() == "Class" + ): arg = rope.base.pynames.UnboundName( - rope.base.pyobjects.PyObject(holding_scope.pyobject)) + rope.base.pyobjects.PyObject(holding_scope.pyobject) + ) return pyname, result.get_type().get_property_object( - arguments.ObjectArguments([arg])) + arguments.ObjectArguments([arg]) + ) return pyname, result @_ignore_inferred def _follow_evaluations(assignment, pyname, pyobject): new_pyname = pyname - tokens = assignment.evaluation.split('.') + tokens = assignment.evaluation.split(".") for token in tokens: - call = token.endswith('()') + call = token.endswith("()") if call: token = token[:-2] if token: @@ -211,8 +221,7 @@ def _follow_evaluations(assignment, pyname, pyobject): def _get_lineno_for_node(assign_node): - if hasattr(assign_node, 'lineno') and \ - assign_node.lineno is not None: + if hasattr(assign_node, "lineno") and assign_node.lineno is not None: return assign_node.lineno return 1 diff --git a/rope/base/oi/transform.py b/rope/base/oi/transform.py index 05870844d..768d5aea9 100644 --- a/rope/base/oi/transform.py +++ b/rope/base/oi/transform.py @@ -20,13 +20,13 @@ def __init__(self, project): def transform(self, pyobject): """Transform a `PyObject` to textual form""" if pyobject is None: - return ('none',) + return ("none",) object_type = type(pyobject) try: - method = getattr(self, object_type.__name__ + '_to_textual') + method = getattr(self, object_type.__name__ + "_to_textual") return method(pyobject) except AttributeError: - return ('unknown',) + return ("unknown",) def __call__(self, pyobject): return self.transform(pyobject) @@ -34,10 +34,10 @@ def __call__(self, pyobject): def PyObject_to_textual(self, pyobject): if isinstance(pyobject.get_type(), rope.base.pyobjects.AbstractClass): result = self.transform(pyobject.get_type()) - if result[0] == 'defined': - return ('instance', result) + if result[0] == "defined": + return ("instance", result) return result - return ('unknown',) + return ("unknown",) def PyFunction_to_textual(self, pyobject): return self._defined_to_textual(pyobject) @@ -50,44 +50,52 @@ def _defined_to_textual(self, pyobject): while pyobject.parent is not None: address.insert(0, pyobject.get_name()) pyobject = pyobject.parent - return ('defined', self._get_pymodule_path(pyobject.get_module()), - '.'.join(address)) + return ( + "defined", + self._get_pymodule_path(pyobject.get_module()), + ".".join(address), + ) def PyModule_to_textual(self, pyobject): - return ('defined', self._get_pymodule_path(pyobject)) + return ("defined", self._get_pymodule_path(pyobject)) def PyPackage_to_textual(self, pyobject): - return ('defined', self._get_pymodule_path(pyobject)) + return ("defined", self._get_pymodule_path(pyobject)) def List_to_textual(self, pyobject): - return ('builtin', 'list', self.transform(pyobject.holding)) + return ("builtin", "list", self.transform(pyobject.holding)) def Dict_to_textual(self, pyobject): - return ('builtin', 'dict', self.transform(pyobject.keys), - self.transform(pyobject.values)) + return ( + "builtin", + "dict", + self.transform(pyobject.keys), + self.transform(pyobject.values), + ) def Tuple_to_textual(self, pyobject): - objects = [self.transform(holding) - for holding in pyobject.get_holding_objects()] - return tuple(['builtin', 'tuple'] + objects) + objects = [ + self.transform(holding) for holding in pyobject.get_holding_objects() + ] + return tuple(["builtin", "tuple"] + objects) def Set_to_textual(self, pyobject): - return ('builtin', 'set', self.transform(pyobject.holding)) + return ("builtin", "set", self.transform(pyobject.holding)) def Iterator_to_textual(self, pyobject): - return ('builtin', 'iter', self.transform(pyobject.holding)) + return ("builtin", "iter", self.transform(pyobject.holding)) def Generator_to_textual(self, pyobject): - return ('builtin', 'generator', self.transform(pyobject.holding)) + return ("builtin", "generator", self.transform(pyobject.holding)) def Str_to_textual(self, pyobject): - return ('builtin', 'str') + return ("builtin", "str") def File_to_textual(self, pyobject): - return ('builtin', 'file') + return ("builtin", "file") def BuiltinFunction_to_textual(self, pyobject): - return ('builtin', 'function', pyobject.get_name()) + return ("builtin", "function", pyobject.get_name()) def _get_pymodule_path(self, pymodule): return self.resource_to_path(pymodule.get_resource()) @@ -114,13 +122,13 @@ def transform(self, textual): return None type = textual[0] try: - method = getattr(self, type + '_to_pyobject') + method = getattr(self, type + "_to_pyobject") return method(textual) except AttributeError: return None def builtin_to_pyobject(self, textual): - method = getattr(self, 'builtin_%s_to_pyobject' % textual[1], None) + method = getattr(self, "builtin_%s_to_pyobject" % textual[1], None) if method is not None: return method(textual) @@ -173,7 +181,7 @@ def _module_to_pyobject(self, textual): def _hierarchical_defined_to_pyobject(self, textual): path = textual[1] - names = textual[2].split('.') + names = textual[2].split(".") pymodule = self._get_pymodule(path) pyobject = pymodule for name in names: @@ -189,7 +197,7 @@ def _hierarchical_defined_to_pyobject(self, textual): return pyobject def defined_to_pyobject(self, textual): - if len(textual) == 2 or textual[2] == '': + if len(textual) == 2 or textual[2] == "": return self._module_to_pyobject(textual) else: return self._hierarchical_defined_to_pyobject(textual) @@ -213,6 +221,7 @@ def path_to_resource(self, path): # INFO: This is a project file; should not be absolute return None import rope.base.project + return rope.base.project.get_no_project().get_resource(path) except exceptions.ResourceNotFoundError: return None @@ -248,12 +257,10 @@ def _class_to_pyobject(self, textual): suspected = None if name in module_scope.get_names(): suspected = module_scope[name].get_object() - if suspected is not None and \ - isinstance(suspected, rope.base.pyobjects.PyClass): + if suspected is not None and isinstance(suspected, rope.base.pyobjects.PyClass): return suspected else: - lineno = self._find_occurrence(name, - pymodule.get_resource().read()) + lineno = self._find_occurrence(name, pymodule.get_resource().read()) if lineno is not None: inner_scope = module_scope.get_inner_scope_for_line(lineno) return inner_scope.pyobject @@ -270,16 +277,16 @@ def defined_to_pyobject(self, textual): return result def _find_occurrence(self, name, source): - pattern = re.compile(r'^\s*class\s*' + name + r'\b') - lines = source.split('\n') + pattern = re.compile(r"^\s*class\s*" + name + r"\b") + lines = source.split("\n") for i in range(len(lines)): if pattern.match(lines[i]): return i + 1 def path_to_resource(self, path): import rope.base.libutils - relpath = rope.base.libutils.path_relative_to_project_root( - self.project, path) + + relpath = rope.base.libutils.path_relative_to_project_root(self.project, path) if relpath is not None: path = relpath return super(DOITextualToPyObject, self).path_to_resource(path) diff --git a/rope/base/oi/type_hinting/evaluate.py b/rope/base/oi/type_hinting/evaluate.py index 3b82eb079..9b6ec3dc1 100644 --- a/rope/base/oi/type_hinting/evaluate.py +++ b/rope/base/oi/type_hinting/evaluate.py @@ -17,32 +17,28 @@ def __init__(self): self.third = None # used by tree nodes def nud(self, parser): - raise SyntaxError( - "Syntax error (%r)." % self.name - ) + raise SyntaxError("Syntax error (%r)." % self.name) def led(self, left, parser): - raise SyntaxError( - "Unknown operator (%r)." % self.name - ) + raise SyntaxError("Unknown operator (%r)." % self.name) def evaluate(self, pyobject): raise NotImplementedError(self.name, self) def __repr__(self): - if self.name == '(name)': + if self.name == "(name)": return "(%s %s)" % (self.name[1:-1], self.value) out = [repr(self.name), self.first, self.second, self.third] out = [str(i) for i in out if i] - return '(' + ' '.join(out) + ')' + return "(" + " ".join(out) + ")" class SymbolTable(object): - def multi(func): def _inner(self, names, *a, **kw): for name in names.split(): func(self, name, *a, **kw) + return _inner def __init__(self): @@ -126,12 +122,14 @@ def led(self, left, parser): multi = staticmethod(multi) # Just for code checker + symbol_table = SymbolTable() class Lexer(object): - _token_pattern = re.compile(r""" + _token_pattern = re.compile( + r""" \s* (?: ( @@ -141,7 +139,9 @@ class Lexer(object): ) # operator | ([a-zA-Z](?:\w|\.)*) # name ) - """, re.U | re.S | re.X) + """, + re.U | re.S | re.X, + ) def __init__(self, symbol_table): self.symbol_table = symbol_table @@ -156,24 +156,26 @@ def tokenize(self, program): s = symbol() s.value = value else: - raise SyntaxError("Unknown operator ({0}). Possible operators are {1!r}".format( - value, list(self.symbol_table) - )) + raise SyntaxError( + "Unknown operator ({0}). Possible operators are {1!r}".format( + value, list(self.symbol_table) + ) + ) yield s def _tokenize_expr(self, program): if isinstance(program, bytes): - program = program.decode('utf-8') + program = program.decode("utf-8") # import pprint; pprint.pprint(self._token_pattern.findall(program)) for operator, name in self._token_pattern.findall(program): if operator: - yield '(operator)', operator + yield "(operator)", operator elif name: - yield '(name)', name + yield "(name)", name else: raise SyntaxError - yield '(end)', '(end)' + yield "(end)", "(end)" class Parser(object): @@ -205,7 +207,9 @@ def expression(self, rbp=0): def advance(self, name=None): if name and self.token.name != name: - raise SyntaxError("Expected {0!r} but found {1!r}".format(name, self.token.name)) + raise SyntaxError( + "Expected {0!r} but found {1!r}".format(name, self.token.name) + ) self.token = self.next() @@ -218,103 +222,107 @@ def bind(fn): return bind + symbol, infix, infix_r, prefix, postfix, ternary = ( - symbol_table.symbol, symbol_table.infix, symbol_table.infix_r, symbol_table.prefix, - symbol_table.postfix, symbol_table.ternary + symbol_table.symbol, + symbol_table.infix, + symbol_table.infix_r, + symbol_table.prefix, + symbol_table.postfix, + symbol_table.ternary, ) -symbol('(', 270) -symbol(')') -symbol('[', 250) # Parameters -symbol(']') -symbol('->', 230) -infix('|', 170) -infix('or', 170) -symbol(',') +symbol("(", 270) +symbol(")") +symbol("[", 250) # Parameters +symbol("]") +symbol("->", 230) +infix("|", 170) +infix("or", 170) +symbol(",") -symbol('(name)') -symbol('(end)') +symbol("(name)") +symbol("(end)") -@method(symbol('(name)')) +@method(symbol("(name)")) def nud(self, parser): return self -@method(symbol('(name)')) +@method(symbol("(name)")) def evaluate(self, pyobject): return utils.resolve_type(self.value, pyobject) # Parametrized objects -@method(symbol('[')) +@method(symbol("[")) def led(self, left, parser): self.first = left self.second = [] - if parser.token.name != ']': + if parser.token.name != "]": while 1: - if parser.token.name == ']': + if parser.token.name == "]": break self.second.append(parser.expression()) - if parser.token.name != ',': + if parser.token.name != ",": break - parser.advance(',') - parser.advance(']') + parser.advance(",") + parser.advance("]") return self -@method(symbol('[')) +@method(symbol("[")) def evaluate(self, pyobject): return utils.parametrize_type( - self.first.evaluate(pyobject), - *[i.evaluate(pyobject) for i in self.second] + self.first.evaluate(pyobject), *[i.evaluate(pyobject) for i in self.second] ) # Anonymous Function Calls -@method(symbol('(')) +@method(symbol("(")) def nud(self, parser): self.second = [] - if parser.token.name != ')': + if parser.token.name != ")": while 1: self.second.append(parser.expression()) - if parser.token.name != ',': + if parser.token.name != ",": break - parser.advance(',') - parser.advance(')') - parser.advance('->') - self.third = parser.expression(symbol('->').lbp + 0.1) + parser.advance(",") + parser.advance(")") + parser.advance("->") + self.third = parser.expression(symbol("->").lbp + 0.1) return self # Function Calls -@method(symbol('(')) +@method(symbol("(")) def led(self, left, parser): self.first = left self.second = [] - if parser.token.name != ')': + if parser.token.name != ")": while 1: self.second.append(parser.expression()) - if parser.token.name != ',': + if parser.token.name != ",": break - parser.advance(',') - parser.advance(')') - parser.advance('->') - self.third = parser.expression(symbol('->').lbp + 0.1) + parser.advance(",") + parser.advance(")") + parser.advance("->") + self.third = parser.expression(symbol("->").lbp + 0.1) return self -@method(symbol('(')) +@method(symbol("(")) def evaluate(self, pyobject): # TODO: Implement me raise NotImplementedError -@method(symbol('or')) -@method(symbol('|')) +@method(symbol("or")) +@method(symbol("|")) def evaluate(self, pyobject): # TODO: Implement me - raise NotImplementedError + raise NotImplementedError class Compiler(object): @@ -334,6 +342,7 @@ def __call__(self, program): """ return self._make_parser().parse(program) + compile = Compiler() @@ -347,7 +356,12 @@ def __call__(self, program, pyobject): :type program: str or rope.base.oi.type_hinting.evaluate.SymbolBase :rtype: rope.base.pyobjects.PyDefinedObject | rope.base.pyobjects.PyObject or None """ - ast = self.compile(program) if isinstance(program, pycompat.string_types) else program + ast = ( + self.compile(program) + if isinstance(program, pycompat.string_types) + else program + ) return ast.evaluate(pyobject) + evaluate = Evaluator() diff --git a/rope/base/oi/type_hinting/factory.py b/rope/base/oi/type_hinting/factory.py index 644d12c03..c37a3de91 100644 --- a/rope/base/oi/type_hinting/factory.py +++ b/rope/base/oi/type_hinting/factory.py @@ -1,25 +1,34 @@ from rope.base.oi.type_hinting import interfaces from rope.base.oi.type_hinting.providers import ( - composite, inheritance, docstrings, numpydocstrings, pep0484_type_comments + composite, + inheritance, + docstrings, + numpydocstrings, + pep0484_type_comments, ) from rope.base.oi.type_hinting.resolvers import composite as composite_resolvers, types from rope.base import utils class TypeHintingFactory(interfaces.ITypeHintingFactory): - @utils.saveit def make_param_provider(self): providers = [ - docstrings.ParamProvider(docstrings.DocstringParamParser(), self.make_resolver()), - docstrings.ParamProvider(numpydocstrings.NumPyDocstringParamParser(), self.make_resolver()), + docstrings.ParamProvider( + docstrings.DocstringParamParser(), self.make_resolver() + ), + docstrings.ParamProvider( + numpydocstrings.NumPyDocstringParamParser(), self.make_resolver() + ), ] return inheritance.ParamProvider(composite.ParamProvider(*providers)) @utils.saveit def make_return_provider(self): providers = [ - docstrings.ReturnProvider(docstrings.DocstringReturnParser(), self.make_resolver()), + docstrings.ReturnProvider( + docstrings.DocstringReturnParser(), self.make_resolver() + ), ] return inheritance.ReturnProvider(composite.ReturnProvider(*providers)) @@ -27,8 +36,12 @@ def make_return_provider(self): def make_assignment_provider(self): providers = [ pep0484_type_comments.AssignmentProvider(self.make_resolver()), - docstrings.AssignmentProvider(docstrings.DocstringParamParser(), self.make_resolver()), - docstrings.AssignmentProvider(numpydocstrings.NumPyDocstringParamParser(), self.make_resolver()), + docstrings.AssignmentProvider( + docstrings.DocstringParamParser(), self.make_resolver() + ), + docstrings.AssignmentProvider( + numpydocstrings.NumPyDocstringParamParser(), self.make_resolver() + ), ] return inheritance.AssignmentProvider(composite.AssignmentProvider(*providers)) @@ -47,15 +60,14 @@ def make_resolver(self): class TypeHintingFactoryAccessor(object): - def __call__(self, project): """ :type project: rope.base.project.Project :rtype: rope.base.oi.type_hinting.interfaces.ITypeHintingFactory """ factory_location = project.get_prefs().get( - 'type_hinting_factory', - 'rope.base.oi.type_hinting.factory.default_type_hinting_factory' + "type_hinting_factory", + "rope.base.oi.type_hinting.factory.default_type_hinting_factory", ) return self._get_factory(factory_location) @@ -67,4 +79,5 @@ def _get_factory(self, factory_location): """ return utils.resolve(factory_location) + get_type_hinting_factory = TypeHintingFactoryAccessor() diff --git a/rope/base/oi/type_hinting/interfaces.py b/rope/base/oi/type_hinting/interfaces.py index fc5568f32..fa9eef165 100644 --- a/rope/base/oi/type_hinting/interfaces.py +++ b/rope/base/oi/type_hinting/interfaces.py @@ -1,5 +1,4 @@ class ITypeHintingFactory(object): - def make_param_provider(self): """ :rtype: rope.base.oi.type_hinting.providers.interfaces.IParamProvider diff --git a/rope/base/oi/type_hinting/providers/composite.py b/rope/base/oi/type_hinting/providers/composite.py index 34a9fae7b..8cfb2d7b9 100644 --- a/rope/base/oi/type_hinting/providers/composite.py +++ b/rope/base/oi/type_hinting/providers/composite.py @@ -2,7 +2,6 @@ class ParamProvider(interfaces.IParamProvider): - def __init__(self, *delegates): """ :type delegates: list[rope.base.oi.type_hinting.providers.interfaces.IParamProvider] @@ -22,7 +21,6 @@ def __call__(self, pyfunc, param_name): class ReturnProvider(interfaces.IReturnProvider): - def __init__(self, *delegates): """ :type delegates: list[rope.base.oi.type_hinting.providers.interfaces.IReturnProvider] @@ -41,7 +39,6 @@ def __call__(self, pyfunc): class AssignmentProvider(interfaces.IAssignmentProvider): - def __init__(self, *delegates): """ :type delegates: list[rope.base.oi.type_hinting.providers.interfaces.IAssignmentProvider] diff --git a/rope/base/oi/type_hinting/providers/docstrings.py b/rope/base/oi/type_hinting/providers/docstrings.py index 8d52e6fbf..2e1c864a5 100644 --- a/rope/base/oi/type_hinting/providers/docstrings.py +++ b/rope/base/oi/type_hinting/providers/docstrings.py @@ -28,7 +28,6 @@ class ParamProvider(interfaces.IParamProvider): - def __init__(self, docstring_parser, resolver): """ :type docstring_parser: rope.base.oi.type_hinting.providers.docstrings.IParamParser @@ -49,7 +48,6 @@ def __call__(self, pyfunc, param_name): class ReturnProvider(interfaces.IReturnProvider): - def __init__(self, docstring_parser, resolver): """ :type docstring_parser: rope.base.oi.type_hinting.providers.docstrings.IReturnParser @@ -69,7 +67,6 @@ def __call__(self, pyfunc): class AssignmentProvider(interfaces.IAssignmentProvider): - def __init__(self, docstring_parser, resolver): """ :type docstring_parser: rope.base.oi.type_hinting.providers.docstrings.IParamParser @@ -94,7 +91,6 @@ def __call__(self, pyname): class IParamParser(object): - def __call__(self, docstring, param_name): """ :type docstring: str @@ -103,7 +99,6 @@ def __call__(self, docstring, param_name): class IReturnParser(object): - def __call__(self, docstring): """ :type docstring: str @@ -113,9 +108,9 @@ def __call__(self, docstring): class DocstringParamParser(IParamParser): DOCSTRING_PARAM_PATTERNS = [ - r'\s*:type\s+%s:\s*([^\n]+)', # Sphinx - r'\s*:param\s+(\w+)\s+%s:[^\n]+', # Sphinx param with type - r'\s*@type\s+%s:\s*([^\n]+)', # Epydoc + r"\s*:type\s+%s:\s*([^\n]+)", # Sphinx + r"\s*:param\s+(\w+)\s+%s:[^\n]+", # Sphinx param with type + r"\s*@type\s+%s:\s*([^\n]+)", # Epydoc ] def __init__(self): @@ -137,8 +132,9 @@ def __call__(self, docstring, param_name): """ if not docstring: return [] - patterns = [re.compile(p % re.escape(param_name)) - for p in self.DOCSTRING_PARAM_PATTERNS] + patterns = [ + re.compile(p % re.escape(param_name)) for p in self.DOCSTRING_PARAM_PATTERNS + ] for pattern in patterns: match = pattern.search(docstring) if match: @@ -150,8 +146,8 @@ def __call__(self, docstring, param_name): class DocstringReturnParser(IReturnParser): DOCSTRING_RETURN_PATTERNS = [ - re.compile(r'\s*:rtype:\s*([^\n]+)', re.M), # Sphinx - re.compile(r'\s*@rtype:\s*([^\n]+)', re.M), # Epydoc + re.compile(r"\s*:rtype:\s*([^\n]+)", re.M), # Sphinx + re.compile(r"\s*@rtype:\s*([^\n]+)", re.M), # Epydoc ] def __init__(self): @@ -169,7 +165,7 @@ def __call__(self, docstring): class RSTRoleStrip(object): - RST_ROLE_PATTERN = re.compile(r':[^`]+:`([^`]+)`') + RST_ROLE_PATTERN = re.compile(r":[^`]+:`([^`]+)`") def __call__(self, type_str): """ diff --git a/rope/base/oi/type_hinting/providers/inheritance.py b/rope/base/oi/type_hinting/providers/inheritance.py index cbefc43b8..43815ba43 100644 --- a/rope/base/oi/type_hinting/providers/inheritance.py +++ b/rope/base/oi/type_hinting/providers/inheritance.py @@ -3,7 +3,6 @@ class ParamProvider(interfaces.IParamProvider): - def __init__(self, delegate): """ :type delegate: rope.base.oi.type_hinting.providers.interfaces.IParamProvider @@ -25,7 +24,6 @@ def __call__(self, pyfunc, param_name): class ReturnProvider(interfaces.IReturnProvider): - def __init__(self, delegate): """ :type delegate: rope.base.oi.type_hinting.providers.interfaces.IReturnProvider @@ -46,7 +44,6 @@ def __call__(self, pyfunc): class AssignmentProvider(interfaces.IAssignmentProvider): - def __init__(self, delegate): """ :type delegate: rope.base.oi.type_hinting.providers.interfaces.IAssignmentProvider diff --git a/rope/base/oi/type_hinting/providers/interfaces.py b/rope/base/oi/type_hinting/providers/interfaces.py index 595cf82ca..9d10a70b7 100644 --- a/rope/base/oi/type_hinting/providers/interfaces.py +++ b/rope/base/oi/type_hinting/providers/interfaces.py @@ -1,5 +1,4 @@ class IParamProvider(object): - def __call__(self, pyfunc, param_name): """ :type pyfunc: rope.base.pyobjectsdef.PyFunction @@ -13,6 +12,7 @@ class IReturnProvider(object): """ :type resolve: rope.base.oi.type_hinting.resolvers.interfaces.IResolver """ + resolve = None def __call__(self, pyfunc): @@ -27,6 +27,7 @@ class IAssignmentProvider(object): """ :type resolve: rope.base.oi.type_hinting.resolvers.interfaces.IResolver """ + resolve = None def __call__(self, pyname): diff --git a/rope/base/oi/type_hinting/providers/numpydocstrings.py b/rope/base/oi/type_hinting/providers/numpydocstrings.py index 74c05900a..6c4fb229d 100644 --- a/rope/base/oi/type_hinting/providers/numpydocstrings.py +++ b/rope/base/oi/type_hinting/providers/numpydocstrings.py @@ -14,19 +14,18 @@ class NumPyDocstringParamParser(docstrings.IParamParser): - def __call__(self, docstring, param_name): """Search `docstring` (in numpydoc format) for type(-s) of `param_name`.""" if not docstring: - return [] - params = NumpyDocString(docstring)._parsed_data['Parameters'] + return [] + params = NumpyDocString(docstring)._parsed_data["Parameters"] for p_name, p_type, p_descr in params: if p_name == param_name: - m = re.match('([^,]+(,[^,]+)*?)(,[ ]*optional)?$', p_type) + m = re.match("([^,]+(,[^,]+)*?)(,[ ]*optional)?$", p_type) if m: p_type = m.group(1) - if p_type.startswith('{'): + if p_type.startswith("{"): types = set(type(x).__name__ for x in literal_eval(p_type)) return list(types) else: diff --git a/rope/base/oi/type_hinting/providers/pep0484_type_comments.py b/rope/base/oi/type_hinting/providers/pep0484_type_comments.py index 357d5906a..243d8bb9f 100644 --- a/rope/base/oi/type_hinting/providers/pep0484_type_comments.py +++ b/rope/base/oi/type_hinting/providers/pep0484_type_comments.py @@ -4,16 +4,13 @@ class AssignmentProvider(interfaces.IAssignmentProvider): - def __init__(self, resolver): """ :type resolver: rope.base.oi.type_hinting.resolvers.interfaces.IResolver """ self._resolve = resolver - PEP0484_TYPE_COMMENT_PATTERNS = ( - re.compile(r'type:\s*([^\n]+)'), - ) + PEP0484_TYPE_COMMENT_PATTERNS = (re.compile(r"type:\s*([^\n]+)"),) def __call__(self, pyname): """ @@ -21,16 +18,17 @@ def __call__(self, pyname): :rtype: rope.base.pyobjects.PyDefinedObject | rope.base.pyobjects.PyObject or None """ from rope.base.oi.soi import _get_lineno_for_node + lineno = _get_lineno_for_node(pyname.assignments[0].ast_node) holding_scope = pyname.module.get_scope().get_inner_scope_for_line(lineno) line = holding_scope._get_global_scope()._scope_finder.lines.get_line(lineno) - if '#' in line: - type_strs = self._search_type_in_type_comment(line.split('#', 1)[1]) + if "#" in line: + type_strs = self._search_type_in_type_comment(line.split("#", 1)[1]) if type_strs: return self._resolve(type_strs[0], holding_scope.pyobject) def _search_type_in_type_comment(self, code): - """ For more info see: + """For more info see: https://www.python.org/dev/peps/pep-0484/#type-comments >>> AssignmentProvider()._search_type_in_type_comment('type: int') diff --git a/rope/base/oi/type_hinting/resolvers/composite.py b/rope/base/oi/type_hinting/resolvers/composite.py index 4e331e48b..48f0612c8 100644 --- a/rope/base/oi/type_hinting/resolvers/composite.py +++ b/rope/base/oi/type_hinting/resolvers/composite.py @@ -2,7 +2,6 @@ class Resolver(interfaces.IResolver): - def __init__(self, *delegates): """ :type delegates: list[rope.base.oi.type_hinting.resolvers.interfaces.IResolver] diff --git a/rope/base/oi/type_hinting/resolvers/interfaces.py b/rope/base/oi/type_hinting/resolvers/interfaces.py index cfcff3d99..120b97232 100644 --- a/rope/base/oi/type_hinting/resolvers/interfaces.py +++ b/rope/base/oi/type_hinting/resolvers/interfaces.py @@ -1,5 +1,4 @@ class IResolver(object): - def __call__(self, hint, pyobject): """ :param hint: For example "List[int]" or "(Foo, Bar) -> Baz" or simple "Foo" diff --git a/rope/base/oi/type_hinting/utils.py b/rope/base/oi/type_hinting/utils.py index ce90dfeb5..bad700af5 100644 --- a/rope/base/oi/type_hinting/utils.py +++ b/rope/base/oi/type_hinting/utils.py @@ -1,4 +1,5 @@ import logging + try: from typing import Union, Optional except ImportError: @@ -51,8 +52,7 @@ def get_class_with_attr_name(pyname): pyobject = holding_scope.pyobject if isinstance(pyobject, PyClass): pyclass = pyobject - elif (isinstance(pyobject, PyFunction) and - isinstance(pyobject.parent, PyClass)): + elif isinstance(pyobject, PyFunction) and isinstance(pyobject.parent, PyClass): pyclass = pyobject.parent else: return @@ -62,8 +62,7 @@ def get_class_with_attr_name(pyname): def get_lineno_for_node(assign_node): - if hasattr(assign_node, 'lineno') and \ - assign_node.lineno is not None: + if hasattr(assign_node, "lineno") and assign_node.lineno is not None: return assign_node.lineno return 1 @@ -83,17 +82,18 @@ def resolve_type(type_name, pyobject): """ Find proper type object from its name. """ - deprecated_aliases = {'collections': 'collections.abc'} + deprecated_aliases = {"collections": "collections.abc"} ret_type = None - logging.debug('Looking for %s', type_name) - if '.' not in type_name: + logging.debug("Looking for %s", type_name) + if "." not in type_name: try: - ret_type = pyobject.get_module().get_scope().get_name( - type_name).get_object() + ret_type = ( + pyobject.get_module().get_scope().get_name(type_name).get_object() + ) except AttributeNotFoundError: - logging.exception('Cannot resolve type %s', type_name) + logging.exception("Cannot resolve type %s", type_name) else: - mod_name, attr_name = type_name.rsplit('.', 1) + mod_name, attr_name = type_name.rsplit(".", 1) try: mod_finder = ScopeNameFinder(pyobject.get_module()) mod = mod_finder._find_module(mod_name).get_object() @@ -101,35 +101,45 @@ def resolve_type(type_name, pyobject): except AttributeNotFoundError: if mod_name in deprecated_aliases: try: - logging.debug('Looking for %s in %s', - attr_name, deprecated_aliases[mod_name]) + logging.debug( + "Looking for %s in %s", attr_name, deprecated_aliases[mod_name] + ) mod = mod_finder._find_module( - deprecated_aliases[mod_name]).get_object() + deprecated_aliases[mod_name] + ).get_object() ret_type = mod.get_attribute(attr_name).get_object() except AttributeNotFoundError: - logging.exception('Cannot resolve type %s in %s', - attr_name, dir(mod)) - logging.debug('ret_type = %s', ret_type) + logging.exception( + "Cannot resolve type %s in %s", attr_name, dir(mod) + ) + logging.debug("ret_type = %s", ret_type) return ret_type class ParametrizeType(object): _supported_mapping = { - 'builtins.list': 'rope.base.builtins.get_list', - 'builtins.tuple': 'rope.base.builtins.get_tuple', - 'builtins.set': 'rope.base.builtins.get_set', - 'builtins.dict': 'rope.base.builtins.get_dict', - '_collections_abc.Iterable': 'rope.base.builtins.get_iterator', - '_collections_abc.Iterator': 'rope.base.builtins.get_iterator', - 'collections.abc.Iterable': 'rope.base.builtins.get_iterator', # Python3.3 - 'collections.abc.Iterator': 'rope.base.builtins.get_iterator', # Python3.3 + "builtins.list": "rope.base.builtins.get_list", + "builtins.tuple": "rope.base.builtins.get_tuple", + "builtins.set": "rope.base.builtins.get_set", + "builtins.dict": "rope.base.builtins.get_dict", + "_collections_abc.Iterable": "rope.base.builtins.get_iterator", + "_collections_abc.Iterator": "rope.base.builtins.get_iterator", + "collections.abc.Iterable": "rope.base.builtins.get_iterator", # Python3.3 + "collections.abc.Iterator": "rope.base.builtins.get_iterator", # Python3.3 } if pycompat.PY2: - _supported_mapping = dict(( - (k.replace('builtins.', '__builtin__.').replace('_collections_abc.', '_abcoll.'), v) - for k, v in _supported_mapping.items() - )) + _supported_mapping = dict( + ( + ( + k.replace("builtins.", "__builtin__.").replace( + "_collections_abc.", "_abcoll." + ), + v, + ) + for k, v in _supported_mapping.items() + ) + ) def __call__(self, pyobject, *args, **kwargs): """ @@ -144,11 +154,12 @@ def __call__(self, pyobject, *args, **kwargs): return pyobject def _get_type_factory(self, pyobject): - type_str = '{0}.{1}'.format( + type_str = "{0}.{1}".format( pyobject.get_module().get_name(), pyobject.get_name(), ) if type_str in self._supported_mapping: return base_utils.resolve(self._supported_mapping[type_str]) + parametrize_type = ParametrizeType() diff --git a/rope/base/prefs.py b/rope/base/prefs.py index 2ab45dac5..222e41c37 100644 --- a/rope/base/prefs.py +++ b/rope/base/prefs.py @@ -1,5 +1,4 @@ class Prefs(object): - def __init__(self): self.prefs = {} self.callbacks = {} diff --git a/rope/base/project.py b/rope/base/project.py index 033ff98f9..02407edb5 100644 --- a/rope/base/project.py +++ b/rope/base/project.py @@ -17,7 +17,6 @@ class _Project(object): - def __init__(self, fscommands): self.observers = [] self.fscommands = fscommands @@ -39,14 +38,14 @@ def get_resource(self, resource_name): path = self._get_resource_path(resource_name) if not os.path.exists(path): raise exceptions.ResourceNotFoundError( - 'Resource <%s> does not exist' % resource_name) + "Resource <%s> does not exist" % resource_name + ) elif os.path.isfile(path): return File(self, resource_name) elif os.path.isdir(path): return Folder(self, resource_name) else: - raise exceptions.ResourceNotFoundError('Unknown resource ' - + resource_name) + raise exceptions.ResourceNotFoundError("Unknown resource " + resource_name) def get_module(self, name, folder=None): """Returns a `PyObject` if the module was found.""" @@ -56,12 +55,12 @@ def get_module(self, name, folder=None): return pymod module = self.find_module(name, folder) if module is None: - raise ModuleNotFoundError('Module %s not found' % name) + raise ModuleNotFoundError("Module %s not found" % name) return self.pycore.resource_to_pyobject(module) def get_python_path_folders(self): result = [] - for src in self.prefs.get('python_path', []) + sys.path: + for src in self.prefs.get("python_path", []) + sys.path: try: src_folder = get_no_project().get_resource(src) result.append(src_folder) @@ -132,7 +131,7 @@ def get_prefs(self): def get_relative_module(self, name, folder, level): module = self.find_relative_module(name, folder, level) if module is None: - raise ModuleNotFoundError('Module %s not found' % name) + raise ModuleNotFoundError("Module %s not found" % name) return self.pycore.resource_to_pyobject(module) def find_module(self, modname, folder=None): @@ -157,7 +156,7 @@ def find_module(self, modname, folder=None): def find_relative_module(self, modname, folder, level): for i in range(level - 1): folder = folder.parent - if modname == '': + if modname == "": return folder else: return _find_module_in_folder(folder, modname) @@ -179,8 +178,7 @@ def pycore(self): return pycore.PyCore(self) def close(self): - warnings.warn('Cannot close a NoProject', - DeprecationWarning, stacklevel=2) + warnings.warn("Cannot close a NoProject", DeprecationWarning, stacklevel=2) ropefolder = None @@ -188,8 +186,9 @@ def close(self): class Project(_Project): """A Project containing files and folders""" - def __init__(self, projectroot, fscommands=None, - ropefolder='.ropeproject', **prefs): + def __init__( + self, projectroot, fscommands=None, ropefolder=".ropeproject", **prefs + ): """A rope project :parameters: @@ -203,29 +202,28 @@ def __init__(self, projectroot, fscommands=None, overwrite config file preferences. """ - if projectroot != '/': - projectroot = _realpath(projectroot).rstrip('/\\') + if projectroot != "/": + projectroot = _realpath(projectroot).rstrip("/\\") self._address = projectroot self._ropefolder_name = ropefolder if not os.path.exists(self._address): os.mkdir(self._address) elif not os.path.isdir(self._address): - raise exceptions.RopeError('Project root exists and' - ' is not a directory') + raise exceptions.RopeError("Project root exists and" " is not a directory") if fscommands is None: fscommands = rope.base.fscommands.create_fscommands(self._address) super(Project, self).__init__(fscommands) self.ignored = _ResourceMatcher() self.file_list = _FileListCacher(self) - self.prefs.add_callback('ignored_resources', self.ignored.set_patterns) + self.prefs.add_callback("ignored_resources", self.ignored.set_patterns) if ropefolder is not None: - self.prefs['ignored_resources'] = [ropefolder] + self.prefs["ignored_resources"] = [ropefolder] self._init_prefs(prefs) self._init_source_folders() - @utils.deprecated('Delete once deprecated functions are gone') + @utils.deprecated("Delete once deprecated functions are gone") def _init_source_folders(self): - for path in self.prefs.get('source_folders', []): + for path in self.prefs.get("source_folders", []): folder = self.get_resource(path) self._custom_source_folders.append(folder) @@ -234,18 +232,21 @@ def get_files(self): def get_python_files(self): """Returns all python files available in the project""" - return [resource for resource in self.get_files() - if self.pycore.is_python_file(resource)] + return [ + resource + for resource in self.get_files() + if self.pycore.is_python_file(resource) + ] def _get_resource_path(self, name): - return os.path.join(self._address, *name.split('/')) + return os.path.join(self._address, *name.split("/")) def _init_ropefolder(self): if self.ropefolder is not None: if not self.ropefolder.exists(): self._create_recursively(self.ropefolder) - if not self.ropefolder.has_child('config.py'): - config = self.ropefolder.create_file('config.py') + if not self.ropefolder.has_child("config.py"): + config = self.ropefolder.create_file("config.py") config.write(self._default_config()) def _create_recursively(self, folder): @@ -256,27 +257,32 @@ def _create_recursively(self, folder): def _init_prefs(self, prefs): run_globals = {} if self.ropefolder is not None: - config = self.get_file(self.ropefolder.path + '/config.py') - run_globals.update({'__name__': '__main__', - '__builtins__': __builtins__, - '__file__': config.real_path}) + config = self.get_file(self.ropefolder.path + "/config.py") + run_globals.update( + { + "__name__": "__main__", + "__builtins__": __builtins__, + "__file__": config.real_path, + } + ) if config.exists(): - config = self.ropefolder.get_child('config.py') + config = self.ropefolder.get_child("config.py") pycompat.execfile(config.real_path, run_globals) else: exec(self._default_config(), run_globals) - if 'set_prefs' in run_globals: - run_globals['set_prefs'](self.prefs) + if "set_prefs" in run_globals: + run_globals["set_prefs"](self.prefs) for key, value in prefs.items(): self.prefs[key] = value self._init_other_parts() self._init_ropefolder() - if 'project_opened' in run_globals: - run_globals['project_opened'](self) + if "project_opened" in run_globals: + run_globals["project_opened"](self) def _default_config(self): import rope.base.default_config import inspect + return inspect.getsource(rope.base.default_config) def _init_other_parts(self): @@ -308,7 +314,7 @@ def validate(self, folder=None): folder = self.root super(Project, self).validate(folder) - root = property(lambda self: self.get_resource('')) + root = property(lambda self: self.get_resource("")) address = property(lambda self: self._address) @@ -323,11 +329,11 @@ def __init__(self): super(NoProject, self).__init__(fscommands) def _get_resource_path(self, name): - real_name = name.replace('/', os.path.sep) + real_name = name.replace("/", os.path.sep) return _realpath(real_name) def get_resource(self, name): - universal_name = _realpath(name).replace(os.path.sep, '/') + universal_name = _realpath(name).replace(os.path.sep, "/") return super(NoProject, self).get_resource(universal_name) def get_files(self): @@ -346,13 +352,12 @@ def get_no_project(): class _FileListCacher(object): - def __init__(self, project): self.project = project self.files = None rawobserver = resourceobserver.ResourceObserver( - self._changed, self._invalid, self._invalid, - self._invalid, self._invalid) + self._changed, self._invalid, self._invalid, self._invalid, self._invalid + ) self.project.add_observer(rawobserver) def get_files(self): @@ -377,7 +382,6 @@ def _invalid(self, resource, new_resource=None): class _DataFiles(object): - def __init__(self, project): self.project = project self.hooks = [] @@ -391,7 +395,7 @@ def read_data(self, name, compress=False, import_=False): if not compress and import_: self._import_old_files(name) if file.exists(): - input = opener(file.real_path, 'rb') + input = opener(file.real_path, "rb") try: result = [] try: @@ -411,7 +415,7 @@ def write_data(self, name, data, compress=False): compress = compress and self._can_compress() file = self._get_file(name, compress) opener = self._get_opener(compress) - output = opener(file.real_path, 'wb') + output = opener(file.real_path, "wb") try: pickle.dump(data, output, 2) finally: @@ -427,12 +431,13 @@ def write(self): def _can_compress(self): try: import gzip # noqa + return True except ImportError: return False def _import_old_files(self, name): - old = self._get_file(name + '.pickle', False) + old = self._get_file(name + ".pickle", False) new = self._get_file(name, False) if old.exists() and not new.exists(): shutil.move(old.real_path, new.real_path) @@ -441,15 +446,16 @@ def _get_opener(self, compress): if compress: try: import gzip + return gzip.open except ImportError: pass return open def _get_file(self, name, compress): - path = self.project.ropefolder.path + '/' + name + path = self.project.ropefolder.path + "/" + name if compress: - path += '.gz' + path += ".gz" return self.project.get_file(path) @@ -465,10 +471,10 @@ def _realpath(path): """ # there is a bug in cygwin for os.path.abspath() for abs paths - if sys.platform == 'cygwin': - if path[1:3] == ':\\': + if sys.platform == "cygwin": + if path[1:3] == ":\\": return path - elif path[1:3] == ':/': + elif path[1:3] == ":/": path = "/cygdrive/" + path[0] + path[2:] return os.path.abspath(os.path.expanduser(path)) return os.path.realpath(os.path.abspath(os.path.expanduser(path))) @@ -476,16 +482,20 @@ def _realpath(path): def _find_module_in_folder(folder, modname): module = folder - packages = modname.split('.') + packages = modname.split(".") for pkg in packages[:-1]: if module.is_folder() and module.has_child(pkg): module = module.get_child(pkg) else: return None if module.is_folder(): - if module.has_child(packages[-1]) and \ - module.get_child(packages[-1]).is_folder(): + if ( + module.has_child(packages[-1]) + and module.get_child(packages[-1]).is_folder() + ): return module.get_child(packages[-1]) - elif module.has_child(packages[-1] + '.py') and \ - not module.get_child(packages[-1] + '.py').is_folder(): - return module.get_child(packages[-1] + '.py') + elif ( + module.has_child(packages[-1] + ".py") + and not module.get_child(packages[-1] + ".py").is_folder() + ): + return module.get_child(packages[-1] + ".py") diff --git a/rope/base/pycore.py b/rope/base/pycore.py index c4c1195a4..352a509b3 100644 --- a/rope/base/pycore.py +++ b/rope/base/pycore.py @@ -19,7 +19,6 @@ class PyCore(object): - def __init__(self, project): self.project = project self._init_resource_observer() @@ -32,7 +31,7 @@ def __init__(self, project): def _init_python_files(self): self.python_matcher = None - patterns = self.project.prefs.get('python_files', None) + patterns = self.project.prefs.get("python_files", None) if patterns is not None: self.python_matcher = rope.base.resources._ResourceMatcher() self.python_matcher.set_patterns(patterns) @@ -40,9 +39,9 @@ def _init_python_files(self): def _init_resource_observer(self): callback = self._invalidate_resource_cache observer = rope.base.resourceobserver.ResourceObserver( - changed=callback, moved=callback, removed=callback) - self.observer = \ - rope.base.resourceobserver.FilteredResourceObserver(observer) + changed=callback, moved=callback, removed=callback + ) + self.observer = rope.base.resourceobserver.FilteredResourceObserver(observer) self.project.add_observer(self.observer) def _init_automatic_soa(self): @@ -50,17 +49,17 @@ def _init_automatic_soa(self): return callback = self._file_changed_for_soa observer = rope.base.resourceobserver.ResourceObserver( - changed=callback, moved=callback, removed=callback) + changed=callback, moved=callback, removed=callback + ) self.project.add_observer(observer) @property def automatic_soa(self): - auto_soa = self.project.prefs.get('automatic_soi', None) - return self.project.prefs.get('automatic_soa', auto_soa) + auto_soa = self.project.prefs.get("automatic_soi", None) + return self.project.prefs.get("automatic_soa", auto_soa) def _file_changed_for_soa(self, resource, new_resource=None): - old_contents = self.project.history.\ - contents_before_current_change(resource) + old_contents = self.project.history.contents_before_current_change(resource) if old_contents is not None: perform_soa_on_changed_scopes(self.project, resource, old_contents) @@ -68,10 +67,10 @@ def is_python_file(self, resource): if resource.is_folder(): return False if self.python_matcher is None: - return resource.name.endswith('.py') + return resource.name.endswith(".py") return self.python_matcher.does_match(resource) - @utils.deprecated('Use `project.get_module` instead') + @utils.deprecated("Use `project.get_module` instead") def get_module(self, name, folder=None): """Returns a `PyObject` if the module was found.""" return self.project.get_module(name, folder) @@ -79,20 +78,20 @@ def get_module(self, name, folder=None): def _builtin_submodules(self, modname): result = {} for extension in self.extension_modules: - if extension.startswith(modname + '.'): - name = extension[len(modname) + 1:] - if '.' not in name: + if extension.startswith(modname + "."): + name = extension[len(modname) + 1 :] + if "." not in name: result[name] = self.builtin_module(extension) return result def builtin_module(self, name): return self.extension_cache.get_pymodule(name) - @utils.deprecated('Use `project.get_relative_module` instead') + @utils.deprecated("Use `project.get_relative_module` instead") def get_relative_module(self, name, folder, level): return self.project.get_relative_module(name, folder, level) - @utils.deprecated('Use `libutils.get_string_module` instead') + @utils.deprecated("Use `libutils.get_string_module` instead") def get_string_module(self, code, resource=None, force_errors=False): """Returns a `PyObject` object for the given code @@ -103,7 +102,7 @@ def get_string_module(self, code, resource=None, force_errors=False): """ return PyModule(self, code, resource, force_errors=force_errors) - @utils.deprecated('Use `libutils.get_string_scope` instead') + @utils.deprecated("Use `libutils.get_string_scope` instead") def get_string_scope(self, code, resource=None): """Returns a `Scope` object for the given code""" return rope.base.libutils.get_string_scope(code, resource) @@ -112,11 +111,11 @@ def _invalidate_resource_cache(self, resource, new_resource=None): for observer in self.cache_observers: observer(resource) - @utils.deprecated('Use `project.get_python_path_folders` instead') + @utils.deprecated("Use `project.get_python_path_folders` instead") def get_python_path_folders(self): return self.project.get_python_path_folders() - @utils.deprecated('Use `project.find_module` instead') + @utils.deprecated("Use `project.find_module` instead") def find_module(self, modname, folder=None): """Returns a resource corresponding to the given module @@ -124,7 +123,7 @@ def find_module(self, modname, folder=None): """ return self.project.find_module(modname, folder) - @utils.deprecated('Use `project.find_relative_module` instead') + @utils.deprecated("Use `project.find_relative_module` instead") def find_relative_module(self, modname, folder, level): return self.project.find_relative_module(modname, folder, level) @@ -133,7 +132,7 @@ def find_relative_module(self, modname, folder, level): # packages, that is most of the time # - We need a separate resource observer; `self.observer` # does not get notified about module and folder creations - @utils.deprecated('Use `project.get_source_folders` instead') + @utils.deprecated("Use `project.get_source_folders` instead") def get_source_folders(self): """Returns project source folders""" return self.project.get_source_folders() @@ -141,14 +140,16 @@ def get_source_folders(self): def resource_to_pyobject(self, resource, force_errors=False): return self.module_cache.get_pymodule(resource, force_errors) - @utils.deprecated('Use `project.get_python_files` instead') + @utils.deprecated("Use `project.get_python_files` instead") def get_python_files(self): """Returns all python files available in the project""" return self.project.get_python_files() def _is_package(self, folder): - if folder.has_child('__init__.py') and \ - not folder.get_child('__init__.py').is_folder(): + if ( + folder.has_child("__init__.py") + and not folder.get_child("__init__.py").is_folder() + ): return True else: return False @@ -159,7 +160,7 @@ def _find_source_folders(self, folder): return [folder] result = [] for resource in folder.get_files(): - if resource.name.endswith('.py'): + if resource.name.endswith(".py"): result.append(folder) break for resource in folder.get_folders(): @@ -173,19 +174,25 @@ def run_module(self, resource, args=None, stdin=None, stdout=None): controlling the process. """ - perform_doa = self.project.prefs.get('perform_doi', True) - perform_doa = self.project.prefs.get('perform_doa', perform_doa) + perform_doa = self.project.prefs.get("perform_doi", True) + perform_doa = self.project.prefs.get("perform_doa", perform_doa) receiver = self.object_info.doa_data_received if not perform_doa: receiver = None runner = rope.base.oi.doa.PythonFileRunner( - self, resource, args, stdin, stdout, receiver) + self, resource, args, stdin, stdout, receiver + ) runner.add_finishing_observer(self.module_cache.forget_all_data) runner.run() return runner - def analyze_module(self, resource, should_analyze=lambda py: True, - search_subscopes=lambda py: True, followed_calls=None): + def analyze_module( + self, + resource, + should_analyze=lambda py: True, + search_subscopes=lambda py: True, + followed_calls=None, + ): """Analyze `resource` module for static object inference This function forces rope to analyze this module to collect @@ -203,35 +210,36 @@ def analyze_module(self, resource, should_analyze=lambda py: True, project config. """ if followed_calls is None: - followed_calls = self.project.prefs.get('soa_followed_calls', 0) + followed_calls = self.project.prefs.get("soa_followed_calls", 0) pymodule = self.resource_to_pyobject(resource) self.module_cache.forget_all_data() rope.base.oi.soa.analyze_module( - self, pymodule, should_analyze, search_subscopes, followed_calls) + self, pymodule, should_analyze, search_subscopes, followed_calls + ) def get_classes(self, task_handle=taskhandle.NullTaskHandle()): - warnings.warn('`PyCore.get_classes()` is deprecated', - DeprecationWarning, stacklevel=2) + warnings.warn( + "`PyCore.get_classes()` is deprecated", DeprecationWarning, stacklevel=2 + ) return [] def __str__(self): return str(self.module_cache) + str(self.object_info) - @utils.deprecated('Use `libutils.modname` instead') + @utils.deprecated("Use `libutils.modname` instead") def modname(self, resource): return rope.base.libutils.modname(resource) @property @utils.cacheit def extension_modules(self): - result = set(self.project.prefs.get('extension_modules', [])) - if self.project.prefs.get('import_dynload_stdmods', False): + result = set(self.project.prefs.get("extension_modules", [])) + if self.project.prefs.get("import_dynload_stdmods", False): result.update(stdmods.dynload_modules()) return result class _ModuleCache(object): - def __init__(self, pycore): self.pycore = pycore self.module_map = {} @@ -248,11 +256,9 @@ def get_pymodule(self, resource, force_errors=False): if resource in self.module_map: return self.module_map[resource] if resource.is_folder(): - result = PyPackage(self.pycore, resource, - force_errors=force_errors) + result = PyPackage(self.pycore, resource, force_errors=force_errors) else: - result = PyModule(self.pycore, resource=resource, - force_errors=force_errors) + result = PyModule(self.pycore, resource=resource, force_errors=force_errors) if result.has_errors: return result self.module_map[resource] = result @@ -264,17 +270,16 @@ def forget_all_data(self): pymodule._forget_concluded_data() def __str__(self): - return 'PyCore caches %d PyModules\n' % len(self.module_map) + return "PyCore caches %d PyModules\n" % len(self.module_map) class _ExtensionCache(object): - def __init__(self, pycore): self.pycore = pycore self.extensions = {} def get_pymodule(self, name): - if name == '__builtin__': + if name == "__builtin__": return builtins.builtins allowed = self.pycore.extension_modules if name not in self.extensions and name in allowed: @@ -299,13 +304,13 @@ def should_analyze(pydefined): start = scope.get_start() end = scope.get_end() return detector.consume_changes(start, end) + pycore.analyze_module(resource, should_analyze, search_subscopes) except exceptions.ModuleSyntaxError: pass class _TextChangeDetector(object): - def __init__(self, old, new): self.old = old self.new = new @@ -315,11 +320,12 @@ def _set_diffs(self): differ = difflib.Differ() self.lines = [] lineno = 0 - for line in differ.compare(self.old.splitlines(True), - self.new.splitlines(True)): - if line.startswith(' '): + for line in differ.compare( + self.old.splitlines(True), self.new.splitlines(True) + ): + if line.startswith(" "): lineno += 1 - elif line.startswith('-'): + elif line.startswith("-"): lineno += 1 self.lines.append(lineno) diff --git a/rope/base/pynames.py b/rope/base/pynames.py index b50b4e151..285ee38f8 100644 --- a/rope/base/pynames.py +++ b/rope/base/pynames.py @@ -13,7 +13,6 @@ def get_definition_location(self): class DefinedName(PyName): - def __init__(self, pyobject): self.pyobject = pyobject @@ -21,7 +20,9 @@ def get_object(self): return self.pyobject def get_definition_location(self): - lineno = utils.guess_def_lineno(self.pyobject.get_module(), self.pyobject.get_ast()) + lineno = utils.guess_def_lineno( + self.pyobject.get_module(), self.pyobject.get_ast() + ) return (self.pyobject.get_module(), lineno) @@ -30,7 +31,6 @@ class AssignedName(PyName): class UnboundName(PyName): - def __init__(self, pyobject=None): self.pyobject = pyobject if self.pyobject is None: @@ -46,8 +46,9 @@ def get_definition_location(self): class AssignmentValue(object): """An assigned expression""" - def __init__(self, ast_node, levels=None, evaluation='', - assign_type=False, type_hint=None): + def __init__( + self, ast_node, levels=None, evaluation="", assign_type=False, type_hint=None + ): """The `level` is `None` for simple assignments and is a list of numbers for tuple assignments for example in:: @@ -95,9 +96,7 @@ class ParameterName(PyName): class ImportedModule(PyName): - - def __init__(self, importing_module, module_name=None, - level=0, resource=None): + def __init__(self, importing_module, module_name=None, level=0, resource=None): self.importing_module = importing_module self.module_name = module_name self.level = level @@ -119,11 +118,12 @@ def _get_pymodule(self): try: if self.level == 0: pymodule = pycore.project.get_module( - self.module_name, self._current_folder()) + self.module_name, self._current_folder() + ) else: pymodule = pycore.project.get_relative_module( - self.module_name, self._current_folder(), - self.level) + self.module_name, self._current_folder(), self.level + ) self.pymodule.set(pymodule) except exceptions.ModuleNotFoundError: pass @@ -142,7 +142,6 @@ def get_definition_location(self): class ImportedName(PyName): - def __init__(self, imported_module, imported_name): self.imported_module = imported_module self.imported_name = imported_name @@ -172,12 +171,10 @@ def _get_concluded_data(module): def _circular_inference(): - raise rope.base.pyobjects.IsBeingInferredError( - 'Circular Object Inference') + raise rope.base.pyobjects.IsBeingInferredError("Circular Object Inference") class _Inferred(object): - def __init__(self, get_inferred, concluded=None): self.get_inferred = get_inferred self.concluded = concluded diff --git a/rope/base/pynamesdef.py b/rope/base/pynamesdef.py index 41eb02ff9..ffbfb8e3c 100644 --- a/rope/base/pynamesdef.py +++ b/rope/base/pynamesdef.py @@ -4,13 +4,13 @@ class AssignedName(pynames.AssignedName): - def __init__(self, lineno=None, module=None, pyobject=None): self.lineno = lineno self.module = module self.assignments = [] - self.pyobject = _Inferred(self._get_inferred, - pynames._get_concluded_data(module)) + self.pyobject = _Inferred( + self._get_inferred, pynames._get_concluded_data(module) + ) self.pyobject.set(pyobject) @utils.prevent_recursion(lambda: None) @@ -36,7 +36,6 @@ def invalidate(self): class ParameterName(pynames.ParameterName): - def __init__(self, pyfunction, index): self.pyfunction = pyfunction self.index = index @@ -49,10 +48,10 @@ def get_object(self): def get_objects(self): """Returns the list of objects passed as this parameter""" - return rope.base.oi.soi.get_passed_objects( - self.pyfunction, self.index) + return rope.base.oi.soi.get_passed_objects(self.pyfunction, self.index) def get_definition_location(self): return (self.pyfunction.get_module(), self.pyfunction.get_ast().lineno) + _Inferred = pynames._Inferred diff --git a/rope/base/pyobjects.py b/rope/base/pyobjects.py index 912d79825..7a0e4d013 100644 --- a/rope/base/pyobjects.py +++ b/rope/base/pyobjects.py @@ -3,7 +3,6 @@ class PyObject(object): - def __init__(self, type_): if type_ is None: type_ = self @@ -16,8 +15,7 @@ def get_attributes(self): def get_attribute(self, name): if name not in self.get_attributes(): - raise exceptions.AttributeNotFoundError( - 'Attribute %s not found' % name) + raise exceptions.AttributeNotFoundError("Attribute %s not found" % name) return self.get_attributes()[name] def get_type(self): @@ -72,10 +70,10 @@ def _get_base_type(name): if PyObject._types is None: PyObject._types = {} base_type = PyObject(None) - PyObject._types['Type'] = base_type - PyObject._types['Module'] = PyObject(base_type) - PyObject._types['Function'] = PyObject(base_type) - PyObject._types['Unknown'] = PyObject(base_type) + PyObject._types["Type"] = base_type + PyObject._types["Module"] = PyObject(base_type) + PyObject._types["Function"] = PyObject(base_type) + PyObject._types["Unknown"] = PyObject(base_type) return PyObject._types[name] @@ -114,14 +112,13 @@ def get_unknown(): """ if PyObject._unknown is None: - PyObject._unknown = PyObject(get_base_type('Unknown')) + PyObject._unknown = PyObject(get_base_type("Unknown")) return PyObject._unknown class AbstractClass(PyObject): - def __init__(self): - super(AbstractClass, self).__init__(get_base_type('Type')) + super(AbstractClass, self).__init__(get_base_type("Type")) def get_name(self): pass @@ -134,9 +131,8 @@ def get_superclasses(self): class AbstractFunction(PyObject): - def __init__(self): - super(AbstractFunction, self).__init__(get_base_type('Function')) + super(AbstractFunction, self).__init__(get_base_type("Function")) def get_name(self): pass @@ -152,9 +148,8 @@ def get_returned_object(self, args): class AbstractModule(PyObject): - def __init__(self, doc=None): - super(AbstractModule, self).__init__(get_base_type('Module')) + super(AbstractModule, self).__init__(get_base_type("Module")) def get_doc(self): pass @@ -203,8 +198,7 @@ def get_attribute(self, name): return self._get_structural_attributes()[name] if name in self._get_concluded_attributes(): return self._get_concluded_attributes()[name] - raise exceptions.AttributeNotFoundError('Attribute %s not found' % - name) + raise exceptions.AttributeNotFoundError("Attribute %s not found" % name) def get_scope(self): if self.scope is None: @@ -220,8 +214,7 @@ def get_module(self): def get_doc(self): if len(self.get_ast().body) > 0: expr = self.get_ast().body[0] - if isinstance(expr, ast.Expr) and \ - isinstance(expr.value, ast.Str): + if isinstance(expr, ast.Expr) and isinstance(expr.value, ast.Str): docstring = expr.value.s coding = self.get_module().coding return _decode_data(docstring, coding) @@ -259,7 +252,6 @@ class PyClass(PyDefinedObject, AbstractClass): class _ConcludedData(object): - def __init__(self): self.data_ = None @@ -275,11 +267,10 @@ def _invalidate(self): self.data = None def __str__(self): - return '<' + str(self.data) + '>' + return "<" + str(self.data) + ">" class _PyModule(PyDefinedObject, AbstractModule): - def __init__(self, pycore, ast_node, resource): self.resource = resource self.concluded_data = [] diff --git a/rope/base/pyobjectsdef.py b/rope/base/pyobjectsdef.py index 22946ebce..1fe9c4b21 100644 --- a/rope/base/pyobjectsdef.py +++ b/rope/base/pyobjectsdef.py @@ -4,8 +4,16 @@ import rope.base.libutils import rope.base.oi.soi import rope.base.pyscopes -from rope.base import (pynamesdef as pynames, exceptions, ast, - astutils, pyobjects, fscommands, arguments, utils) +from rope.base import ( + pynamesdef as pynames, + exceptions, + ast, + astutils, + pyobjects, + fscommands, + arguments, + utils, +) from rope.base.utils import pycompat try: @@ -15,14 +23,13 @@ class PyFunction(pyobjects.PyFunction): - def __init__(self, pycore, ast_node, parent): rope.base.pyobjects.AbstractFunction.__init__(self) - rope.base.pyobjects.PyDefinedObject.__init__( - self, pycore, ast_node, parent) + rope.base.pyobjects.PyDefinedObject.__init__(self, pycore, ast_node, parent) self.arguments = self.ast_node.args self.parameter_pyobjects = pynames._Inferred( - self._infer_parameters, self.get_module()._get_concluded_data()) + self._infer_parameters, self.get_module()._get_concluded_data() + ) self.returned = pynames._Inferred(self._infer_returned) self.parameter_pynames = None @@ -33,8 +40,7 @@ def _create_concluded_attributes(self): return {} def _create_scope(self): - return rope.base.pyscopes.FunctionScope(self.pycore, self, - _FunctionVisitor) + return rope.base.pyscopes.FunctionScope(self.pycore, self, _FunctionVisitor) def _infer_parameters(self): pyobjects = rope.base.oi.soi.infer_parameter_objects(self) @@ -77,8 +83,11 @@ def get_name(self): def get_param_names(self, special_args=True): # TODO: handle tuple parameters - result = [pycompat.get_ast_arg_arg(node) for node in self.arguments.args - if isinstance(node, pycompat.ast_arg_type)] + result = [ + pycompat.get_ast_arg_arg(node) + for node in self.arguments.args + if isinstance(node, pycompat.ast_arg_type) + ] if special_args: if self.arguments.vararg: result.append(pycompat.get_ast_arg_arg(self.arguments.vararg)) @@ -97,28 +106,26 @@ def get_kind(self): if isinstance(self.parent, PyClass): for decorator in self.decorators: pyname = rope.base.evaluate.eval_node(scope, decorator) - if pyname == rope.base.builtins.builtins['staticmethod']: - return 'staticmethod' - if pyname == rope.base.builtins.builtins['classmethod']: - return 'classmethod' - return 'method' - return 'function' + if pyname == rope.base.builtins.builtins["staticmethod"]: + return "staticmethod" + if pyname == rope.base.builtins.builtins["classmethod"]: + return "classmethod" + return "method" + return "function" @property def decorators(self): try: - return getattr(self.ast_node, 'decorator_list') + return getattr(self.ast_node, "decorator_list") except AttributeError: - return getattr(self.ast_node, 'decorators', None) + return getattr(self.ast_node, "decorators", None) class PyClass(pyobjects.PyClass): - def __init__(self, pycore, ast_node, parent): self.visitor_class = _ClassVisitor rope.base.pyobjects.AbstractClass.__init__(self) - rope.base.pyobjects.PyDefinedObject.__init__( - self, pycore, ast_node, parent) + rope.base.pyobjects.PyDefinedObject.__init__(self, pycore, ast_node, parent) self.parent = parent self._superclasses = self.get_module()._get_concluded_data() @@ -139,12 +146,13 @@ def _create_concluded_attributes(self): def _get_bases(self): result = [] for base_name in self.ast_node.bases: - base = rope.base.evaluate.eval_node(self.parent.get_scope(), - base_name) - if base is not None and \ - base.get_object().get_type() == \ - rope.base.pyobjects.get_base_type('Type'): - result.append(base.get_object()) + base = rope.base.evaluate.eval_node(self.parent.get_scope(), base_name) + if ( + base is not None + and base.get_object().get_type() + == rope.base.pyobjects.get_base_type("Type") + ): + result.append(base.get_object()) return result def _create_scope(self): @@ -152,10 +160,8 @@ def _create_scope(self): class PyModule(pyobjects.PyModule): - - def __init__(self, pycore, source=None, - resource=None, force_errors=False): - ignore = pycore.project.prefs.get('ignore_syntax_errors', False) + def __init__(self, pycore, source=None, resource=None, force_errors=False): + ignore = pycore.project.prefs.get("ignore_syntax_errors", False) syntax_errors = force_errors or not ignore self.has_errors = False try: @@ -165,8 +171,8 @@ def __init__(self, pycore, source=None, if syntax_errors: raise else: - source = '\n' - node = ast.parse('\n') + source = "\n" + node = ast.parse("\n") self.source_code = source self.star_imports = [] self.visitor_class = _GlobalVisitor @@ -174,7 +180,7 @@ def __init__(self, pycore, source=None, super(PyModule, self).__init__(pycore, node, resource) def _init_source(self, pycore, source_code, resource): - filename = 'string' + filename = "string" if resource: filename = resource.path try: @@ -190,7 +196,7 @@ def _init_source(self, pycore, source_code, resource): except SyntaxError as e: raise exceptions.ModuleSyntaxError(filename, e.lineno, e.msg) except UnicodeDecodeError as e: - raise exceptions.ModuleSyntaxError(filename, 1, '%s' % (e.reason)) + raise exceptions.ModuleSyntaxError(filename, 1, "%s" % (e.reason)) return source_code, ast_node @utils.prevent_recursion(lambda: {}) @@ -218,16 +224,17 @@ def logical_lines(self): def get_name(self): return rope.base.libutils.modname(self.get_resource()) -class PyPackage(pyobjects.PyPackage): +class PyPackage(pyobjects.PyPackage): def __init__(self, pycore, resource=None, force_errors=False): self.resource = resource init_dot_py = self._get_init_dot_py() if init_dot_py is not None: ast_node = pycore.project.get_pymodule( - init_dot_py, force_errors=force_errors).get_ast() + init_dot_py, force_errors=force_errors + ).get_ast() else: - ast_node = ast.parse('\n') + ast_node = ast.parse("\n") super(PyPackage, self).__init__(pycore, ast_node, resource) def _create_structural_attributes(self): @@ -255,16 +262,14 @@ def _get_child_resources(self): for child in self.resource.get_children(): if child.is_folder(): result[child.name] = child - elif child.name.endswith('.py') and \ - child.name != '__init__.py': + elif child.name.endswith(".py") and child.name != "__init__.py": name = child.name[:-3] result[name] = child return result def _get_init_dot_py(self): - if self.resource is not None and \ - self.resource.has_child('__init__.py'): - return self.resource.get_child('__init__.py') + if self.resource is not None and self.resource.has_child("__init__.py"): + return self.resource.get_child("__init__.py") else: return None @@ -279,7 +284,6 @@ def get_module(self): class _AnnAssignVisitor(object): - def __init__(self, scope_visitor): self.scope_visitor = scope_visitor self.assigned_ast = None @@ -295,9 +299,9 @@ def _assigned(self, name, assignment=None): self.scope_visitor._assigned(name, assignment) def _Name(self, node): - assignment = pynames.AssignmentValue(self.assigned_ast, - assign_type=True, - type_hint=self.type_hint) + assignment = pynames.AssignmentValue( + self.assigned_ast, assign_type=True, type_hint=self.type_hint + ) self._assigned(node.id, assignment) def _Tuple(self, node): @@ -329,7 +333,7 @@ def _assigned(self, name, assignment=None): self.scope_visitor._assigned(name, assignment) def _GeneratorExp(self, node): - for child in ['elt', 'key', 'value']: + for child in ["elt", "key", "value"]: if hasattr(node, child): ast.walk(getattr(node, child), self) for comp in node.generators: @@ -353,7 +357,6 @@ def _NamedExpr(self, node): class _AssignVisitor(object): - def __init__(self, scope_visitor): self.scope_visitor = scope_visitor self.assigned_ast = None @@ -392,7 +395,6 @@ def _Slice(self, node): class _ScopeVisitor(_ExpressionVisitor): - def __init__(self, pycore, owner_object): self.pycore = pycore self.owner_object = owner_object @@ -413,20 +415,23 @@ def _ClassDef(self, node): def _FunctionDef(self, node): pyfunction = PyFunction(self.pycore, node, self.owner_object) for decorator in pyfunction.decorators: - if isinstance(decorator, ast.Name) and decorator.id == 'property': + if isinstance(decorator, ast.Name) and decorator.id == "property": if isinstance(self, _ClassVisitor): type_ = rope.base.builtins.Property(pyfunction) arg = pynames.UnboundName( - rope.base.pyobjects.PyObject(self.owner_object)) + rope.base.pyobjects.PyObject(self.owner_object) + ) def _eval(type_=type_, arg=arg): return type_.get_property_object( - arguments.ObjectArguments([arg])) + arguments.ObjectArguments([arg]) + ) lineno = utils.guess_def_lineno(self.get_module(), node) self.names[node.name] = pynames.EvaluatedName( - _eval, module=self.get_module(), lineno=lineno) + _eval, module=self.get_module(), lineno=lineno + ) break else: self.names[node.name] = pynames.DefinedName(pyfunction) @@ -445,8 +450,9 @@ def _AugAssign(self, node): pass def _For(self, node): - names = self._update_evaluated(node.target, node.iter, # noqa - '.__iter__().next()') + names = self._update_evaluated( + node.target, node.iter, ".__iter__().next()" # noqa + ) for child in node.body + node.orelse: ast.walk(child, self) @@ -462,26 +468,28 @@ def _assigned(self, name, assignment): pyname.assignments.append(assignment) self.names[name] = pyname - def _update_evaluated(self, targets, assigned, - evaluation='', eval_type=False, type_hint=None): + def _update_evaluated( + self, targets, assigned, evaluation="", eval_type=False, type_hint=None + ): result = {} if isinstance(targets, str): - assignment = pynames.AssignmentValue(assigned, [], - evaluation, eval_type) + assignment = pynames.AssignmentValue(assigned, [], evaluation, eval_type) self._assigned(targets, assignment) else: names = astutils.get_name_levels(targets) for name, levels in names: - assignment = pynames.AssignmentValue(assigned, levels, - evaluation, eval_type) + assignment = pynames.AssignmentValue( + assigned, levels, evaluation, eval_type + ) self._assigned(name, assignment) return result def _With(self, node): for item in pycompat.get_ast_with_items(node): if item.optional_vars: - self._update_evaluated(item.optional_vars, - item.context_expr, '.__enter__()') + self._update_evaluated( + item.optional_vars, item.context_expr, ".__enter__()" + ) for child in node.body: ast.walk(child, self) @@ -506,15 +514,13 @@ def _Import(self, node): for import_pair in node.names: module_name = import_pair.name alias = import_pair.asname - first_package = module_name.split('.')[0] + first_package = module_name.split(".")[0] if alias is not None: - imported = pynames.ImportedModule(self.get_module(), - module_name) + imported = pynames.ImportedModule(self.get_module(), module_name) if not self._is_ignored_import(imported): self.names[alias] = imported else: - imported = pynames.ImportedModule(self.get_module(), - first_package) + imported = pynames.ImportedModule(self.get_module(), first_package) if not self._is_ignored_import(imported): self.names[first_package] = imported @@ -522,28 +528,28 @@ def _ImportFrom(self, node): level = 0 if node.level: level = node.level - imported_module = pynames.ImportedModule(self.get_module(), - node.module, level) + imported_module = pynames.ImportedModule(self.get_module(), node.module, level) if self._is_ignored_import(imported_module): return - if len(node.names) == 1 and node.names[0].name == '*': + if len(node.names) == 1 and node.names[0].name == "*": if isinstance(self.owner_object, PyModule): - self.owner_object.star_imports.append( - StarImport(imported_module)) + self.owner_object.star_imports.append(StarImport(imported_module)) else: for imported_name in node.names: imported = imported_name.name alias = imported_name.asname if alias is not None: imported = alias - self.names[imported] = pynames.ImportedName(imported_module, - imported_name.name) + self.names[imported] = pynames.ImportedName( + imported_module, imported_name.name + ) def _is_ignored_import(self, imported_module): - if not self.pycore.project.prefs.get('ignore_bad_imports', False): + if not self.pycore.project.prefs.get("ignore_bad_imports", False): return False - return not isinstance(imported_module.get_object(), - rope.base.pyobjects.AbstractModule) + return not isinstance( + imported_module.get_object(), rope.base.pyobjects.AbstractModule + ) def _Global(self, node): module = self.get_module() @@ -557,13 +563,11 @@ def _Global(self, node): class _GlobalVisitor(_ScopeVisitor): - def __init__(self, pycore, owner_object): super(_GlobalVisitor, self).__init__(pycore, owner_object) class _ClassVisitor(_ScopeVisitor): - def __init__(self, pycore, owner_object): super(_ClassVisitor, self).__init__(pycore, owner_object) @@ -580,7 +584,6 @@ def _FunctionDef(self, node): class _FunctionVisitor(_ScopeVisitor): - def __init__(self, pycore, owner_object): super(_FunctionVisitor, self).__init__(pycore, owner_object) self.returned_asts = [] @@ -597,7 +600,6 @@ def _Yield(self, node): class _ClassInitVisitor(_AssignVisitor): - def __init__(self, scope_visitor, self_name): super(_ClassInitVisitor, self).__init__(scope_visitor) self.self_name = self_name @@ -605,16 +607,17 @@ def __init__(self, scope_visitor, self_name): def _Attribute(self, node): if not isinstance(node.ctx, ast.Store): return - if isinstance(node.value, ast.Name) and \ - node.value.id == self.self_name: + if isinstance(node.value, ast.Name) and node.value.id == self.self_name: if node.attr not in self.scope_visitor.names: self.scope_visitor.names[node.attr] = pynames.AssignedName( - lineno=node.lineno, module=self.scope_visitor.get_module()) + lineno=node.lineno, module=self.scope_visitor.get_module() + ) if self.assigned_ast is not None: pyname = self.scope_visitor.names[node.attr] if isinstance(pyname, pynames.AssignedName): pyname.assignments.append( - pynames.AssignmentValue(self.assigned_ast)) + pynames.AssignmentValue(self.assigned_ast) + ) def _Tuple(self, node): if not isinstance(node.ctx, ast.Store): @@ -639,7 +642,6 @@ def _With(self, node): class StarImport(object): - def __init__(self, imported_module): self.imported_module = imported_module @@ -647,6 +649,6 @@ def get_names(self): result = {} imported = self.imported_module.get_object() for name in imported: - if not name.startswith('_'): + if not name.startswith("_"): result[name] = pynames.ImportedName(self.imported_module, name) return result diff --git a/rope/base/pyscopes.py b/rope/base/pyscopes.py index 0bed19a92..949f3109a 100644 --- a/rope/base/pyscopes.py +++ b/rope/base/pyscopes.py @@ -5,7 +5,6 @@ class Scope(object): - def __init__(self, pycore, pyobject, parent_scope): self.pycore = pycore self.pyobject = pyobject @@ -22,7 +21,7 @@ def get_defined_names(self): def get_name(self, name): """Return name `PyName` defined in this scope""" if name not in self.get_names(): - raise exceptions.NameNotFoundError('name %s not found' % name) + raise exceptions.NameNotFoundError("name %s not found" % name) return self.get_names()[name] def __getitem__(self, key): @@ -66,8 +65,9 @@ def _propagated_lookup(self, name): return None def _create_scopes(self): - return [pydefined.get_scope() - for pydefined in self.pyobject._get_defined_objects()] + return [ + pydefined.get_scope() for pydefined in self.pyobject._get_defined_objects() + ] def _get_global_scope(self): current = self @@ -102,7 +102,6 @@ def get_kind(self): class GlobalScope(Scope): - def __init__(self, pycore, module): super(GlobalScope, self).__init__(pycore, module, None) self.names = module._get_concluded_data() @@ -111,7 +110,7 @@ def get_start(self): return 1 def get_kind(self): - return 'Module' + return "Module" def get_name(self, name): try: @@ -119,7 +118,7 @@ def get_name(self, name): except exceptions.AttributeNotFoundError: if name in self.builtin_names: return self.builtin_names[name] - raise exceptions.NameNotFoundError('name %s not found' % name) + raise exceptions.NameNotFoundError("name %s not found" % name) def get_names(self): if self.names.get() is None: @@ -145,10 +144,10 @@ def builtin_names(self): class FunctionScope(Scope): - def __init__(self, pycore, pyobject, visitor): - super(FunctionScope, self).__init__(pycore, pyobject, - pyobject.parent.get_scope()) + super(FunctionScope, self).__init__( + pycore, pyobject, pyobject.parent.get_scope() + ) self.names = None self.returned_asts = None self.is_generator = None @@ -190,36 +189,34 @@ def _create_scopes(self): return [pydefined.get_scope() for pydefined in self.defineds] def get_kind(self): - return 'Function' + return "Function" def invalidate_data(self): for pyname in self.get_names().values(): - if isinstance(pyname, (rope.base.pynames.AssignedName, - rope.base.pynames.EvaluatedName)): + if isinstance( + pyname, + (rope.base.pynames.AssignedName, rope.base.pynames.EvaluatedName), + ): pyname.invalidate() class ClassScope(Scope): - def __init__(self, pycore, pyobject): - super(ClassScope, self).__init__(pycore, pyobject, - pyobject.parent.get_scope()) + super(ClassScope, self).__init__(pycore, pyobject, pyobject.parent.get_scope()) def get_kind(self): - return 'Class' + return "Class" def get_propagated_names(self): return {} class _HoldingScopeFinder(object): - def __init__(self, pymodule): self.pymodule = pymodule def get_indents(self, lineno): - return rope.base.codeanalyze.count_line_indents( - self.lines.get_line(lineno)) + return rope.base.codeanalyze.count_line_indents(self.lines.get_line(lineno)) def _get_scope_indents(self, scope): return self.get_indents(scope.get_start()) @@ -229,12 +226,15 @@ def get_holding_scope(self, module_scope, lineno, line_indents=None): line_indents = self.get_indents(lineno) current_scope = module_scope new_scope = current_scope - while new_scope is not None and \ - (new_scope.get_kind() == 'Module' or - self._get_scope_indents(new_scope) <= line_indents): + while new_scope is not None and ( + new_scope.get_kind() == "Module" + or self._get_scope_indents(new_scope) <= line_indents + ): current_scope = new_scope - if current_scope.get_start() == lineno and \ - current_scope.get_kind() != 'Module': + if ( + current_scope.get_start() == lineno + and current_scope.get_kind() != "Module" + ): return current_scope new_scope = None for scope in current_scope.get_scopes(): @@ -248,14 +248,13 @@ def get_holding_scope(self, module_scope, lineno, line_indents=None): def _is_empty_line(self, lineno): line = self.lines.get_line(lineno) - return line.strip() == '' or line.lstrip().startswith('#') + return line.strip() == "" or line.lstrip().startswith("#") def _get_body_indents(self, scope): return self.get_indents(scope.get_body_start()) def get_holding_scope_for_offset(self, scope, offset): - return self.get_holding_scope( - scope, self.lines.get_line_number(offset)) + return self.get_holding_scope(scope, self.lines.get_line_number(offset)) def find_scope_end(self, scope): if not scope.parent: @@ -268,7 +267,8 @@ def find_scope_end(self, scope): else: body_indents = self._get_body_indents(scope) for l in self.logical_lines.generate_starts( - min(end + 1, self.lines.length()), self.lines.length() + 1): + min(end + 1, self.lines.length()), self.lines.length() + 1 + ): if not self._is_empty_line(l): if self.get_indents(l) < body_indents: return end @@ -298,7 +298,8 @@ class TemporaryScope(Scope): def __init__(self, pycore, parent_scope, names): super(TemporaryScope, self).__init__( - pycore, parent_scope.pyobject, parent_scope) + pycore, parent_scope.pyobject, parent_scope + ) self.names = names def get_names(self): @@ -311,4 +312,4 @@ def _create_scopes(self): return [] def get_kind(self): - return 'Temporary' + return "Temporary" diff --git a/rope/base/resourceobserver.py b/rope/base/resourceobserver.py index 71d5f793e..4f16d5806 100644 --- a/rope/base/resourceobserver.py +++ b/rope/base/resourceobserver.py @@ -16,8 +16,9 @@ class ResourceObserver(object): """ - def __init__(self, changed=None, moved=None, created=None, - removed=None, validate=None): + def __init__( + self, changed=None, moved=None, created=None, removed=None, validate=None + ): self.changed = changed self.moved = moved self.created = created @@ -80,8 +81,7 @@ class FilteredResourceObserver(object): """ - def __init__(self, resource_observer, initial_resources=None, - timekeeper=None): + def __init__(self, resource_observer, initial_resources=None, timekeeper=None): self.observer = resource_observer self.resources = {} if timekeeper is not None: @@ -119,8 +119,7 @@ def _update_changes_caused_by_changed(self, changes, changed): if self._is_parent_changed(changed): changes.add_changed(changed.parent) - def _update_changes_caused_by_moved(self, changes, resource, - new_resource=None): + def _update_changes_caused_by_moved(self, changes, resource, new_resource=None): if resource in self.resources: changes.add_removed(resource, new_resource) if new_resource in self.resources: @@ -129,7 +128,8 @@ def _update_changes_caused_by_moved(self, changes, resource, for file in list(self.resources): if resource.contains(file): new_file = self._calculate_new_resource( - resource, new_resource, file) + resource, new_resource, file + ) changes.add_removed(file, new_file) if self._is_parent_changed(resource): changes.add_changed(resource.parent) @@ -190,13 +190,19 @@ def validate(self, resource): def _search_resource_creations(self, resource): creations = set() - if resource in self.resources and resource.exists() and \ - self.resources[resource] is None: + if ( + resource in self.resources + and resource.exists() + and self.resources[resource] is None + ): creations.add(resource) if resource.is_folder(): for file in self.resources: - if file.exists() and resource.contains(file) and \ - self.resources[file] is None: + if ( + file.exists() + and resource.contains(file) + and self.resources[file] is None + ): creations.add(file) return creations @@ -231,32 +237,30 @@ def _search_resource_changes(self, resource): def _is_changed(self, resource): if self.resources[resource] is None: return False - return self.resources[resource] != \ - self.timekeeper.get_indicator(resource) + return self.resources[resource] != self.timekeeper.get_indicator(resource) def _calculate_new_resource(self, main, new_main, resource): if new_main is None: return None - diff = resource.path[len(main.path):] + diff = resource.path[len(main.path) :] return resource.project.get_resource(new_main.path + diff) class ChangeIndicator(object): - def get_indicator(self, resource): """Return the modification time and size of a `Resource`.""" path = resource.real_path # on dos, mtime does not change for a folder when files are added - if os.name != 'posix' and os.path.isdir(path): - return (os.path.getmtime(path), - len(os.listdir(path)), - os.path.getsize(path)) - return (os.path.getmtime(path), - os.path.getsize(path)) + if os.name != "posix" and os.path.isdir(path): + return ( + os.path.getmtime(path), + len(os.listdir(path)), + os.path.getsize(path), + ) + return (os.path.getmtime(path), os.path.getsize(path)) class _Changes(object): - def __init__(self): self.changes = set() self.creations = set() diff --git a/rope/base/resources.py b/rope/base/resources.py index ad54dcf14..e8273d29e 100644 --- a/rope/base/resources.py +++ b/rope/base/resources.py @@ -43,13 +43,14 @@ def __init__(self, project, path): def move(self, new_location): """Move resource to `new_location`""" - self._perform_change(change.MoveResource(self, new_location), - 'Moving <%s> to <%s>' % (self.path, new_location)) + self._perform_change( + change.MoveResource(self, new_location), + "Moving <%s> to <%s>" % (self.path, new_location), + ) def remove(self): """Remove resource from the project""" - self._perform_change(change.RemoveResource(self), - 'Removing <%s>' % self.path) + self._perform_change(change.RemoveResource(self), "Removing <%s>" % self.path) def is_folder(self): """Return true if the resource is a folder""" @@ -62,7 +63,7 @@ def exists(self): @property def parent(self): - parent = '/'.join(self.path.split('/')[0:-1]) + parent = "/".join(self.path.split("/")[0:-1]) return self.project.get_folder(parent) @property @@ -77,7 +78,7 @@ def path(self): @property def name(self): """Return the name of this resource""" - return self.path.split('/')[-1] + return self.path.split("/")[-1] @property def real_path(self): @@ -113,7 +114,7 @@ def read(self): raise exceptions.ModuleDecodeError(self.path, e.reason) def read_bytes(self): - handle = open(self.real_path, 'rb') + handle = open(self.real_path, "rb") try: return handle.read() finally: @@ -125,8 +126,9 @@ def write(self, contents): return except IOError: pass - self._perform_change(change.ChangeContents(self, contents), - 'Writing file <%s>' % self.path) + self._perform_change( + change.ChangeContents(self, contents), "Writing file <%s>" % self.path + ) def is_folder(self): return False @@ -163,18 +165,20 @@ def get_children(self): def create_file(self, file_name): self._perform_change( change.CreateFile(self, file_name), - 'Creating file <%s>' % self._get_child_path(file_name)) + "Creating file <%s>" % self._get_child_path(file_name), + ) return self.get_child(file_name) def create_folder(self, folder_name): self._perform_change( change.CreateFolder(self, folder_name), - 'Creating folder <%s>' % self._get_child_path(folder_name)) + "Creating folder <%s>" % self._get_child_path(folder_name), + ) return self.get_child(folder_name) def _get_child_path(self, name): if self.path: - return self.path + '/' + name + return self.path + "/" + name else: return name @@ -189,24 +193,23 @@ def has_child(self, name): return False def get_files(self): - return [resource for resource in self.get_children() - if not resource.is_folder()] + return [ + resource for resource in self.get_children() if not resource.is_folder() + ] def get_folders(self): - return [resource for resource in self.get_children() - if resource.is_folder()] + return [resource for resource in self.get_children() if resource.is_folder()] def contains(self, resource): if self == resource: return False - return self.path == '' or resource.path.startswith(self.path + '/') + return self.path == "" or resource.path.startswith(self.path + "/") def create(self): self.parent.create_folder(self.name) class _ResourceMatcher(object): - def __init__(self): self.patterns = [] self._compiled_patterns = [] @@ -222,18 +225,20 @@ def set_patterns(self, patterns): self.patterns = patterns def _add_pattern(self, pattern): - re_pattern = pattern.replace('.', '\\.').\ - replace('*', '[^/]*').replace('?', '[^/]').\ - replace('//', '/(.*/)?') - re_pattern = '^(.*/)?' + re_pattern + '(/.*)?$' + re_pattern = ( + pattern.replace(".", "\\.") + .replace("*", "[^/]*") + .replace("?", "[^/]") + .replace("//", "/(.*/)?") + ) + re_pattern = "^(.*/)?" + re_pattern + "(/.*)?$" self.compiled_patterns.append(re.compile(re_pattern)) def does_match(self, resource): for pattern in self.compiled_patterns: if pattern.match(resource.path): return True - path = os.path.join(resource.project.address, - *resource.path.split('/')) + path = os.path.join(resource.project.address, *resource.path.split("/")) if os.path.islink(path): return True return False diff --git a/rope/base/simplify.py b/rope/base/simplify.py index bc4cade4a..4f89925d7 100644 --- a/rope/base/simplify.py +++ b/rope/base/simplify.py @@ -23,10 +23,10 @@ def real_code(source): """ collector = codeanalyze.ChangeCollector(source) for start, end in ignored_regions(source): - if source[start] == '#': - replacement = ' ' * (end - start) + if source[start] == "#": + replacement = " " * (end - start) else: - replacement = '"%s"' % (' ' * (end - start - 2)) + replacement = '"%s"' % (" " * (end - start - 2)) collector.add_change(start, end, replacement) source = collector.get_changed() or source collector = codeanalyze.ChangeCollector(source) @@ -34,14 +34,14 @@ def real_code(source): for match in _parens.finditer(source): i = match.start() c = match.group() - if c in '({[': + if c in "({[": parens += 1 - if c in ')}]': + if c in ")}]": parens -= 1 - if c == '\n' and parens > 0: - collector.add_change(i, i + 1, ' ') + if c == "\n" and parens > 0: + collector.add_change(i, i + 1, " ") source = collector.get_changed() or source - return source.replace('\\\n', ' ').replace('\t', ' ').replace(';', '\n') + return source.replace("\\\n", " ").replace("\t", " ").replace(";", "\n") @utils.cached(7) @@ -50,6 +50,7 @@ def ignored_regions(source): return [(match.start(), match.end()) for match in _str.finditer(source)] -_str = re.compile('%s|%s' % (codeanalyze.get_comment_pattern(), - codeanalyze.get_string_pattern())) -_parens = re.compile(r'[\({\[\]}\)\n]') +_str = re.compile( + "%s|%s" % (codeanalyze.get_comment_pattern(), codeanalyze.get_string_pattern()) +) +_parens = re.compile(r"[\({\[\]}\)\n]") diff --git a/rope/base/stdmods.py b/rope/base/stdmods.py index 7d5db1d10..c57e2e54b 100644 --- a/rope/base/stdmods.py +++ b/rope/base/stdmods.py @@ -9,10 +9,11 @@ def _stdlib_path(): if pycompat.PY2: from distutils import sysconfig - return sysconfig.get_python_lib(standard_lib=True, - plat_specific=True) + + return sysconfig.get_python_lib(standard_lib=True, plat_specific=True) elif pycompat.PY3: import inspect + return os.path.dirname(inspect.getsourcefile(inspect)) @@ -29,10 +30,10 @@ def python_modules(): for name in os.listdir(lib_path): path = os.path.join(lib_path, name) if os.path.isdir(path): - if '-' not in name: + if "-" not in name: result.add(name) else: - if name.endswith('.py'): + if name.endswith(".py"): result.add(name[:-3]) return result @@ -53,13 +54,13 @@ def normalize_so_name(name): @utils.cached(1) def dynload_modules(): result = set(sys.builtin_module_names) - dynload_path = os.path.join(_stdlib_path(), 'lib-dynload') + dynload_path = os.path.join(_stdlib_path(), "lib-dynload") if os.path.exists(dynload_path): for name in os.listdir(dynload_path): path = os.path.join(dynload_path, name) if os.path.isfile(path): - if name.endswith('.dll'): + if name.endswith(".dll"): result.add(normalize_so_name(name)) - if name.endswith('.so'): + if name.endswith(".so"): result.add(normalize_so_name(name)) return result diff --git a/rope/base/taskhandle.py b/rope/base/taskhandle.py index c1f01b984..4724492c7 100644 --- a/rope/base/taskhandle.py +++ b/rope/base/taskhandle.py @@ -2,8 +2,7 @@ class TaskHandle(object): - - def __init__(self, name='Task', interrupts=True): + def __init__(self, name="Task", interrupts=True): """Construct a TaskHandle If `interrupts` is `False` the task won't be interrupted by @@ -42,7 +41,7 @@ def is_stopped(self): def get_jobsets(self): return self.job_sets - def create_jobset(self, name='JobSet', count=None): + def create_jobset(self, name="JobSet", count=None): result = JobSet(self, name=name, count=count) self.job_sets.append(result) self._inform_observers() @@ -54,7 +53,6 @@ def _inform_observers(self): class JobSet(object): - def __init__(self, handle, name, count): self.handle = handle self.name = name @@ -90,7 +88,6 @@ def get_name(self): class NullTaskHandle(object): - def __init__(self): pass @@ -111,7 +108,6 @@ def add_observer(self, observer): class NullJobSet(object): - def started_job(self, name): pass diff --git a/rope/base/utils/__init__.py b/rope/base/utils/__init__.py index cc220bc0c..884d71c08 100644 --- a/rope/base/utils/__init__.py +++ b/rope/base/utils/__init__.py @@ -5,21 +5,24 @@ def saveit(func): """A decorator that caches the return value of a function""" - name = '_' + func.__name__ + name = "_" + func.__name__ def _wrapper(self, *args, **kwds): if not hasattr(self, name): setattr(self, name, func(self, *args, **kwds)) return getattr(self, name) + return _wrapper + cacheit = saveit def prevent_recursion(default): """A decorator that returns the return value of `default` in recursions""" + def decorator(func): - name = '_calling_%s_' % func.__name__ + name = "_calling_%s_" % func.__name__ def newfunc(self, *args, **kwds): if getattr(self, name, False): @@ -29,45 +32,54 @@ def newfunc(self, *args, **kwds): return func(self, *args, **kwds) finally: setattr(self, name, False) + return newfunc + return decorator def ignore_exception(exception_class): """A decorator that ignores `exception_class` exceptions""" + def _decorator(func): def newfunc(*args, **kwds): try: return func(*args, **kwds) except exception_class: pass + return newfunc + return _decorator def deprecated(message=None): """A decorator for deprecated functions""" + def _decorator(func, message=message): if message is None: - message = '%s is deprecated' % func.__name__ + message = "%s is deprecated" % func.__name__ def newfunc(*args, **kwds): warnings.warn(message, DeprecationWarning, stacklevel=2) return func(*args, **kwds) + return newfunc + return _decorator def cached(size): """A caching decorator based on parameter objects""" + def decorator(func): cached_func = _Cached(func, size) return lambda *a, **kw: cached_func(*a, **kw) + return decorator class _Cached(object): - def __init__(self, func, count): self.func = func self.cache = [] @@ -88,26 +100,29 @@ def __call__(self, *args, **kwds): def resolve(str_or_obj): """Returns object from string""" from rope.base.utils.pycompat import string_types + if not isinstance(str_or_obj, string_types): return str_or_obj - if '.' not in str_or_obj: - str_or_obj += '.' - mod_name, obj_name = str_or_obj.rsplit('.', 1) + if "." not in str_or_obj: + str_or_obj += "." + mod_name, obj_name = str_or_obj.rsplit(".", 1) __import__(mod_name) mod = sys.modules[mod_name] return getattr(mod, obj_name) if obj_name else mod + def guess_def_lineno(module, node): - """ Find the line number for a function or class definition. + """Find the line number for a function or class definition. - `node` may be either an ast.FunctionDef, ast.AsyncFunctionDef, or ast.ClassDef + `node` may be either an ast.FunctionDef, ast.AsyncFunctionDef, or ast.ClassDef - Python 3.8 simply provides this to us, but in earlier versions the ast - node.lineno points to the first decorator rather than the actual - definition, so we try our best to find where the definitions are. + Python 3.8 simply provides this to us, but in earlier versions the ast + node.lineno points to the first decorator rather than the actual + definition, so we try our best to find where the definitions are. - This is to workaround bpo-33211 (https://bugs.python.org/issue33211) + This is to workaround bpo-33211 (https://bugs.python.org/issue33211) """ + def is_inline_body(): # class Foo(object): # def inline_body(): pass @@ -124,5 +139,7 @@ def is_inline_body(): if sys.version_info >= (3, 8): return node.lineno - possible_def_line = node.body[0].lineno if is_inline_body() else node.body[0].lineno - 1 + possible_def_line = ( + node.body[0].lineno if is_inline_body() else node.body[0].lineno - 1 + ) return module.logical_lines.logical_line_in(possible_def_line)[0] diff --git a/rope/base/utils/datastructures.py b/rope/base/utils/datastructures.py index 4e39a9ba0..e9c0b60ee 100644 --- a/rope/base/utils/datastructures.py +++ b/rope/base/utils/datastructures.py @@ -8,7 +8,6 @@ class OrderedSet(MutableSet): - def __init__(self, iterable=None): self.end = end = [] end += [None, end, end] # sentinel @@ -54,15 +53,15 @@ def __reversed__(self): def pop(self, last=True): if not self: - raise KeyError('set is empty') + raise KeyError("set is empty") key = self.end[1][0] if last else self.end[2][0] self.discard(key) return key def __repr__(self): if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self)) + return "%s()" % (self.__class__.__name__,) + return "%s(%r)" % (self.__class__.__name__, list(self)) def __eq__(self, other): if isinstance(other, OrderedSet): diff --git a/rope/base/utils/pycompat.py b/rope/base/utils/pycompat.py index de7cf2e44..5da280fe0 100644 --- a/rope/base/utils/pycompat.py +++ b/rope/base/utils/pycompat.py @@ -1,5 +1,6 @@ import sys import ast + # from rope.base import ast PY2 = sys.version_info[0] == 2 @@ -15,25 +16,29 @@ str = str string_types = (str,) import builtins + ast_arg_type = ast.arg def execfile(fn, global_vars=None, local_vars=None): with open(fn) as f: - code = compile(f.read(), fn, 'exec') + code = compile(f.read(), fn, "exec") exec(code, global_vars or {}, local_vars) def get_ast_arg_arg(node): - if isinstance(node, string_types): # TODO: G21: Understand the Algorithm (Where it's used?) + if isinstance( + node, string_types + ): # TODO: G21: Understand the Algorithm (Where it's used?) return node return node.arg def get_ast_with_items(node): return node.items + else: # PY2 string_types = (basestring,) - builtins = __import__('__builtin__') + builtins = __import__("__builtin__") ast_arg_type = ast.Name execfile = execfile diff --git a/rope/base/worder.py b/rope/base/worder.py index 557c4d7e6..ab91fe152 100644 --- a/rope/base/worder.py +++ b/rope/base/worder.py @@ -31,7 +31,7 @@ def _init_ignores(self): def _context_call(self, name, offset): if self.handle_ignores: - if not hasattr(self, 'starts'): + if not hasattr(self, "starts"): self._init_ignores() start = bisect.bisect(self.starts, offset) if start > 0 and offset < self.ends[start - 1]: @@ -39,19 +39,19 @@ def _context_call(self, name, offset): return getattr(self.code_finder, name)(offset) def get_primary_at(self, offset): - return self._context_call('get_primary_at', offset) + return self._context_call("get_primary_at", offset) def get_word_at(self, offset): - return self._context_call('get_word_at', offset) + return self._context_call("get_word_at", offset) def get_primary_range(self, offset): - return self._context_call('get_primary_range', offset) + return self._context_call("get_primary_range", offset) def get_splitted_primary_before(self, offset): - return self._context_call('get_splitted_primary_before', offset) + return self._context_call("get_splitted_primary_before", offset) def get_word_range(self, offset): - return self._context_call('get_word_range', offset) + return self._context_call("get_word_range", offset) def is_function_keyword_parameter(self, offset): return self.code_finder.is_function_keyword_parameter(offset) @@ -124,7 +124,6 @@ def find_function_offset(self, offset): class _RealFinder(object): - def __init__(self, code, raw): self.code = code self.raw = raw @@ -142,15 +141,14 @@ def _find_word_end(self, offset): def _find_last_non_space_char(self, offset): while offset >= 0 and self.code[offset].isspace(): - if self.code[offset] == '\n': + if self.code[offset] == "\n": return offset offset -= 1 return max(-1, offset) def get_word_at(self, offset): offset = self._get_fixed_offset(offset) - return self.raw[self._find_word_start(offset): - self._find_word_end(offset) + 1] + return self.raw[self._find_word_start(offset) : self._find_word_end(offset) + 1] def _get_fixed_offset(self, offset): if offset >= len(self.code): @@ -163,7 +161,7 @@ def _get_fixed_offset(self, offset): return offset def _is_id_char(self, offset): - return self.code[offset].isalnum() or self.code[offset] == '_' + return self.code[offset].isalnum() or self.code[offset] == "_" def _find_string_start(self, offset): kind = self.code[offset] @@ -174,21 +172,21 @@ def _find_string_start(self, offset): def _find_parens_start(self, offset): offset = self._find_last_non_space_char(offset - 1) - while offset >= 0 and self.code[offset] not in '[({': - if self.code[offset] not in ':,': + while offset >= 0 and self.code[offset] not in "[({": + if self.code[offset] not in ":,": offset = self._find_primary_start(offset) offset = self._find_last_non_space_char(offset - 1) return offset def _find_atom_start(self, offset): old_offset = offset - if self.code[offset] == '\n': + if self.code[offset] == "\n": return offset + 1 if self.code[offset].isspace(): offset = self._find_last_non_space_char(offset) - if self.code[offset] in '\'"': + if self.code[offset] in "'\"": return self._find_string_start(offset) - if self.code[offset] in ')]}': + if self.code[offset] in ")]}": return self._find_parens_start(offset) if self._is_id_char(offset): return self._find_word_start(offset) @@ -203,33 +201,33 @@ def _find_primary_without_dot_start(self, offset): """ last_atom = offset offset = self._find_last_non_space_char(last_atom) - while offset > 0 and self.code[offset] in ')]': + while offset > 0 and self.code[offset] in ")]": last_atom = self._find_parens_start(offset) offset = self._find_last_non_space_char(last_atom - 1) - if offset >= 0 and (self.code[offset] in '"\'})]' or - self._is_id_char(offset)): + if offset >= 0 and (self.code[offset] in "\"'})]" or self._is_id_char(offset)): atom_start = self._find_atom_start(offset) - if not keyword.iskeyword(self.code[atom_start:offset + 1]) or \ - (offset + 1 < len(self.code) and self._is_id_char(offset + 1)): + if not keyword.iskeyword(self.code[atom_start : offset + 1]) or ( + offset + 1 < len(self.code) and self._is_id_char(offset + 1) + ): return atom_start return last_atom def _find_primary_start(self, offset): if offset >= len(self.code): offset = len(self.code) - 1 - if self.code[offset] != '.': + if self.code[offset] != ".": offset = self._find_primary_without_dot_start(offset) else: offset = offset + 1 while offset > 0: prev = self._find_last_non_space_char(offset - 1) - if offset <= 0 or self.code[prev] != '.': + if offset <= 0 or self.code[prev] != ".": break # Check if relative import # XXX: Looks like a hack... prev_word_end = self._find_last_non_space_char(prev - 1) - if self.code[prev_word_end-3:prev_word_end+1] == "from": + if self.code[prev_word_end - 3 : prev_word_end + 1] == "from": offset = prev break @@ -250,66 +248,67 @@ def get_splitted_primary_before(self, offset): This function is used in `rope.codeassist.assist` function. """ if offset == 0: - return ('', '', 0) + return ("", "", 0) end = offset - 1 word_start = self._find_atom_start(end) real_start = self._find_primary_start(end) - if self.code[word_start:offset].strip() == '': + if self.code[word_start:offset].strip() == "": word_start = end if self.code[end].isspace(): word_start = end - if self.code[real_start:word_start].strip() == '': + if self.code[real_start:word_start].strip() == "": real_start = word_start if real_start == word_start == end and not self._is_id_char(end): - return ('', '', offset) + return ("", "", offset) if real_start == word_start: - return ('', self.raw[word_start:offset], word_start) + return ("", self.raw[word_start:offset], word_start) else: - if self.code[end] == '.': - return (self.raw[real_start:end], '', offset) + if self.code[end] == ".": + return (self.raw[real_start:end], "", offset) last_dot_position = word_start - if self.code[word_start] != '.': - last_dot_position = \ - self._find_last_non_space_char(word_start - 1) - last_char_position = \ - self._find_last_non_space_char(last_dot_position - 1) + if self.code[word_start] != ".": + last_dot_position = self._find_last_non_space_char(word_start - 1) + last_char_position = self._find_last_non_space_char(last_dot_position - 1) if self.code[word_start].isspace(): word_start = offset - return (self.raw[real_start:last_char_position + 1], - self.raw[word_start:offset], word_start) + return ( + self.raw[real_start : last_char_position + 1], + self.raw[word_start:offset], + word_start, + ) def _get_line_start(self, offset): try: - return self.code.rindex('\n', 0, offset + 1) + return self.code.rindex("\n", 0, offset + 1) except ValueError: return 0 def _get_line_end(self, offset): try: - return self.code.index('\n', offset) + return self.code.index("\n", offset) except ValueError: return len(self.code) def is_name_assigned_in_class_body(self, offset): word_start = self._find_word_start(offset - 1) word_end = self._find_word_end(offset) + 1 - if '.' in self.code[word_start:word_end]: + if "." in self.code[word_start:word_end]: return False line_start = self._get_line_start(word_start) line = self.code[line_start:word_start].strip() - return not line and self.get_assignment_type(offset) == '=' + return not line and self.get_assignment_type(offset) == "=" def is_a_class_or_function_name_in_header(self, offset): word_start = self._find_word_start(offset - 1) line_start = self._get_line_start(word_start) prev_word = self.code[line_start:word_start].strip() - return prev_word in ['def', 'class'] + return prev_word in ["def", "class"] def _find_first_non_space_char(self, offset): if offset >= len(self.code): return len(self.code) while offset < len(self.code) and self.code[offset].isspace(): - if self.code[offset] == '\n': + if self.code[offset] == "\n": return offset offset += 1 return offset @@ -317,26 +316,30 @@ def _find_first_non_space_char(self, offset): def is_a_function_being_called(self, offset): word_end = self._find_word_end(offset) + 1 next_char = self._find_first_non_space_char(word_end) - return next_char < len(self.code) and \ - self.code[next_char] == '(' and \ - not self.is_a_class_or_function_name_in_header(offset) + return ( + next_char < len(self.code) + and self.code[next_char] == "(" + and not self.is_a_class_or_function_name_in_header(offset) + ) def _find_import_end(self, start): return self._get_line_end(start) def is_import_statement(self, offset): try: - last_import = self.code.rindex('import ', 0, offset) + last_import = self.code.rindex("import ", 0, offset) except ValueError: return False line_start = self._get_line_start(last_import) - return (self._find_import_end(last_import + 7) >= offset and - self._find_word_start(line_start) == last_import) + return ( + self._find_import_end(last_import + 7) >= offset + and self._find_word_start(line_start) == last_import + ) def is_from_statement(self, offset): try: - last_from = self.code.rindex('from ', 0, offset) - from_import = self.code.index(' import ', last_from) + last_from = self.code.rindex("from ", 0, offset) + from_import = self.code.index(" import ", last_from) from_names = from_import + 8 except ValueError: return False @@ -349,37 +352,39 @@ def is_from_statement_module(self, offset): stmt_start = self._find_primary_start(offset) line_start = self._get_line_start(stmt_start) prev_word = self.code[line_start:stmt_start].strip() - return prev_word == 'from' + return prev_word == "from" def is_import_statement_aliased_module(self, offset): if not self.is_import_statement(offset): return False try: line_start = self._get_line_start(offset) - import_idx = self.code.rindex('import', line_start, offset) + import_idx = self.code.rindex("import", line_start, offset) imported_names = import_idx + 7 except ValueError: return False # Check if the offset is within the imported names - if (imported_names - 1 > offset or - self._find_import_end(imported_names) < offset): + if ( + imported_names - 1 > offset + or self._find_import_end(imported_names) < offset + ): return False try: end = self._find_word_end(offset) as_end = min(self._find_word_end(end + 1), len(self.code)) as_start = self._find_word_start(as_end) - return self.code[as_start:as_end + 1] == 'as' + return self.code[as_start : as_end + 1] == "as" except ValueError: return False def is_a_name_after_from_import(self, offset): try: - if len(self.code) > offset and self.code[offset] == '\n': + if len(self.code) > offset and self.code[offset] == "\n": line_start = self._get_line_start(offset - 1) else: line_start = self._get_line_start(offset) - last_from = self.code.rindex('from ', line_start, offset) - from_import = self.code.index(' import ', last_from) + last_from = self.code.rindex("from ", line_start, offset) + from_import = self.code.index(" import ", last_from) from_names = from_import + 8 except ValueError: return False @@ -389,8 +394,8 @@ def is_a_name_after_from_import(self, offset): def get_from_module(self, offset): try: - last_from = self.code.rindex('from ', 0, offset) - import_offset = self.code.index(' import ', last_from) + last_from = self.code.rindex("from ", 0, offset) + import_offset = self.code.index(" import ", last_from) end = self._find_last_non_space_char(import_offset) return self.get_primary_at(end) except ValueError: @@ -403,7 +408,7 @@ def is_from_aliased(self, offset): end = self._find_word_end(offset) as_end = min(self._find_word_end(end + 1), len(self.code)) as_start = self._find_word_start(as_end) - return self.code[as_start:as_end + 1] == 'as' + return self.code[as_start : as_end + 1] == "as" except ValueError: return False @@ -413,7 +418,7 @@ def get_from_aliased(self, offset): as_ = self._find_word_end(end + 1) alias = self._find_word_end(as_ + 1) start = self._find_word_start(alias) - return self.raw[start:alias + 1] + return self.raw[start : alias + 1] except ValueError: pass @@ -422,19 +427,19 @@ def is_function_keyword_parameter(self, offset): if word_end + 1 == len(self.code): return False next_char = self._find_first_non_space_char(word_end + 1) - equals = self.code[next_char:next_char + 2] - if equals == '==' or not equals.startswith('='): + equals = self.code[next_char : next_char + 2] + if equals == "==" or not equals.startswith("="): return False word_start = self._find_word_start(offset) prev_char = self._find_last_non_space_char(word_start - 1) - return prev_char - 1 >= 0 and self.code[prev_char] in ',(' + return prev_char - 1 >= 0 and self.code[prev_char] in ",(" def is_on_function_call_keyword(self, offset): stop = self._get_line_start(offset) if self._is_id_char(offset): offset = self._find_word_start(offset) - 1 offset = self._find_last_non_space_char(offset) - if offset <= stop or self.code[offset] not in '(,': + if offset <= stop or self.code[offset] not in "(,": return False parens_start = self.find_parens_start_from_inside(offset) return stop < parens_start @@ -442,9 +447,9 @@ def is_on_function_call_keyword(self, offset): def find_parens_start_from_inside(self, offset): stop = self._get_line_start(offset) while offset > stop: - if self.code[offset] == '(': + if self.code[offset] == "(": break - if self.code[offset] != ',': + if self.code[offset] != ",": offset = self._find_primary_start(offset) offset -= 1 return max(stop, offset) @@ -456,12 +461,12 @@ def get_assignment_type(self, offset): # XXX: does not handle tuple assignments word_end = self._find_word_end(offset) next_char = self._find_first_non_space_char(word_end + 1) - single = self.code[next_char:next_char + 1] - double = self.code[next_char:next_char + 2] - triple = self.code[next_char:next_char + 3] - if double not in ('==', '<=', '>=', '!='): + single = self.code[next_char : next_char + 1] + double = self.code[next_char : next_char + 2] + triple = self.code[next_char : next_char + 3] + if double not in ("==", "<=", ">=", "!="): for op in [single, double, triple]: - if op.endswith('='): + if op.endswith("="): return op def get_primary_range(self, offset): @@ -475,7 +480,7 @@ def get_word_range(self, offset): end = self._find_word_end(offset) + 1 return (start, end) - def get_word_parens_range(self, offset, opening='(', closing=')'): + def get_word_parens_range(self, offset, opening="(", closing=")"): end = self._find_word_end(offset) start_parens = self.code.index(opening, end) index = start_parens @@ -497,16 +502,17 @@ def get_parameters(self, first, last): while current > first: primary_start = current current = self._find_primary_start(current) - while current != first and (self.code[current] not in '=,' - or self.code[current-1] in '=!<>'): + while current != first and ( + self.code[current] not in "=," or self.code[current - 1] in "=!<>" + ): current = self._find_last_non_space_char(current - 1) - primary = self.raw[current + 1:primary_start + 1].strip() - if self.code[current] == '=': + primary = self.raw[current + 1 : primary_start + 1].strip() + if self.code[current] == "=": primary_start = current - 1 current -= 1 - while current != first and self.code[current] not in ',': + while current != first and self.code[current] not in ",": current = self._find_last_non_space_char(current - 1) - param_name = self.raw[current + 1:primary_start + 1].strip() + param_name = self.raw[current + 1 : primary_start + 1].strip() keywords.append((param_name, primary)) else: args.append(primary) @@ -523,28 +529,27 @@ def is_assigned_in_a_tuple_assignment(self, offset): prev_char_offset = self._find_last_non_space_char(primary_start - 1) next_char_offset = self._find_first_non_space_char(primary_end + 1) - next_char = prev_char = '' + next_char = prev_char = "" if prev_char_offset >= start: prev_char = self.code[prev_char_offset] if next_char_offset < end: next_char = self.code[next_char_offset] try: - equals_offset = self.code.index('=', start, end) + equals_offset = self.code.index("=", start, end) except ValueError: return False - if prev_char not in '(,' and next_char not in ',)': + if prev_char not in "(," and next_char not in ",)": return False parens_start = self.find_parens_start_from_inside(offset) # XXX: only handling (x, y) = value - return offset < equals_offset and \ - self.code[start:parens_start].strip() == '' + return offset < equals_offset and self.code[start:parens_start].strip() == "" def get_function_and_args_in_header(self, offset): offset = self.find_function_offset(offset) lparens, rparens = self.get_word_parens_range(offset) - return self.raw[offset:rparens + 1] + return self.raw[offset : rparens + 1] - def find_function_offset(self, offset, definition='def '): + def find_function_offset(self, offset, definition="def "): while True: offset = self.code.index(definition, offset) if offset == 0 or not self._is_id_char(offset - 1): @@ -554,7 +559,6 @@ def find_function_offset(self, offset, definition='def '): return self._find_first_non_space_char(def_) def get_lambda_and_args(self, offset): - offset = self.find_function_offset(offset, definition='lambda ') - lparens, rparens = self.get_word_parens_range(offset, opening=' ', - closing=':') - return self.raw[offset:rparens + 1] + offset = self.find_function_offset(offset, definition="lambda ") + lparens, rparens = self.get_word_parens_range(offset, opening=" ", closing=":") + return self.raw[offset : rparens + 1] diff --git a/rope/contrib/autoimport.py b/rope/contrib/autoimport.py index 72c565abb..ec1a18a03 100644 --- a/rope/contrib/autoimport.py +++ b/rope/contrib/autoimport.py @@ -29,13 +29,14 @@ def __init__(self, project, observe=True, underlined=False): """ self.project = project self.underlined = underlined - self.names = project.data_files.read_data('globalnames') + self.names = project.data_files.read_data("globalnames") if self.names is None: self.names = {} project.data_files.add_write_hook(self._write) # XXX: using a filtered observer observer = resourceobserver.ResourceObserver( - changed=self._changed, moved=self._moved, removed=self._removed) + changed=self._changed, moved=self._moved, removed=self._removed + ) if observe: project.add_observer(observer) @@ -86,8 +87,9 @@ def get_name_locations(self, name): pass return result - def generate_cache(self, resources=None, underlined=None, - task_handle=taskhandle.NullTaskHandle()): + def generate_cache( + self, resources=None, underlined=None, task_handle=taskhandle.NullTaskHandle() + ): """Generate global name cache for project files If `resources` is a list of `rope.base.resource.File`, only @@ -98,20 +100,23 @@ def generate_cache(self, resources=None, underlined=None, if resources is None: resources = self.project.get_python_files() job_set = task_handle.create_jobset( - 'Generatig autoimport cache', len(resources)) + "Generatig autoimport cache", len(resources) + ) for file in resources: - job_set.started_job('Working on <%s>' % file.path) + job_set.started_job("Working on <%s>" % file.path) self.update_resource(file, underlined) job_set.finished_job() - def generate_modules_cache(self, modules, underlined=None, - task_handle=taskhandle.NullTaskHandle()): + def generate_modules_cache( + self, modules, underlined=None, task_handle=taskhandle.NullTaskHandle() + ): """Generate global name cache for modules listed in `modules`""" job_set = task_handle.create_jobset( - 'Generatig autoimport cache for modules', len(modules)) + "Generatig autoimport cache for modules", len(modules) + ) for modname in modules: - job_set.started_job('Working on <%s>' % modname) - if modname.endswith('.*'): + job_set.started_job("Working on <%s>" % modname) + if modname.endswith(".*"): mod = self.project.find_module(modname[:-2]) if mod: for sub in submodules(mod): @@ -131,21 +136,20 @@ def clear_cache(self): def find_insertion_line(self, code): """Guess at what line the new import should be inserted""" - match = re.search(r'^(def|class)\s+', code) + match = re.search(r"^(def|class)\s+", code) if match is not None: - code = code[:match.start()] + code = code[: match.start()] try: pymodule = libutils.get_string_module(self.project, code) except exceptions.ModuleSyntaxError: return 1 - testmodname = '__rope_testmodule_rope' + testmodname = "__rope_testmodule_rope" importinfo = importutils.NormalImport(((testmodname, None),)) - module_imports = importutils.get_module_imports(self.project, - pymodule) + module_imports = importutils.get_module_imports(self.project, pymodule) module_imports.add_import(importinfo) code = module_imports.get_changed_source() offset = code.index(testmodname) - lineno = code.count('\n', 0, offset) + 1 + lineno = code.count("\n", 0, offset) + 1 return lineno def update_resource(self, resource, underlined=None): @@ -180,7 +184,7 @@ def _add_names(self, pymodule, modname, underlined): else: attributes = pymodule.get_attributes() for name, pyname in attributes.items(): - if not underlined and name.startswith('_'): + if not underlined and name.startswith("_"): continue if isinstance(pyname, (pynames.AssignedName, pynames.DefinedName)): globals.append(name) @@ -189,7 +193,7 @@ def _add_names(self, pymodule, modname, underlined): self.names[modname] = globals def _write(self): - self.project.data_files.write_data('globalnames', self.names) + self.project.data_files.write_data("globalnames", self.names) def _changed(self, resource): if not resource.is_folder(): @@ -211,10 +215,10 @@ def _removed(self, resource): def submodules(mod): if isinstance(mod, resources.File): - if mod.name.endswith('.py') and mod.name != '__init__.py': + if mod.name.endswith(".py") and mod.name != "__init__.py": return set([mod]) return set() - if not mod.has_child('__init__.py'): + if not mod.has_child("__init__.py"): return set() result = set([mod]) for child in mod.get_children(): diff --git a/rope/contrib/changestack.py b/rope/contrib/changestack.py index 70f2271f7..1325cef25 100644 --- a/rope/contrib/changestack.py +++ b/rope/contrib/changestack.py @@ -22,8 +22,7 @@ class ChangeStack(object): - - def __init__(self, project, description='merged changes'): + def __init__(self, project, description="merged changes"): self.project = project self.description = description self.stack = [] diff --git a/rope/contrib/codeassist.py b/rope/contrib/codeassist.py index 9575e7c83..ebc5e3969 100644 --- a/rope/contrib/codeassist.py +++ b/rope/contrib/codeassist.py @@ -17,8 +17,15 @@ from rope.refactor import functionutils -def code_assist(project, source_code, offset, resource=None, - templates=None, maxfixes=1, later_locals=True): +def code_assist( + project, + source_code, + offset, + resource=None, + templates=None, + maxfixes=1, + later_locals=True, +): """Return python code completions as a list of `CodeAssistProposal` `resource` is a `rope.base.resources.Resource` object. If @@ -32,11 +39,17 @@ def code_assist(project, source_code, offset, resource=None, """ if templates is not None: - warnings.warn('Codeassist no longer supports templates', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Codeassist no longer supports templates", DeprecationWarning, stacklevel=2 + ) assist = _PythonCodeAssist( - project, source_code, offset, resource=resource, - maxfixes=maxfixes, later_locals=later_locals) + project, + source_code, + offset, + resource=resource, + maxfixes=maxfixes, + later_locals=later_locals, + ) return assist() @@ -53,8 +66,9 @@ def starting_offset(source_code, offset): """ word_finder = worder.Worder(source_code, True) - expression, starting, starting_offset = \ - word_finder.get_splitted_primary_before(offset) + expression, starting, starting_offset = word_finder.get_splitted_primary_before( + offset + ) return starting_offset @@ -68,8 +82,15 @@ def get_doc(project, source_code, offset, resource=None, maxfixes=1): return PyDocExtractor().get_doc(pyobject) -def get_calltip(project, source_code, offset, resource=None, - maxfixes=1, ignore_unknown=False, remove_self=False): +def get_calltip( + project, + source_code, + offset, + resource=None, + maxfixes=1, + ignore_unknown=False, + remove_self=False, +): """Get the calltip of a function The format of the returned string is @@ -101,8 +122,7 @@ def get_calltip(project, source_code, offset, resource=None, return PyDocExtractor().get_calltip(pyobject, ignore_unknown, remove_self) -def get_definition_location(project, source_code, offset, - resource=None, maxfixes=1): +def get_definition_location(project, source_code, offset, resource=None, maxfixes=1): """Return the definition location of the python name at `offset` Return a (`rope.base.resources.Resource`, lineno) tuple. If no @@ -122,8 +142,12 @@ def get_definition_location(project, source_code, offset, def find_occurrences(*args, **kwds): import rope.contrib.findit - warnings.warn('Use `rope.contrib.findit.find_occurrences()` instead', - DeprecationWarning, stacklevel=2) + + warnings.warn( + "Use `rope.contrib.findit.find_occurrences()` instead", + DeprecationWarning, + stacklevel=2, + ) return rope.contrib.findit.find_occurrences(*args, **kwds) @@ -163,24 +187,22 @@ def mux(self, x): # Start with the name of the object we're interested in. names = [] if isinstance(pyname, pynamesdef.ParameterName): - names = [(worder.get_name_at(pymod.get_resource(), offset), - 'PARAMETER') ] + names = [(worder.get_name_at(pymod.get_resource(), offset), "PARAMETER")] elif isinstance(pyname, pynamesdef.AssignedName): - names = [(worder.get_name_at(pymod.get_resource(), offset), - 'VARIABLE')] + names = [(worder.get_name_at(pymod.get_resource(), offset), "VARIABLE")] # Collect scope names. while scope.parent: if isinstance(scope, pyscopes.FunctionScope): - scope_type = 'FUNCTION' + scope_type = "FUNCTION" elif isinstance(scope, pyscopes.ClassScope): - scope_type = 'CLASS' + scope_type = "CLASS" else: scope_type = None names.append((scope.pyobject.get_name(), scope_type)) scope = scope.parent - names.append((defmod.get_resource().real_path, 'MODULE')) + names.append((defmod.get_resource().real_path, "MODULE")) names.reverse() return names @@ -216,7 +238,7 @@ def __init__(self, name, scope, pyname=None): self.scope = self._get_scope(scope) def __str__(self): - return '%s (%s, %s)' % (self.name, self.scope, self.type) + return "%s (%s, %s)" % (self.name, self.scope, self.type) def __repr__(self): return str(self) @@ -241,29 +263,32 @@ def type(self): if isinstance(pyname, builtins.BuiltinName): pyobject = pyname.get_object() if isinstance(pyobject, builtins.BuiltinFunction): - return 'function' + return "function" elif isinstance(pyobject, builtins.BuiltinClass): - return 'class' - elif isinstance(pyobject, builtins.BuiltinObject) or \ - isinstance(pyobject, builtins.BuiltinName): - return 'instance' + return "class" + elif isinstance(pyobject, builtins.BuiltinObject) or isinstance( + pyobject, builtins.BuiltinName + ): + return "instance" elif isinstance(pyname, pynames.ImportedModule): - return 'module' - elif isinstance(pyname, pynames.ImportedName) or \ - isinstance(pyname, pynames.DefinedName): + return "module" + elif isinstance(pyname, pynames.ImportedName) or isinstance( + pyname, pynames.DefinedName + ): pyobject = pyname.get_object() if isinstance(pyobject, pyobjects.AbstractFunction): - return 'function' + return "function" if isinstance(pyobject, pyobjects.AbstractClass): - return 'class' - return 'instance' + return "class" + return "instance" def _get_scope(self, scope): if isinstance(self.pyname, builtins.BuiltinName): - return 'builtin' - if isinstance(self.pyname, pynames.ImportedModule) or \ - isinstance(self.pyname, pynames.ImportedName): - return 'imported' + return "builtin" + if isinstance(self.pyname, pynames.ImportedModule) or isinstance( + self.pyname, pynames.ImportedName + ): + return "imported" return scope def get_doc(self): @@ -274,14 +299,15 @@ def get_doc(self): if not self.pyname: return None pyobject = self.pyname.get_object() - if not hasattr(pyobject, 'get_doc'): + if not hasattr(pyobject, "get_doc"): return None return self.pyname.get_object().get_doc() @property def kind(self): - warnings.warn("the proposal's `kind` property is deprecated, " - "use `scope` instead") + warnings.warn( + "the proposal's `kind` property is deprecated, " "use `scope` instead" + ) return self.scope @@ -296,10 +322,11 @@ class NamedParamProposal(CompletionProposal): parameter ``name`` belongs to. This allows to determine default value for this parameter. """ + def __init__(self, name, function): self.argname = name - name = '%s=' % name - super(NamedParamProposal, self).__init__(name, 'parameter_keyword') + name = "%s=" % name + super(NamedParamProposal, self).__init__(name, "parameter_keyword") self._function = function def get_default(self): @@ -334,38 +361,44 @@ def sorted_proposals(proposals, scopepref=None, typepref=None): def starting_expression(source_code, offset): """Return the expression to complete""" word_finder = worder.Worder(source_code, True) - expression, starting, starting_offset = \ - word_finder.get_splitted_primary_before(offset) + expression, starting, starting_offset = word_finder.get_splitted_primary_before( + offset + ) if expression: - return expression + '.' + starting + return expression + "." + starting return starting def default_templates(): - warnings.warn('default_templates() is deprecated.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "default_templates() is deprecated.", DeprecationWarning, stacklevel=2 + ) return {} class _PythonCodeAssist(object): - - def __init__(self, project, source_code, offset, resource=None, - maxfixes=1, later_locals=True): + def __init__( + self, project, source_code, offset, resource=None, maxfixes=1, later_locals=True + ): self.project = project self.code = source_code self.resource = resource self.maxfixes = maxfixes self.later_locals = later_locals self.word_finder = worder.Worder(source_code, True) - self.expression, self.starting, self.offset = \ - self.word_finder.get_splitted_primary_before(offset) + ( + self.expression, + self.starting, + self.offset, + ) = self.word_finder.get_splitted_primary_before(offset) keywords = keyword.kwlist def _find_starting_offset(self, source_code, offset): current_offset = offset - 1 - while current_offset >= 0 and (source_code[current_offset].isalnum() or - source_code[current_offset] in '_'): + while current_offset >= 0 and ( + source_code[current_offset].isalnum() or source_code[current_offset] in "_" + ): current_offset -= 1 return current_offset + 1 @@ -373,31 +406,28 @@ def _matching_keywords(self, starting): result = [] for kw in self.keywords: if kw.startswith(starting): - result.append(CompletionProposal(kw, 'keyword')) + result.append(CompletionProposal(kw, "keyword")) return result def __call__(self): if self.offset > len(self.code): return [] completions = list(self._code_completions().values()) - if self.expression.strip() == '' and self.starting.strip() != '': + if self.expression.strip() == "" and self.starting.strip() != "": completions.extend(self._matching_keywords(self.starting)) return completions def _dotted_completions(self, module_scope, holding_scope): result = {} - found_pyname = rope.base.evaluate.eval_str(holding_scope, - self.expression) + found_pyname = rope.base.evaluate.eval_str(holding_scope, self.expression) if found_pyname is not None: element = found_pyname.get_object() - compl_scope = 'attribute' - if isinstance(element, (pyobjectsdef.PyModule, - pyobjectsdef.PyPackage)): - compl_scope = 'imported' + compl_scope = "attribute" + if isinstance(element, (pyobjectsdef.PyModule, pyobjectsdef.PyPackage)): + compl_scope = "imported" for name, pyname in element.get_attributes().items(): if name.startswith(self.starting): - result[name] = CompletionProposal(name, compl_scope, - pyname) + result[name] = CompletionProposal(name, compl_scope, pyname) return result def _undotted_completions(self, scope, result, lineno=None): @@ -409,13 +439,15 @@ def _undotted_completions(self, scope, result, lineno=None): names = scope.get_names() for name, pyname in names.items(): if name.startswith(self.starting): - compl_scope = 'local' - if scope.get_kind() == 'Module': - compl_scope = 'global' - if lineno is None or self.later_locals or \ - not self._is_defined_after(scope, pyname, lineno): - result[name] = CompletionProposal(name, compl_scope, - pyname) + compl_scope = "local" + if scope.get_kind() == "Module": + compl_scope = "global" + if ( + lineno is None + or self.later_locals + or not self._is_defined_after(scope, pyname, lineno) + ): + result[name] = CompletionProposal(name, compl_scope, pyname) def _from_import_completions(self, pymodule): module_name = self.word_finder.get_from_module(self.offset) @@ -425,44 +457,46 @@ def _from_import_completions(self, pymodule): result = {} for name in pymodule: if name.startswith(self.starting): - result[name] = CompletionProposal(name, scope='global', - pyname=pymodule[name]) + result[name] = CompletionProposal( + name, scope="global", pyname=pymodule[name] + ) return result def _find_module(self, pymodule, module_name): dots = 0 - while module_name[dots] == '.': + while module_name[dots] == ".": dots += 1 - pyname = pynames.ImportedModule(pymodule, - module_name[dots:], dots) + pyname = pynames.ImportedModule(pymodule, module_name[dots:], dots) return pyname.get_object() def _is_defined_after(self, scope, pyname, lineno): location = pyname.get_definition_location() if location is not None and location[1] is not None: - if location[0] == scope.pyobject.get_module() and \ - lineno <= location[1] <= scope.get_end(): + if ( + location[0] == scope.pyobject.get_module() + and lineno <= location[1] <= scope.get_end() + ): return True def _code_completions(self): - lineno = self.code.count('\n', 0, self.offset) + 1 - fixer = fixsyntax.FixSyntax(self.project, self.code, - self.resource, self.maxfixes) + lineno = self.code.count("\n", 0, self.offset) + 1 + fixer = fixsyntax.FixSyntax( + self.project, self.code, self.resource, self.maxfixes + ) pymodule = fixer.get_pymodule() module_scope = pymodule.get_scope() code = pymodule.source_code - lines = code.split('\n') + lines = code.split("\n") result = {} start = fixsyntax._logical_start(lines, lineno) indents = fixsyntax._get_line_indents(lines[start - 1]) inner_scope = module_scope.get_inner_scope_for_line(start, indents) if self.word_finder.is_a_name_after_from_import(self.offset): return self._from_import_completions(pymodule) - if self.expression.strip() != '': + if self.expression.strip() != "": result.update(self._dotted_completions(module_scope, inner_scope)) else: - result.update(self._keyword_parameters(module_scope.pyobject, - inner_scope)) + result.update(self._keyword_parameters(module_scope.pyobject, inner_scope)) self._undotted_completions(inner_scope, result, lineno=lineno) return result @@ -472,33 +506,30 @@ def _keyword_parameters(self, pymodule, scope): return {} word_finder = worder.Worder(self.code, True) if word_finder.is_on_function_call_keyword(offset - 1): - function_parens = word_finder.\ - find_parens_start_from_inside(offset - 1) + function_parens = word_finder.find_parens_start_from_inside(offset - 1) primary = word_finder.get_primary_at(function_parens - 1) try: - function_pyname = rope.base.evaluate.\ - eval_str(scope, primary) + function_pyname = rope.base.evaluate.eval_str(scope, primary) except exceptions.BadIdentifierError: return {} if function_pyname is not None: pyobject = function_pyname.get_object() if isinstance(pyobject, pyobjects.AbstractFunction): pass - elif isinstance(pyobject, pyobjects.AbstractClass) and \ - '__init__' in pyobject: - pyobject = pyobject['__init__'].get_object() - elif '__call__' in pyobject: - pyobject = pyobject['__call__'].get_object() + elif ( + isinstance(pyobject, pyobjects.AbstractClass) + and "__init__" in pyobject + ): + pyobject = pyobject["__init__"].get_object() + elif "__call__" in pyobject: + pyobject = pyobject["__call__"].get_object() if isinstance(pyobject, pyobjects.AbstractFunction): param_names = [] - param_names.extend( - pyobject.get_param_names(special_args=False)) + param_names.extend(pyobject.get_param_names(special_args=False)) result = {} for name in param_names: if name.startswith(self.starting): - result[name + '='] = NamedParamProposal( - name, pyobject - ) + result[name + "="] = NamedParamProposal(name, pyobject) return result return {} @@ -509,13 +540,19 @@ class _ProposalSorter(object): def __init__(self, code_assist_proposals, scopepref=None, typepref=None): self.proposals = code_assist_proposals if scopepref is None: - scopepref = ['parameter_keyword', 'local', 'global', 'imported', - 'attribute', 'builtin', 'keyword'] + scopepref = [ + "parameter_keyword", + "local", + "global", + "imported", + "attribute", + "builtin", + "keyword", + ] self.scopepref = scopepref if typepref is None: - typepref = ['class', 'function', 'instance', 'module', None] - self.typerank = dict((type, index) - for index, type in enumerate(typepref)) + typepref = ["class", "function", "instance", "module", None] + self.typerank = dict((type, index) for index, type in enumerate(typepref)) def get_sorted_proposal_list(self): """Return a list of `CodeAssistProposal`""" @@ -525,27 +562,32 @@ def get_sorted_proposal_list(self): result = [] for scope in self.scopepref: scope_proposals = proposals.get(scope, []) - scope_proposals = [proposal for proposal in scope_proposals - if proposal.type in self.typerank] + scope_proposals = [ + proposal + for proposal in scope_proposals + if proposal.type in self.typerank + ] scope_proposals.sort(key=self._proposal_key) result.extend(scope_proposals) return result def _proposal_key(self, proposal1): def _underline_count(name): - return sum(1 for c in name if c == "_") - return (self.typerank.get(proposal1.type, 100), - _underline_count(proposal1.name), - proposal1.name) - #if proposal1.type != proposal2.type: + return sum(1 for c in name if c == "_") + + return ( + self.typerank.get(proposal1.type, 100), + _underline_count(proposal1.name), + proposal1.name, + ) + # if proposal1.type != proposal2.type: # return cmp(self.typerank.get(proposal1.type, 100), # self.typerank.get(proposal2.type, 100)) - #return self._compare_underlined_names(proposal1.name, + # return self._compare_underlined_names(proposal1.name, # proposal2.name) class PyDocExtractor(object): - def get_doc(self, pyobject): if isinstance(pyobject, pyobjects.AbstractFunction): return self._get_function_docstring(pyobject) @@ -558,9 +600,9 @@ def get_doc(self, pyobject): def get_calltip(self, pyobject, ignore_unknown=False, remove_self=False): try: if isinstance(pyobject, pyobjects.AbstractClass): - pyobject = pyobject['__init__'].get_object() + pyobject = pyobject["__init__"].get_object() if not isinstance(pyobject, pyobjects.AbstractFunction): - pyobject = pyobject['__call__'].get_object() + pyobject = pyobject["__call__"].get_object() except exceptions.AttributeNotFoundError: return None if ignore_unknown and not isinstance(pyobject, pyobjects.PyFunction): @@ -568,37 +610,39 @@ def get_calltip(self, pyobject, ignore_unknown=False, remove_self=False): if isinstance(pyobject, pyobjects.AbstractFunction): result = self._get_function_signature(pyobject, add_module=True) if remove_self and self._is_method(pyobject): - return result.replace('(self)', '()').replace('(self, ', '(') + return result.replace("(self)", "()").replace("(self, ", "(") return result def _get_class_docstring(self, pyclass): contents = self._trim_docstring(pyclass.get_doc(), 2) supers = [super.get_name() for super in pyclass.get_superclasses()] - doc = 'class %s(%s):\n\n' % (pyclass.get_name(), ', '.join(supers)) \ - + contents + doc = "class %s(%s):\n\n" % (pyclass.get_name(), ", ".join(supers)) + contents - if '__init__' in pyclass: - init = pyclass['__init__'].get_object() + if "__init__" in pyclass: + init = pyclass["__init__"].get_object() if isinstance(init, pyobjects.AbstractFunction): - doc += '\n\n' + self._get_single_function_docstring(init) + doc += "\n\n" + self._get_single_function_docstring(init) return doc def _get_function_docstring(self, pyfunction): functions = [pyfunction] if self._is_method(pyfunction): - functions.extend(self._get_super_methods(pyfunction.parent, - pyfunction.get_name())) - return '\n\n'.join([self._get_single_function_docstring(function) - for function in functions]) + functions.extend( + self._get_super_methods(pyfunction.parent, pyfunction.get_name()) + ) + return "\n\n".join( + [self._get_single_function_docstring(function) for function in functions] + ) def _is_method(self, pyfunction): - return isinstance(pyfunction, pyobjects.PyFunction) and \ - isinstance(pyfunction.parent, pyobjects.PyClass) + return isinstance(pyfunction, pyobjects.PyFunction) and isinstance( + pyfunction.parent, pyobjects.PyClass + ) def _get_single_function_docstring(self, pyfunction): signature = self._get_function_signature(pyfunction) docs = self._trim_docstring(pyfunction.get_doc(), indents=2) - return signature + ':\n\n' + docs + return signature + ":\n\n" + docs def _get_super_methods(self, pyclass, name): result = [] @@ -616,35 +660,37 @@ def _get_function_signature(self, pyfunction, add_module=False): info = functionutils.DefinitionInfo.read(pyfunction) return location + info.to_string() else: - return '%s(%s)' % (location + pyfunction.get_name(), - ', '.join(pyfunction.get_param_names())) + return "%s(%s)" % ( + location + pyfunction.get_name(), + ", ".join(pyfunction.get_param_names()), + ) def _location(self, pyobject, add_module=False): location = [] parent = pyobject.parent while parent and not isinstance(parent, pyobjects.AbstractModule): location.append(parent.get_name()) - location.append('.') + location.append(".") parent = parent.parent if add_module: if isinstance(pyobject, pyobjects.PyFunction): location.insert(0, self._get_module(pyobject)) if isinstance(parent, builtins.BuiltinModule): - location.insert(0, parent.get_name() + '.') - return ''.join(location) + location.insert(0, parent.get_name() + ".") + return "".join(location) def _get_module(self, pyfunction): module = pyfunction.get_module() if module is not None: resource = module.get_resource() if resource is not None: - return libutils.modname(resource) + '.' - return '' + return libutils.modname(resource) + "." + return "" def _trim_docstring(self, docstring, indents=0): """The sample code from :PEP:`257`""" if not docstring: - return '' + return "" # Convert tabs to spaces (following normal Python rules) # and split into a list of lines: lines = docstring.expandtabs().splitlines() @@ -665,25 +711,25 @@ def _trim_docstring(self, docstring, indents=0): while trimmed and not trimmed[0]: trimmed.pop(0) # Return a single string: - return '\n'.join((' ' * indents + line for line in trimmed)) + return "\n".join((" " * indents + line for line in trimmed)) # Deprecated classes + class TemplateProposal(CodeAssistProposal): def __init__(self, name, template): - warnings.warn('TemplateProposal is deprecated.', - DeprecationWarning, stacklevel=2) - super(TemplateProposal, self).__init__(name, 'template') + warnings.warn( + "TemplateProposal is deprecated.", DeprecationWarning, stacklevel=2 + ) + super(TemplateProposal, self).__init__(name, "template") self.template = template class Template(object): - def __init__(self, template): self.template = template - warnings.warn('Template is deprecated.', - DeprecationWarning, stacklevel=2) + warnings.warn("Template is deprecated.", DeprecationWarning, stacklevel=2) def variables(self): return [] diff --git a/rope/contrib/finderrors.py b/rope/contrib/finderrors.py index 868e47d45..b1742d595 100644 --- a/rope/contrib/finderrors.py +++ b/rope/contrib/finderrors.py @@ -38,7 +38,6 @@ def find_errors(project, resource): class _BadAccessFinder(object): - def __init__(self, pymodule): self.pymodule = pymodule self.scope = pymodule.get_scope() @@ -50,18 +49,17 @@ def _Name(self, node): scope = self.scope.get_inner_scope_for_line(node.lineno) pyname = scope.lookup(node.id) if pyname is None: - self._add_error(node, 'Unresolved variable') + self._add_error(node, "Unresolved variable") elif self._is_defined_after(scope, pyname, node.lineno): - self._add_error(node, 'Defined later') + self._add_error(node, "Defined later") def _Attribute(self, node): if not isinstance(node.ctx, ast.Store): scope = self.scope.get_inner_scope_for_line(node.lineno) pyname = evaluate.eval_node(scope, node.value) - if pyname is not None and \ - pyname.get_object() != pyobjects.get_unknown(): + if pyname is not None and pyname.get_object() != pyobjects.get_unknown(): if node.attr not in pyname.get_object(): - self._add_error(node, 'Unresolved attribute') + self._add_error(node, "Unresolved attribute") ast.walk(node.value, self) def _add_error(self, node, msg): @@ -69,23 +67,24 @@ def _add_error(self, node, msg): name = node.attr else: name = node.id - if name != 'None': - error = Error(node.lineno, msg + ' ' + name) + if name != "None": + error = Error(node.lineno, msg + " " + name) self.errors.append(error) def _is_defined_after(self, scope, pyname, lineno): location = pyname.get_definition_location() if location is not None and location[1] is not None: - if location[0] == self.pymodule and \ - lineno <= location[1] <= scope.get_end(): + if ( + location[0] == self.pymodule + and lineno <= location[1] <= scope.get_end() + ): return True class Error(object): - def __init__(self, lineno, error): self.lineno = lineno self.error = error def __str__(self): - return '%s: %s' % (self.lineno, self.error) + return "%s: %s" % (self.lineno, self.error) diff --git a/rope/contrib/findit.py b/rope/contrib/findit.py index ed85526b2..f22c9d1ce 100644 --- a/rope/contrib/findit.py +++ b/rope/contrib/findit.py @@ -6,9 +6,15 @@ from rope.refactor import occurrences -def find_occurrences(project, resource, offset, unsure=False, resources=None, - in_hierarchy=False, - task_handle=taskhandle.NullTaskHandle()): +def find_occurrences( + project, + resource, + offset, + unsure=False, + resources=None, + in_hierarchy=False, + task_handle=taskhandle.NullTaskHandle(), +): """Return a list of `Location` If `unsure` is `True`, possible matches are returned, too. You @@ -20,23 +26,28 @@ def find_occurrences(project, resource, offset, unsure=False, resources=None, """ name = worder.get_name_at(resource, offset) this_pymodule = project.get_pymodule(resource) - primary, pyname = rope.base.evaluate.eval_location2( - this_pymodule, offset) + primary, pyname = rope.base.evaluate.eval_location2(this_pymodule, offset) def is_match(occurrence): return unsure + finder = occurrences.create_finder( - project, name, pyname, unsure=is_match, - in_hierarchy=in_hierarchy, instance=primary) + project, + name, + pyname, + unsure=is_match, + in_hierarchy=in_hierarchy, + instance=primary, + ) if resources is None: resources = project.get_python_files() - job_set = task_handle.create_jobset('Finding Occurrences', - count=len(resources)) + job_set = task_handle.create_jobset("Finding Occurrences", count=len(resources)) return _find_locations(finder, resources, job_set) -def find_implementations(project, resource, offset, resources=None, - task_handle=taskhandle.NullTaskHandle()): +def find_implementations( + project, resource, offset, resources=None, task_handle=taskhandle.NullTaskHandle() +): """Find the places a given method is overridden. Finds the places a method is implemented. Returns a list of @@ -47,11 +58,13 @@ def find_implementations(project, resource, offset, resources=None, pyname = rope.base.evaluate.eval_location(this_pymodule, offset) if pyname is not None: pyobject = pyname.get_object() - if not isinstance(pyobject, rope.base.pyobjects.PyFunction) or \ - pyobject.get_kind() != 'method': - raise exceptions.BadIdentifierError('Not a method!') + if ( + not isinstance(pyobject, rope.base.pyobjects.PyFunction) + or pyobject.get_kind() != "method" + ): + raise exceptions.BadIdentifierError("Not a method!") else: - raise exceptions.BadIdentifierError('Cannot resolve the identifier!') + raise exceptions.BadIdentifierError("Cannot resolve the identifier!") def is_defined(occurrence): if not occurrence.is_defined(): @@ -60,13 +73,12 @@ def is_defined(occurrence): def not_self(occurrence): if occurrence.get_pyname().get_object() == pyname.get_object(): return False - filters = [is_defined, not_self, - occurrences.InHierarchyFilter(pyname, True)] + + filters = [is_defined, not_self, occurrences.InHierarchyFilter(pyname, True)] finder = occurrences.Finder(project, name, filters=filters) if resources is None: resources = project.get_python_files() - job_set = task_handle.create_jobset('Finding Implementations', - count=len(resources)) + job_set = task_handle.create_jobset("Finding Implementations", count=len(resources)) return _find_locations(finder, resources, job_set) @@ -87,15 +99,14 @@ def find_definition(project, code, offset, resource=None, maxfixes=1): def check_offset(occurrence): if occurrence.offset < start: return False + pyname_filter = occurrences.PyNameFilter(pyname) - finder = occurrences.Finder(project, name, - [check_offset, pyname_filter]) + finder = occurrences.Finder(project, name, [check_offset, pyname_filter]) for occurrence in finder.find_occurrences(pymodule=module): return Location(occurrence) class Location(object): - def __init__(self, occurrence): self.resource = occurrence.resource self.region = occurrence.get_word_range() diff --git a/rope/contrib/fixmodnames.py b/rope/contrib/fixmodnames.py index d8bd3da10..6290472e2 100644 --- a/rope/contrib/fixmodnames.py +++ b/rope/contrib/fixmodnames.py @@ -21,21 +21,20 @@ class FixModuleNames(object): - def __init__(self, project): self.project = project - def get_changes(self, fixer=str.lower, - task_handle=taskhandle.NullTaskHandle()): + def get_changes(self, fixer=str.lower, task_handle=taskhandle.NullTaskHandle()): """Fix module names `fixer` is a function that takes and returns a `str`. Given the name of a module, it should return the fixed name. """ - stack = changestack.ChangeStack(self.project, 'Fixing module names') - jobset = task_handle.create_jobset('Fixing module names', - self._count_fixes(fixer) + 1) + stack = changestack.ChangeStack(self.project, "Fixing module names") + jobset = task_handle.create_jobset( + "Fixing module names", self._count_fixes(fixer) + 1 + ) try: while True: for resource in self._tobe_fixed(fixer): @@ -48,7 +47,7 @@ def get_changes(self, fixer=str.lower, else: break finally: - jobset.started_job('Reverting to original state') + jobset.started_job("Reverting to original state") stack.pop_all() jobset.finished_job() return stack.merged() @@ -63,7 +62,7 @@ def _tobe_fixed(self, fixer): yield resource def _name(self, resource): - modname = resource.name.rsplit('.', 1)[0] - if modname == '__init__': + modname = resource.name.rsplit(".", 1)[0] + if modname == "__init__": modname = resource.parent.name return modname diff --git a/rope/contrib/fixsyntax.py b/rope/contrib/fixsyntax.py index fa2a17d93..b7eaac7ab 100644 --- a/rope/contrib/fixsyntax.py +++ b/rope/contrib/fixsyntax.py @@ -8,7 +8,6 @@ class FixSyntax(object): - def __init__(self, project, code, resource, maxfixes=1): self.project = project self.code = code @@ -23,24 +22,26 @@ def get_pymodule(self): tries = 0 while True: try: - if tries == 0 and self.resource is not None and \ - self.resource.read() == code: - return self.project.get_pymodule(self.resource, - force_errors=True) + if ( + tries == 0 + and self.resource is not None + and self.resource.read() == code + ): + return self.project.get_pymodule(self.resource, force_errors=True) return libutils.get_string_module( - self.project, code, resource=self.resource, - force_errors=True) + self.project, code, resource=self.resource, force_errors=True + ) except exceptions.ModuleSyntaxError as e: if msg is None: - msg = '%s:%s %s' % (e.filename, e.lineno, e.message_) + msg = "%s:%s %s" % (e.filename, e.lineno, e.message_) if tries < self.maxfixes: tries += 1 self.commenter.comment(e.lineno) - code = '\n'.join(self.commenter.lines) + code = "\n".join(self.commenter.lines) else: raise exceptions.ModuleSyntaxError( - e.filename, e.lineno, - 'Failed to fix error: {0}'.format(msg)) + e.filename, e.lineno, "Failed to fix error: {0}".format(msg) + ) @property @utils.saveit @@ -53,16 +54,18 @@ def pyname_at(self, offset): def old_pyname(): word_finder = worder.Worder(self.code, True) expression = word_finder.get_primary_at(offset) - expression = expression.replace('\\\n', ' ').replace('\n', ' ') - lineno = self.code.count('\n', 0, offset) + expression = expression.replace("\\\n", " ").replace("\n", " ") + lineno = self.code.count("\n", 0, offset) scope = pymodule.get_scope().get_inner_scope_for_line(lineno) return rope.base.evaluate.eval_str(scope, expression) + new_code = pymodule.source_code def new_pyname(): newoffset = self.commenter.transfered_offset(offset) return rope.base.evaluate.eval_location(pymodule, newoffset) - if new_code.startswith(self.code[:offset + 1]): + + if new_code.startswith(self.code[: offset + 1]): return new_pyname() result = old_pyname() if result is None: @@ -71,11 +74,10 @@ def new_pyname(): class _Commenter(object): - def __init__(self, code): self.code = code - self.lines = self.code.split('\n') - self.lines.append('\n') + self.lines = self.code.split("\n") + self.lines.append("\n") self.origs = list(range(len(self.lines) + 1)) self.diffs = [0] * (len(self.lines) + 1) @@ -88,20 +90,20 @@ def comment(self, lineno): if 0 < start: last_lineno = self._last_non_blank(start - 1) last_line = self.lines[last_lineno] - if last_line.rstrip().endswith(':'): + if last_line.rstrip().endswith(":"): indents = _get_line_indents(last_line) + 4 - self._set(start, ' ' * indents + 'pass') + self._set(start, " " * indents + "pass") for line in range(start + 1, end + 1): self._set(line, self.lines[start]) self._fix_incomplete_try_blocks(lineno, indents) def transfered_offset(self, offset): - lineno = self.code.count('\n', 0, offset) + lineno = self.code.count("\n", 0, offset) diff = sum(self.diffs[:lineno]) return offset + diff def _last_non_blank(self, start): - while start > 0 and self.lines[start].strip() == '': + while start > 0 and self.lines[start].strip() == "": start -= 1 return start @@ -126,27 +128,33 @@ def _fix_incomplete_try_blocks(self, lineno, indents): block_start = lineno last_indents = indents while block_start > 0: - block_start = rope.base.codeanalyze.get_block_start( - ArrayLinesAdapter(self.lines), block_start) - 1 - if self.lines[block_start].strip().startswith('try:'): + block_start = ( + rope.base.codeanalyze.get_block_start( + ArrayLinesAdapter(self.lines), block_start + ) + - 1 + ) + if self.lines[block_start].strip().startswith("try:"): indents = _get_line_indents(self.lines[block_start]) if indents > last_indents: continue last_indents = indents block_end = self._find_matching_deindent(block_start) line = self.lines[block_end].strip() - if not (line.startswith('finally:') or - line.startswith('except ') or - line.startswith('except:')): - self._insert(block_end, ' ' * indents + 'finally:') - self._insert(block_end + 1, ' ' * indents + ' pass') + if not ( + line.startswith("finally:") + or line.startswith("except ") + or line.startswith("except:") + ): + self._insert(block_end, " " * indents + "finally:") + self._insert(block_end + 1, " " * indents + " pass") def _find_matching_deindent(self, line_number): indents = _get_line_indents(self.lines[line_number]) current_line = line_number + 1 while current_line < len(self.lines): line = self.lines[current_line] - if not line.strip().startswith('#') and not line.strip() == '': + if not line.strip().startswith("#") and not line.strip() == "": # HACK: We should have used logical lines here if _get_line_indents(self.lines[current_line]) <= indents: return current_line diff --git a/rope/contrib/generate.py b/rope/contrib/generate.py index 3b9826304..0c225ee88 100644 --- a/rope/contrib/generate.py +++ b/rope/contrib/generate.py @@ -1,7 +1,6 @@ import rope.base.evaluate from rope.base import libutils -from rope.base import (change, pyobjects, exceptions, pynames, worder, - codeanalyze) +from rope.base import change, pyobjects, exceptions, pynames, worder, codeanalyze from rope.refactor import sourceutils, importutils, functionutils, suites @@ -12,7 +11,7 @@ def create_generate(kind, project, resource, offset, goal_resource=None): 'package'. """ - generate = eval('Generate' + kind.title()) + generate = eval("Generate" + kind.title()) return generate(project, resource, offset, goal_resource=goal_resource) @@ -20,28 +19,27 @@ def create_module(project, name, sourcefolder=None): """Creates a module and returns a `rope.base.resources.File`""" if sourcefolder is None: sourcefolder = project.root - packages = name.split('.') + packages = name.split(".") parent = sourcefolder for package in packages[:-1]: parent = parent.get_child(package) - return parent.create_file(packages[-1] + '.py') + return parent.create_file(packages[-1] + ".py") def create_package(project, name, sourcefolder=None): """Creates a package and returns a `rope.base.resources.Folder`""" if sourcefolder is None: sourcefolder = project.root - packages = name.split('.') + packages = name.split(".") parent = sourcefolder for package in packages[:-1]: parent = parent.get_child(package) made_packages = parent.create_folder(packages[-1]) - made_packages.create_file('__init__.py') + made_packages.create_file("__init__.py") return made_packages class _Generate(object): - def __init__(self, project, resource, offset, goal_resource=None): self.project = project self.resource = resource @@ -56,36 +54,37 @@ def _generate_info(self, project, resource, offset): def _check_exceptional_conditions(self): if self.info.element_already_exists(): raise exceptions.RefactoringError( - 'Element <%s> already exists.' % self.name) + "Element <%s> already exists." % self.name + ) if not self.info.primary_is_found(): raise exceptions.RefactoringError( - 'Cannot determine the scope <%s> should be defined in.' % - self.name) + "Cannot determine the scope <%s> should be defined in." % self.name + ) def get_changes(self): - changes = change.ChangeSet('Generate %s <%s>' % - (self._get_element_kind(), self.name)) + changes = change.ChangeSet( + "Generate %s <%s>" % (self._get_element_kind(), self.name) + ) indents = self.info.get_scope_indents() blanks = self.info.get_blank_lines() - base_definition = sourceutils.fix_indentation(self._get_element(), - indents) - definition = '\n' * blanks[0] + base_definition + '\n' * blanks[1] + base_definition = sourceutils.fix_indentation(self._get_element(), indents) + definition = "\n" * blanks[0] + base_definition + "\n" * blanks[1] resource = self.info.get_insertion_resource() start, end = self.info.get_insertion_offsets() collector = codeanalyze.ChangeCollector(resource.read()) collector.add_change(start, end, definition) - changes.add_change(change.ChangeContents( - resource, collector.get_changed())) + changes.add_change(change.ChangeContents(resource, collector.get_changed())) if self.goal_resource: - relative_import = _add_relative_import_to_module(self.project, self.resource, self.goal_resource, self.name) + relative_import = _add_relative_import_to_module( + self.project, self.resource, self.goal_resource, self.name + ) changes.add_change(relative_import) return changes def get_location(self): - return (self.info.get_insertion_resource(), - self.info.get_insertion_lineno()) + return (self.info.get_insertion_resource(), self.info.get_insertion_lineno()) def _get_element_kind(self): raise NotImplementedError() @@ -95,86 +94,89 @@ def _get_element(self): class GenerateFunction(_Generate): - def _generate_info(self, project, resource, offset): return _FunctionGenerationInfo(project.pycore, resource, offset) def _get_element(self): - decorator = '' + decorator = "" args = [] if self.info.is_static_method(): - decorator = '@staticmethod\n' - if self.info.is_method() or self.info.is_constructor() or \ - self.info.is_instance(): - args.append('self') + decorator = "@staticmethod\n" + if ( + self.info.is_method() + or self.info.is_constructor() + or self.info.is_instance() + ): + args.append("self") args.extend(self.info.get_passed_args()) - definition = '%sdef %s(%s):\n pass\n' % (decorator, self.name, - ', '.join(args)) + definition = "%sdef %s(%s):\n pass\n" % ( + decorator, + self.name, + ", ".join(args), + ) return definition def _get_element_kind(self): - return 'Function' + return "Function" class GenerateVariable(_Generate): - def _get_element(self): - return '%s = None\n' % self.name + return "%s = None\n" % self.name def _get_element_kind(self): - return 'Variable' + return "Variable" class GenerateClass(_Generate): - def _get_element(self): - return 'class %s(object):\n pass\n' % self.name + return "class %s(object):\n pass\n" % self.name def _get_element_kind(self): - return 'Class' + return "Class" class GenerateModule(_Generate): - def get_changes(self): package = self.info.get_package() - changes = change.ChangeSet('Generate Module <%s>' % self.name) - new_resource = self.project.get_file('%s/%s.py' % - (package.path, self.name)) + changes = change.ChangeSet("Generate Module <%s>" % self.name) + new_resource = self.project.get_file("%s/%s.py" % (package.path, self.name)) if new_resource.exists(): raise exceptions.RefactoringError( - 'Module <%s> already exists' % new_resource.path) + "Module <%s> already exists" % new_resource.path + ) changes.add_change(change.CreateResource(new_resource)) - changes.add_change(_add_import_to_module( - self.project, self.resource, new_resource)) + changes.add_change( + _add_import_to_module(self.project, self.resource, new_resource) + ) return changes def get_location(self): package = self.info.get_package() - return (package.get_child('%s.py' % self.name), 1) + return (package.get_child("%s.py" % self.name), 1) class GeneratePackage(_Generate): - def get_changes(self): package = self.info.get_package() - changes = change.ChangeSet('Generate Package <%s>' % self.name) - new_resource = self.project.get_folder('%s/%s' % - (package.path, self.name)) + changes = change.ChangeSet("Generate Package <%s>" % self.name) + new_resource = self.project.get_folder("%s/%s" % (package.path, self.name)) if new_resource.exists(): raise exceptions.RefactoringError( - 'Package <%s> already exists' % new_resource.path) + "Package <%s> already exists" % new_resource.path + ) changes.add_change(change.CreateResource(new_resource)) - changes.add_change(_add_import_to_module( - self.project, self.resource, new_resource)) - child = self.project.get_folder(package.path + '/' + self.name) - changes.add_change(change.CreateFile(child, '__init__.py')) + changes.add_change( + _add_import_to_module(self.project, self.resource, new_resource) + ) + child = self.project.get_folder(package.path + "/" + self.name) + changes.add_change(change.CreateFile(child, "__init__.py")) return changes def get_location(self): package = self.info.get_package() child = package.get_child(self.name) - return (child.get_child('__init__.py'), 1) + return (child.get_child("__init__.py"), 1) def _add_import_to_module(project, resource, imported): @@ -182,7 +184,7 @@ def _add_import_to_module(project, resource, imported): import_tools = importutils.ImportTools(project) module_imports = import_tools.module_imports(pymodule) module_name = libutils.modname(imported) - new_import = importutils.NormalImport(((module_name, None), )) + new_import = importutils.NormalImport(((module_name, None),)) module_imports.add_import(new_import) return change.ChangeContents(resource, module_imports.get_changed_source()) @@ -197,7 +199,6 @@ def _add_relative_import_to_module(project, resource, imported, name): class _GenerationInfo(object): - def __init__(self, pycore, resource, offset, goal_resource=None): self.pycore = pycore self.resource = resource @@ -258,28 +259,32 @@ def get_insertion_resource(self): return self.goal_pymodule.get_resource() def get_insertion_offsets(self): - if self.goal_scope.get_kind() == 'Class': + if self.goal_scope.get_kind() == "Class": start, end = sourceutils.get_body_region(self.goal_scope.pyobject) - if self.goal_pymodule.source_code[start:end].strip() == 'pass': + if self.goal_pymodule.source_code[start:end].strip() == "pass": return start, end lines = self.goal_pymodule.lines start = lines.get_line_start(self.get_insertion_lineno()) return (start, start) def get_scope_indents(self): - if self.goal_scope.get_kind() == 'Module': + if self.goal_scope.get_kind() == "Module": return 0 - return sourceutils.get_indents(self.goal_pymodule.lines, - self.goal_scope.get_start()) + 4 + return ( + sourceutils.get_indents( + self.goal_pymodule.lines, self.goal_scope.get_start() + ) + + 4 + ) def get_blank_lines(self): - if self.goal_scope.get_kind() == 'Module': + if self.goal_scope.get_kind() == "Module": base_blanks = 2 - if self.goal_pymodule.source_code.strip() == '': + if self.goal_pymodule.source_code.strip() == "": base_blanks = 0 - if self.goal_scope.get_kind() == 'Class': + if self.goal_scope.get_kind() == "Class": base_blanks = 1 - if self.goal_scope.get_kind() == 'Function': + if self.goal_scope.get_kind() == "Function": base_blanks = 0 if self.goal_scope == self.source_scope: return (0, base_blanks) @@ -292,7 +297,8 @@ def get_package(self): if isinstance(primary.get_object(), pyobjects.PyPackage): return primary.get_object().get_resource() raise exceptions.RefactoringError( - 'A module/package can be only created in a package.') + "A module/package can be only created in a package." + ) def primary_is_found(self): return self.goal_scope is not None @@ -307,7 +313,6 @@ def get_name(self): class _FunctionGenerationInfo(_GenerationInfo): - def _get_goal_scope(self): if self.is_constructor(): return self.pyname.get_object().get_scope() @@ -327,16 +332,19 @@ def element_already_exists(self): return self.get_name() in self.goal_scope.get_defined_names() def is_static_method(self): - return self.primary is not None and \ - isinstance(self.primary.get_object(), pyobjects.PyClass) + return self.primary is not None and isinstance( + self.primary.get_object(), pyobjects.PyClass + ) def is_method(self): - return self.primary is not None and \ - isinstance(self.primary.get_object().get_type(), pyobjects.PyClass) + return self.primary is not None and isinstance( + self.primary.get_object().get_type(), pyobjects.PyClass + ) def is_constructor(self): - return self.pyname is not None and \ - isinstance(self.pyname.get_object(), pyobjects.PyClass) + return self.pyname is not None and isinstance( + self.pyname.get_object(), pyobjects.PyClass + ) def is_instance(self): if self.pyname is None: @@ -346,9 +354,9 @@ def is_instance(self): def get_name(self): if self.is_constructor(): - return '__init__' + return "__init__" if self.is_instance(): - return '__call__' + return "__call__" return worder.get_name_at(self.resource, self.offset) def get_passed_args(self): @@ -365,14 +373,15 @@ def get_passed_args(self): if self._is_id(arg): result.append(arg) else: - result.append('arg%d' % len(result)) + result.append("arg%d" % len(result)) for name, value in keywords: result.append(name) return result def _is_id(self, arg): def id_or_underline(c): - return c.isalpha() or c == '_' + return c.isalpha() or c == "_" + for c in arg: if not id_or_underline(c) and not c.isdigit(): return False diff --git a/rope/refactor/__init__.py b/rope/refactor/__init__.py index 4ef675134..a43d3b86c 100644 --- a/rope/refactor/__init__.py +++ b/rope/refactor/__init__.py @@ -49,7 +49,20 @@ from rope.refactor.topackage import ModuleToPackage # noqa -__all__ = ['rename', 'move', 'inline', 'extract', 'restructure', 'topackage', - 'importutils', 'usefunction', 'change_signature', - 'encapsulate_field', 'introduce_factory', 'introduce_parameter', - 'localtofield', 'method_object', 'multiproject'] +__all__ = [ + "rename", + "move", + "inline", + "extract", + "restructure", + "topackage", + "importutils", + "usefunction", + "change_signature", + "encapsulate_field", + "introduce_factory", + "introduce_parameter", + "localtofield", + "method_object", + "multiproject", +] diff --git a/rope/refactor/change_signature.py b/rope/refactor/change_signature.py index 90f6ce1c4..f4fdcb424 100644 --- a/rope/refactor/change_signature.py +++ b/rope/refactor/change_signature.py @@ -12,56 +12,69 @@ class ChangeSignature(object): - def __init__(self, project, resource, offset): self.project = project self.resource = resource self.offset = offset self._set_name_and_pyname() - if self.pyname is None or self.pyname.get_object() is None or \ - not isinstance(self.pyname.get_object(), pyobjects.PyFunction): + if ( + self.pyname is None + or self.pyname.get_object() is None + or not isinstance(self.pyname.get_object(), pyobjects.PyFunction) + ): raise rope.base.exceptions.RefactoringError( - 'Change method signature should be performed on functions') + "Change method signature should be performed on functions" + ) def _set_name_and_pyname(self): self.name = worder.get_name_at(self.resource, self.offset) this_pymodule = self.project.get_pymodule(self.resource) - self.primary, self.pyname = evaluate.eval_location2( - this_pymodule, self.offset) + self.primary, self.pyname = evaluate.eval_location2(this_pymodule, self.offset) if self.pyname is None: return pyobject = self.pyname.get_object() - if isinstance(pyobject, pyobjects.PyClass) and \ - '__init__' in pyobject: - self.pyname = pyobject['__init__'] - self.name = '__init__' + if isinstance(pyobject, pyobjects.PyClass) and "__init__" in pyobject: + self.pyname = pyobject["__init__"] + self.name = "__init__" pyobject = self.pyname.get_object() self.others = None - if self.name == '__init__' and \ - isinstance(pyobject, pyobjects.PyFunction) and \ - isinstance(pyobject.parent, pyobjects.PyClass): + if ( + self.name == "__init__" + and isinstance(pyobject, pyobjects.PyFunction) + and isinstance(pyobject.parent, pyobjects.PyClass) + ): pyclass = pyobject.parent - self.others = (pyclass.get_name(), - pyclass.parent[pyclass.get_name()]) - - def _change_calls(self, call_changer, in_hierarchy=None, resources=None, - handle=taskhandle.NullTaskHandle()): + self.others = (pyclass.get_name(), pyclass.parent[pyclass.get_name()]) + + def _change_calls( + self, + call_changer, + in_hierarchy=None, + resources=None, + handle=taskhandle.NullTaskHandle(), + ): if resources is None: resources = self.project.get_python_files() - changes = ChangeSet('Changing signature of <%s>' % self.name) - job_set = handle.create_jobset('Collecting Changes', len(resources)) + changes = ChangeSet("Changing signature of <%s>" % self.name) + job_set = handle.create_jobset("Collecting Changes", len(resources)) finder = occurrences.create_finder( - self.project, self.name, self.pyname, instance=self.primary, - in_hierarchy=in_hierarchy and self.is_method()) + self.project, + self.name, + self.pyname, + instance=self.primary, + in_hierarchy=in_hierarchy and self.is_method(), + ) if self.others: name, pyname = self.others constructor_finder = occurrences.create_finder( - self.project, name, pyname, only_calls=True) + self.project, name, pyname, only_calls=True + ) finder = _MultipleFinders([finder, constructor_finder]) for file in resources: job_set.started_job(file.path) change_calls = _ChangeCallsInModule( - self.project, finder, file, call_changer) + self.project, finder, file, call_changer + ) changed_file = change_calls.get_changed_module() if changed_file is not None: changes.add_change(ChangeContents(file, changed_file)) @@ -81,7 +94,7 @@ def is_method(self): pyfunction = self.pyname.get_object() return isinstance(pyfunction.parent, pyobjects.PyClass) - @utils.deprecated('Use `ChangeSignature.get_args()` instead') + @utils.deprecated("Use `ChangeSignature.get_args()` instead") def get_definition_info(self): return self._definfo() @@ -91,40 +104,53 @@ def _definfo(self): @utils.deprecated() def normalize(self): changer = _FunctionChangers( - self.pyname.get_object(), self.get_definition_info(), - [ArgumentNormalizer()]) + self.pyname.get_object(), self.get_definition_info(), [ArgumentNormalizer()] + ) return self._change_calls(changer) @utils.deprecated() def remove(self, index): changer = _FunctionChangers( - self.pyname.get_object(), self.get_definition_info(), - [ArgumentRemover(index)]) + self.pyname.get_object(), + self.get_definition_info(), + [ArgumentRemover(index)], + ) return self._change_calls(changer) @utils.deprecated() def add(self, index, name, default=None, value=None): changer = _FunctionChangers( - self.pyname.get_object(), self.get_definition_info(), - [ArgumentAdder(index, name, default, value)]) + self.pyname.get_object(), + self.get_definition_info(), + [ArgumentAdder(index, name, default, value)], + ) return self._change_calls(changer) @utils.deprecated() def inline_default(self, index): changer = _FunctionChangers( - self.pyname.get_object(), self.get_definition_info(), - [ArgumentDefaultInliner(index)]) + self.pyname.get_object(), + self.get_definition_info(), + [ArgumentDefaultInliner(index)], + ) return self._change_calls(changer) @utils.deprecated() def reorder(self, new_ordering): changer = _FunctionChangers( - self.pyname.get_object(), self.get_definition_info(), - [ArgumentReorderer(new_ordering)]) + self.pyname.get_object(), + self.get_definition_info(), + [ArgumentReorderer(new_ordering)], + ) return self._change_calls(changer) - def get_changes(self, changers, in_hierarchy=False, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, + changers, + in_hierarchy=False, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Get changes caused by this refactoring `changers` is a list of `_ArgumentChanger`. If `in_hierarchy` @@ -135,14 +161,15 @@ def get_changes(self, changers, in_hierarchy=False, resources=None, in the project are searched. """ - function_changer = _FunctionChangers(self.pyname.get_object(), - self._definfo(), changers) - return self._change_calls(function_changer, in_hierarchy, - resources, task_handle) + function_changer = _FunctionChangers( + self.pyname.get_object(), self._definfo(), changers + ) + return self._change_calls( + function_changer, in_hierarchy, resources, task_handle + ) class _FunctionChangers(object): - def __init__(self, pyfunction, definition_info, changers=None): self.pyfunction = pyfunction self.definition_info = definition_info @@ -164,20 +191,19 @@ def change_definition(self, call): def change_call(self, primary, pyname, call): call_info = functionutils.CallInfo.read( - primary, pyname, self.definition_info, call) - mapping = functionutils.ArgumentMapping(self.definition_info, - call_info) + primary, pyname, self.definition_info, call + ) + mapping = functionutils.ArgumentMapping(self.definition_info, call_info) - for definition_info, changer in zip(self.changed_definition_infos, - self.changers): + for definition_info, changer in zip( + self.changed_definition_infos, self.changers + ): changer.change_argument_mapping(definition_info, mapping) - return mapping.to_call_info( - self.changed_definition_infos[-1]).to_string() + return mapping.to_call_info(self.changed_definition_infos[-1]).to_string() class _ArgumentChanger(object): - def change_definition_info(self, definition_info): pass @@ -190,22 +216,26 @@ class ArgumentNormalizer(_ArgumentChanger): class ArgumentRemover(_ArgumentChanger): - def __init__(self, index): self.index = index def change_definition_info(self, call_info): if self.index < len(call_info.args_with_defaults): del call_info.args_with_defaults[self.index] - elif self.index == len(call_info.args_with_defaults) and \ - call_info.args_arg is not None: + elif ( + self.index == len(call_info.args_with_defaults) + and call_info.args_arg is not None + ): call_info.args_arg = None - elif (self.index == len(call_info.args_with_defaults) and - call_info.args_arg is None and - call_info.keywords_arg is not None) or \ - (self.index == len(call_info.args_with_defaults) + 1 and - call_info.args_arg is not None and - call_info.keywords_arg is not None): + elif ( + self.index == len(call_info.args_with_defaults) + and call_info.args_arg is None + and call_info.keywords_arg is not None + ) or ( + self.index == len(call_info.args_with_defaults) + 1 + and call_info.args_arg is not None + and call_info.keywords_arg is not None + ): call_info.keywords_arg = None def change_argument_mapping(self, definition_info, mapping): @@ -216,7 +246,6 @@ def change_argument_mapping(self, definition_info, mapping): class ArgumentAdder(_ArgumentChanger): - def __init__(self, index, name, default=None, value=None): self.index = index self.name = name @@ -227,9 +256,9 @@ def change_definition_info(self, definition_info): for pair in definition_info.args_with_defaults: if pair[0] == self.name: raise rope.base.exceptions.RefactoringError( - 'Adding duplicate parameter: <%s>.' % self.name) - definition_info.args_with_defaults.insert(self.index, - (self.name, self.default)) + "Adding duplicate parameter: <%s>." % self.name + ) + definition_info.args_with_defaults.insert(self.index, (self.name, self.default)) def change_argument_mapping(self, definition_info, mapping): if self.value is not None: @@ -237,15 +266,16 @@ def change_argument_mapping(self, definition_info, mapping): class ArgumentDefaultInliner(_ArgumentChanger): - def __init__(self, index): self.index = index self.remove = False def change_definition_info(self, definition_info): if self.remove: - definition_info.args_with_defaults[self.index] = \ - (definition_info.args_with_defaults[self.index][0], None) + definition_info.args_with_defaults[self.index] = ( + definition_info.args_with_defaults[self.index][0], + None, + ) def change_argument_mapping(self, definition_info, mapping): default = definition_info.args_with_defaults[self.index][1] @@ -255,7 +285,6 @@ def change_argument_mapping(self, definition_info, mapping): class ArgumentReorderer(_ArgumentChanger): - def __init__(self, new_order, autodef=None): """Construct an `ArgumentReorderer` @@ -291,7 +320,6 @@ def change_definition_info(self, definition_info): class _ChangeCallsInModule(object): - def __init__(self, project, occurrence_finder, resource, call_changer): self.project = project self.occurrence_finder = occurrence_finder @@ -301,20 +329,20 @@ def __init__(self, project, occurrence_finder, resource, call_changer): def get_changed_module(self): word_finder = worder.Worder(self.source) change_collector = codeanalyze.ChangeCollector(self.source) - for occurrence in self.occurrence_finder.find_occurrences( - self.resource): + for occurrence in self.occurrence_finder.find_occurrences(self.resource): if not occurrence.is_called() and not occurrence.is_defined(): continue start, end = occurrence.get_primary_range() - begin_parens, end_parens = word_finder.\ - get_word_parens_range(end - 1) + begin_parens, end_parens = word_finder.get_word_parens_range(end - 1) if occurrence.is_called(): primary, pyname = occurrence.get_primary_and_pyname() changed_call = self.call_changer.change_call( - primary, pyname, self.source[start:end_parens]) + primary, pyname, self.source[start:end_parens] + ) else: changed_call = self.call_changer.change_definition( - self.source[start:end_parens]) + self.source[start:end_parens] + ) if changed_call is not None: change_collector.add_change(start, end_parens, changed_call) return change_collector.get_changed() @@ -339,7 +367,6 @@ def lines(self): class _MultipleFinders(object): - def __init__(self, finders): self.finders = finders @@ -349,4 +376,3 @@ def find_occurrences(self, resource=None, pymodule=None): all_occurrences.extend(finder.find_occurrences(resource, pymodule)) all_occurrences.sort(key=lambda x: x.get_primary_range()) return all_occurrences - diff --git a/rope/refactor/encapsulate_field.py b/rope/refactor/encapsulate_field.py index 9aa59f887..ce114312a 100644 --- a/rope/refactor/encapsulate_field.py +++ b/rope/refactor/encapsulate_field.py @@ -10,7 +10,6 @@ class EncapsulateField(object): - def __init__(self, project, resource, offset): self.project = project self.name = worder.get_name_at(resource, offset) @@ -18,11 +17,17 @@ def __init__(self, project, resource, offset): self.pyname = evaluate.eval_location(this_pymodule, offset) if not self._is_an_attribute(self.pyname): raise exceptions.RefactoringError( - 'Encapsulate field should be performed on class attributes.') + "Encapsulate field should be performed on class attributes." + ) self.resource = self.pyname.get_definition_location()[0].get_resource() - def get_changes(self, getter=None, setter=None, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, + getter=None, + setter=None, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Get the changes this refactoring makes If `getter` is not `None`, that will be the name of the @@ -37,20 +42,19 @@ def get_changes(self, getter=None, setter=None, resources=None, """ if resources is None: resources = self.project.get_python_files() - changes = ChangeSet('Encapsulate field <%s>' % self.name) - job_set = task_handle.create_jobset('Collecting Changes', - len(resources)) + changes = ChangeSet("Encapsulate field <%s>" % self.name) + job_set = task_handle.create_jobset("Collecting Changes", len(resources)) if getter is None: - getter = 'get_' + self.name + getter = "get_" + self.name if setter is None: - setter = 'set_' + self.name + setter = "set_" + self.name renamer = GetterSetterRenameInModule( - self.project, self.name, self.pyname, getter, setter) + self.project, self.name, self.pyname, getter, setter + ) for file in resources: job_set.started_job(file.path) if file == self.resource: - result = self._change_holding_module(changes, renamer, - getter, setter) + result = self._change_holding_module(changes, renamer, getter, setter) changes.add_change(ChangeContents(self.resource, result)) else: result = renamer.get_changed_module(file) @@ -66,18 +70,17 @@ def get_field_name(self): def _is_an_attribute(self, pyname): if pyname is not None and isinstance(pyname, pynames.AssignedName): pymodule, lineno = self.pyname.get_definition_location() - scope = pymodule.get_scope().\ - get_inner_scope_for_line(lineno) - if scope.get_kind() == 'Class': + scope = pymodule.get_scope().get_inner_scope_for_line(lineno) + if scope.get_kind() == "Class": return pyname in scope.get_names().values() parent = scope.parent - if parent is not None and parent.get_kind() == 'Class': + if parent is not None and parent.get_kind() == "Class": return pyname in parent.get_names().values() return False def _get_defining_class_scope(self): defining_scope = self._get_defining_scope() - if defining_scope.get_kind() == 'Function': + if defining_scope.get_kind() == "Function": defining_scope = defining_scope.parent return defining_scope @@ -91,25 +94,28 @@ def _change_holding_module(self, changes, renamer, getter, setter): defining_object = self._get_defining_scope().pyobject start, end = sourceutils.get_body_region(defining_object) - new_source = renamer.get_changed_module(pymodule=pymodule, - skip_start=start, skip_end=end) + new_source = renamer.get_changed_module( + pymodule=pymodule, skip_start=start, skip_end=end + ) if new_source is not None: pymodule = libutils.get_string_module( - self.project, new_source, self.resource) - class_scope = pymodule.get_scope().\ - get_inner_scope_for_line(class_scope.get_start()) - indents = sourceutils.get_indent(self.project) * ' ' - getter = 'def %s(self):\n%sreturn self.%s' % \ - (getter, indents, self.name) - setter = 'def %s(self, value):\n%sself.%s = value' % \ - (setter, indents, self.name) - new_source = sourceutils.add_methods(pymodule, class_scope, - [getter, setter]) + self.project, new_source, self.resource + ) + class_scope = pymodule.get_scope().get_inner_scope_for_line( + class_scope.get_start() + ) + indents = sourceutils.get_indent(self.project) * " " + getter = "def %s(self):\n%sreturn self.%s" % (getter, indents, self.name) + setter = "def %s(self, value):\n%sself.%s = value" % ( + setter, + indents, + self.name, + ) + new_source = sourceutils.add_methods(pymodule, class_scope, [getter, setter]) return new_source class GetterSetterRenameInModule(object): - def __init__(self, project, name, pyname, getter, setter): self.project = project self.name = name @@ -117,15 +123,16 @@ def __init__(self, project, name, pyname, getter, setter): self.getter = getter self.setter = setter - def get_changed_module(self, resource=None, pymodule=None, - skip_start=0, skip_end=0): - change_finder = _FindChangesForModule(self, resource, pymodule, - skip_start, skip_end) + def get_changed_module( + self, resource=None, pymodule=None, skip_start=0, skip_end=0 + ): + change_finder = _FindChangesForModule( + self, resource, pymodule, skip_start, skip_end + ) return change_finder.get_changed_module() class _FindChangesForModule(object): - def __init__(self, finder, resource, pymodule, skip_start, skip_end): self.project = finder.project self.finder = finder.finder @@ -141,46 +148,51 @@ def __init__(self, finder, resource, pymodule, skip_start, skip_end): def get_changed_module(self): result = [] - for occurrence in self.finder.find_occurrences(self.resource, - self.pymodule): + for occurrence in self.finder.find_occurrences(self.resource, self.pymodule): start, end = occurrence.get_word_range() if self.skip_start <= start < self.skip_end: continue self._manage_writes(start, result) - result.append(self.source[self.last_modified:start]) + result.append(self.source[self.last_modified : start]) if self._is_assigned_in_a_tuple_assignment(occurrence): raise exceptions.RefactoringError( - 'Cannot handle tuple assignments in encapsulate field.') + "Cannot handle tuple assignments in encapsulate field." + ) if occurrence.is_written(): assignment_type = self.worder.get_assignment_type(start) - if assignment_type == '=': - result.append(self.setter + '(') + if assignment_type == "=": + result.append(self.setter + "(") else: - var_name = self.source[occurrence.get_primary_range()[0]: - start] + self.getter + '()' - result.append(self.setter + '(' + var_name - + ' %s ' % assignment_type[:-1]) + var_name = ( + self.source[occurrence.get_primary_range()[0] : start] + + self.getter + + "()" + ) + result.append( + self.setter + "(" + var_name + " %s " % assignment_type[:-1] + ) current_line = self.lines.get_line_number(start) - start_line, end_line = self.pymodule.logical_lines.\ - logical_line_in(current_line) + start_line, end_line = self.pymodule.logical_lines.logical_line_in( + current_line + ) self.last_set = self.lines.get_line_end(end_line) - end = self.source.index('=', end) + 1 + end = self.source.index("=", end) + 1 self.set_index = len(result) else: - result.append(self.getter + '()') + result.append(self.getter + "()") self.last_modified = end if self.last_modified != 0: self._manage_writes(len(self.source), result) - result.append(self.source[self.last_modified:]) - return ''.join(result) + result.append(self.source[self.last_modified :]) + return "".join(result) return None def _manage_writes(self, offset, result): if self.last_set is not None and self.last_set <= offset: - result.append(self.source[self.last_modified:self.last_set]) - set_value = ''.join(result[self.set_index:]).strip() - del result[self.set_index:] - result.append(set_value + ')') + result.append(self.source[self.last_modified : self.last_set]) + set_value = "".join(result[self.set_index :]).strip() + del result[self.set_index :] + result.append(set_value + ")") self.last_modified = self.last_set self.last_set = None diff --git a/rope/refactor/extract.py b/rope/refactor/extract.py index 7a6f853b6..e3e011425 100644 --- a/rope/refactor/extract.py +++ b/rope/refactor/extract.py @@ -6,8 +6,7 @@ from rope.base.exceptions import RefactoringError from rope.base.utils import pycompat from rope.base.utils.datastructures import OrderedSet -from rope.refactor import (sourceutils, similarfinder, - patchedast, suites, usefunction) +from rope.refactor import sourceutils, similarfinder, patchedast, suites, usefunction # Extract refactoring has lots of special cases. I tried to split it @@ -38,8 +37,7 @@ class _ExtractRefactoring(object): kind_prefixes = {} - def __init__(self, project, resource, start_offset, end_offset, - variable=False): + def __init__(self, project, resource, start_offset, end_offset, variable=False): self.project = project self.resource = resource self.start_offset = self._fix_start(resource.read(), start_offset) @@ -71,13 +69,18 @@ def get_changes(self, extracted_name, similar=False, global_=False, kind=None): extracted_name, kind = self._get_kind_from_name(extracted_name, kind) info = _ExtractInfo( - self.project, self.resource, self.start_offset, self.end_offset, - extracted_name, variable=self._get_kind(kind) == 'variable', - similar=similar, make_global=global_) + self.project, + self.resource, + self.start_offset, + self.end_offset, + extracted_name, + variable=self._get_kind(kind) == "variable", + similar=similar, + make_global=global_, + ) info.kind = self._get_kind(kind) new_contents = _ExtractPerformer(info).extract() - changes = ChangeSet('Extract %s <%s>' % (info.kind, - extracted_name)) + changes = ChangeSet("Extract %s <%s>" % (info.kind, extracted_name)) changes.add_change(ChangeContents(self.resource, new_contents)) return changes @@ -99,7 +102,7 @@ def _get_kind(cls, kind): class ExtractMethod(_ExtractRefactoring): - kind = 'method' + kind = "method" allowed_kinds = ("function", "method", "staticmethod", "classmethod") kind_prefixes = {"@": "classmethod", "$": "staticmethod"} @@ -109,22 +112,23 @@ def _get_kind(cls, kind): class ExtractVariable(_ExtractRefactoring): - def __init__(self, *args, **kwds): kwds = dict(kwds) - kwds['variable'] = True + kwds["variable"] = True super(ExtractVariable, self).__init__(*args, **kwds) - kind = 'variable' + kind = "variable" def _get_kind(cls, kind): return cls.kind + class _ExtractInfo(object): """Holds information about the extract to be performed""" - def __init__(self, project, resource, start, end, new_name, - variable, similar, make_global): + def __init__( + self, project, resource, start, end, new_name, variable, similar, make_global + ): self.project = project self.resource = resource self.pymodule = project.get_pymodule(resource) @@ -140,17 +144,23 @@ def __init__(self, project, resource, start, end, new_name, self.make_global = make_global def _init_parts(self, start, end): - self.region = (self._choose_closest_line_end(start), - self._choose_closest_line_end(end, end=True)) + self.region = ( + self._choose_closest_line_end(start), + self._choose_closest_line_end(end, end=True), + ) start = self.logical_lines.logical_line_in( - self.lines.get_line_number(self.region[0]))[0] + self.lines.get_line_number(self.region[0]) + )[0] end = self.logical_lines.logical_line_in( - self.lines.get_line_number(self.region[1]))[1] + self.lines.get_line_number(self.region[1]) + )[1] self.region_lines = (start, end) - self.lines_region = (self.lines.get_line_start(self.region_lines[0]), - self.lines.get_line_end(self.region_lines[1])) + self.lines_region = ( + self.lines.get_line_start(self.region_lines[0]), + self.lines.get_line_end(self.region_lines[1]), + ) @property def logical_lines(self): @@ -159,33 +169,36 @@ def logical_lines(self): def _init_scope(self): start_line = self.region_lines[0] scope = self.global_scope.get_inner_scope_for_line(start_line) - if scope.get_kind() != 'Module' and scope.get_start() == start_line: + if scope.get_kind() != "Module" and scope.get_start() == start_line: scope = scope.parent self.scope = scope self.scope_region = self._get_scope_region(self.scope) def _get_scope_region(self, scope): - return (self.lines.get_line_start(scope.get_start()), - self.lines.get_line_end(scope.get_end()) + 1) + return ( + self.lines.get_line_start(scope.get_start()), + self.lines.get_line_end(scope.get_end()) + 1, + ) def _choose_closest_line_end(self, offset, end=False): lineno = self.lines.get_line_number(offset) line_start = self.lines.get_line_start(lineno) line_end = self.lines.get_line_end(lineno) - if self.source[line_start:offset].strip() == '': + if self.source[line_start:offset].strip() == "": if end: return line_start - 1 else: return line_start - elif self.source[offset:line_end].strip() == '': + elif self.source[offset:line_end].strip() == "": return min(line_end, len(self.source)) return offset @property def one_line(self): - return self.region != self.lines_region and \ - (self.logical_lines.logical_line_in(self.region_lines[0]) == - self.logical_lines.logical_line_in(self.region_lines[1])) + return self.region != self.lines_region and ( + self.logical_lines.logical_line_in(self.region_lines[0]) + == self.logical_lines.logical_line_in(self.region_lines[1]) + ) @property def global_(self): @@ -193,24 +206,21 @@ def global_(self): @property def method(self): - return self.scope.parent is not None and \ - self.scope.parent.get_kind() == 'Class' + return self.scope.parent is not None and self.scope.parent.get_kind() == "Class" @property def indents(self): - return sourceutils.get_indents(self.pymodule.lines, - self.region_lines[0]) + return sourceutils.get_indents(self.pymodule.lines, self.region_lines[0]) @property def scope_indents(self): if self.global_: return 0 - return sourceutils.get_indents(self.pymodule.lines, - self.scope.get_start()) + return sourceutils.get_indents(self.pymodule.lines, self.scope.get_start()) @property def extracted(self): - return self.source[self.region[0]:self.region[1]] + return self.source[self.region[0] : self.region[1]] _cached_parsed_extraced = None @@ -235,7 +245,9 @@ def returned(self): def returning_named_expr(self): """Does the extracted piece contains named expression/:= operator)""" if self._returning_named_expr is None: - self._returning_named_expr = usefunction._namedexpr_last(self._parsed_extracted) + self._returning_named_expr = usefunction._namedexpr_last( + self._parsed_extracted + ) return self._returning_named_expr @@ -253,7 +265,6 @@ def __init__(self, info): class _ExtractPerformer(object): - def __init__(self, info): self.info = info _ExceptionalConditionChecker()(self.info) @@ -271,8 +282,7 @@ def extract(self): def _replace_occurrences(self, content, extract_info): for match in extract_info.matches: - replacement = similarfinder.CodeTemplate( - extract_info.replacement_pattern) + replacement = similarfinder.CodeTemplate(extract_info.replacement_pattern) mapping = {} for name in replacement.get_names(): node = match.get_ast(name) @@ -282,8 +292,7 @@ def _replace_occurrences(self, content, extract_info): else: mapping[name] = name region = match.get_region() - content.add_change(region[0], region[1], - replacement.substitute(mapping)) + content.add_change(region[0], region[1], replacement.substitute(mapping)) def _collect_info(self): extract_collector = _ExtractCollector(self.info) @@ -297,8 +306,9 @@ def _find_matches(self, collector): finder = similarfinder.SimilarFinder(self.info.pymodule) matches = [] for start, end in regions: - region_matches = finder.get_matches(collector.body_pattern, - collector.checks, start, end) + region_matches = finder.get_matches( + collector.body_pattern, collector.checks, start, end + ) # Don't extract overlapping regions last_match_end = -1 for region_match in region_matches: @@ -312,7 +322,9 @@ def _find_matches(self, collector): @staticmethod def _is_assignment(region_match): - return isinstance(region_match.ast, (ast.Attribute, ast.Subscript)) and isinstance(region_match.ast.ctx, ast.Store) + return isinstance( + region_match.ast, (ast.Attribute, ast.Subscript) + ) and isinstance(region_match.ast.ctx, ast.Store) def _where_to_search(self): if self.info.similar: @@ -323,8 +335,10 @@ def _where_to_search(self): regions = [] method_kind = _get_function_kind(self.info.scope) for scope in class_scope.get_scopes(): - if method_kind == 'method' and \ - _get_function_kind(scope) != 'method': + if ( + method_kind == "method" + and _get_function_kind(scope) != "method" + ): continue start = self.info.lines.get_line_start(scope.get_start()) end = self.info.lines.get_line_end(scope.get_end()) @@ -334,8 +348,7 @@ def _where_to_search(self): if self.info.variable: return [self.info.scope_region] else: - return [self.info._get_scope_region( - self.info.scope.parent)] + return [self.info._get_scope_region(self.info.scope.parent)] else: return [self.info.region] @@ -346,8 +359,10 @@ def _find_definition_location(self, collector): start_line = self.info.logical_lines.logical_line_in(start)[0] matched_lines.append(start_line) location_finder = _DefinitionLocationFinder(self.info, matched_lines) - collector.definition_location = (location_finder.find_lineno(), - location_finder.find_indents()) + collector.definition_location = ( + location_finder.find_lineno(), + location_finder.find_indents(), + ) def _find_definition(self, collector): if self.info.variable: @@ -361,7 +376,6 @@ def _find_definition(self, collector): class _DefinitionLocationFinder(object): - def __init__(self, info, matched_lines): self.info = info self.matched_lines = matched_lines @@ -388,8 +402,7 @@ def _find_toplevel(self, scope): def find_indents(self): if self.info.variable and not self.info.make_global: - return sourceutils.get_indents(self.info.lines, - self._get_before_line()) + return sourceutils.get_indents(self.info.lines, self._get_before_line()) else: if self.info.global_ or self.info.make_global: return 0 @@ -404,7 +417,6 @@ def _get_after_scope(self): class _ExceptionalConditionChecker(object): - def __call__(self, info): self.base_conditions(info) if info.one_line: @@ -414,69 +426,80 @@ def __call__(self, info): def base_conditions(self, info): if info.region[1] > info.scope_region[1]: - raise RefactoringError('Bad region selected for extract method') + raise RefactoringError("Bad region selected for extract method") end_line = info.region_lines[1] end_scope = info.global_scope.get_inner_scope_for_line(end_line) if end_scope != info.scope and end_scope.get_end() != end_line: - raise RefactoringError('Bad region selected for extract method') + raise RefactoringError("Bad region selected for extract method") try: extracted = info.extracted if info.one_line: - extracted = '(%s)' % extracted + extracted = "(%s)" % extracted if _UnmatchedBreakOrContinueFinder.has_errors(extracted): - raise RefactoringError('A break/continue without having a ' - 'matching for/while loop.') + raise RefactoringError( + "A break/continue without having a " "matching for/while loop." + ) except SyntaxError: - raise RefactoringError('Extracted piece should ' - 'contain complete statements.') + raise RefactoringError( + "Extracted piece should " "contain complete statements." + ) def one_line_conditions(self, info): if self._is_region_on_a_word(info): - raise RefactoringError('Should extract complete statements.') + raise RefactoringError("Should extract complete statements.") if info.variable and not info.one_line: - raise RefactoringError('Extract variable should not ' - 'span multiple lines.') - if usefunction._named_expr_count(info._parsed_extracted) - usefunction._namedexpr_last(info._parsed_extracted): - raise RefactoringError('Extracted piece cannot ' - 'contain named expression (:= operator).') + raise RefactoringError( + "Extract variable should not " "span multiple lines." + ) + if usefunction._named_expr_count( + info._parsed_extracted + ) - usefunction._namedexpr_last(info._parsed_extracted): + raise RefactoringError( + "Extracted piece cannot " "contain named expression (:= operator)." + ) def multi_line_conditions(self, info): - node = _parse_text(info.source[info.region[0]:info.region[1]]) + node = _parse_text(info.source[info.region[0] : info.region[1]]) count = usefunction._return_count(node) extracted = info.extracted if count > 1: - raise RefactoringError('Extracted piece can have only one ' - 'return statement.') + raise RefactoringError( + "Extracted piece can have only one " "return statement." + ) if usefunction._yield_count(node): - raise RefactoringError('Extracted piece cannot ' - 'have yield statements.') - if not hasattr(ast, 'PyCF_ALLOW_TOP_LEVEL_AWAIT') and _AsyncStatementFinder.has_errors(extracted): - raise RefactoringError('Extracted piece can only have async/await ' - 'statements if Rope is running on Python ' - '3.8 or higher') + raise RefactoringError("Extracted piece cannot " "have yield statements.") + if not hasattr( + ast, "PyCF_ALLOW_TOP_LEVEL_AWAIT" + ) and _AsyncStatementFinder.has_errors(extracted): + raise RefactoringError( + "Extracted piece can only have async/await " + "statements if Rope is running on Python " + "3.8 or higher" + ) if count == 1 and not usefunction._returns_last(node): - raise RefactoringError('Return should be the last statement.') + raise RefactoringError("Return should be the last statement.") if info.region != info.lines_region: - raise RefactoringError('Extracted piece should ' - 'contain complete statements.') + raise RefactoringError( + "Extracted piece should " "contain complete statements." + ) def _is_region_on_a_word(self, info): - if info.region[0] > 0 and \ - self._is_on_a_word(info, info.region[0] - 1) or \ - self._is_on_a_word(info, info.region[1] - 1): + if ( + info.region[0] > 0 + and self._is_on_a_word(info, info.region[0] - 1) + or self._is_on_a_word(info, info.region[1] - 1) + ): return True def _is_on_a_word(self, info, offset): prev = info.source[offset] - if not (prev.isalnum() or prev == '_') or \ - offset + 1 == len(info.source): + if not (prev.isalnum() or prev == "_") or offset + 1 == len(info.source): return False next = info.source[offset + 1] - return next.isalnum() or next == '_' + return next.isalnum() or next == "_" class _ExtractMethodParts(object): - def __init__(self, info): self.info = info self.info_collector = self._create_info_collector() @@ -493,19 +516,23 @@ def _get_kind_by_scope(self): def _check_constraints(self): if self._extracting_staticmethod() or self._extracting_classmethod(): if not self.info.method: - raise RefactoringError("Cannot extract to staticmethod/classmethod outside class") + raise RefactoringError( + "Cannot extract to staticmethod/classmethod outside class" + ) def _extacting_from_staticmethod(self): - return self.info.method and _get_function_kind(self.info.scope) == "staticmethod" + return ( + self.info.method and _get_function_kind(self.info.scope) == "staticmethod" + ) def _extracting_from_classmethod(self): return self.info.method and _get_function_kind(self.info.scope) == "classmethod" def get_definition(self): if self.info.global_: - return '\n%s\n' % self._get_function_definition() + return "\n%s\n" % self._get_function_definition() else: - return '\n%s' % self._get_function_definition() + return "\n%s" % self._get_function_definition() def get_replacement_pattern(self): variables = [] @@ -523,29 +550,29 @@ def get_body_pattern(self): def _get_body(self): result = sourceutils.fix_indentation(self.info.extracted, 0) if self.info.one_line: - result = '(%s)' % result + result = "(%s)" % result return result def _find_temps(self): - return usefunction.find_temps(self.info.project, - self._get_body()) + return usefunction.find_temps(self.info.project, self._get_body()) def get_checks(self): if self.info.method and not self.info.make_global: - if _get_function_kind(self.info.scope) == 'method': + if _get_function_kind(self.info.scope) == "method": class_name = similarfinder._pydefined_to_str( - self.info.scope.parent.pyobject) - return {self._get_self_name(): 'type=' + class_name} + self.info.scope.parent.pyobject + ) + return {self._get_self_name(): "type=" + class_name} return {} def _create_info_collector(self): zero = self.info.scope.get_start() - 1 start_line = self.info.region_lines[0] - zero end_line = self.info.region_lines[1] - zero - info_collector = _FunctionInformationCollector(start_line, end_line, - self.info.global_) - body = self.info.source[self.info.scope_region[0]: - self.info.scope_region[1]] + info_collector = _FunctionInformationCollector( + start_line, end_line, self.info.global_ + ) + body = self.info.source[self.info.scope_region[0] : self.info.scope_region[1]] node = _parse_text(body) ast.walk(node, info_collector) return info_collector @@ -556,20 +583,20 @@ def _get_function_definition(self): result = [] self._append_decorators(result) - result.append('def %s:\n' % self._get_function_signature(args)) + result.append("def %s:\n" % self._get_function_signature(args)) unindented_body = self._get_unindented_function_body(returns) indents = sourceutils.get_indent(self.info.project) function_body = sourceutils.indent_lines(unindented_body, indents) result.append(function_body) - definition = ''.join(result) + definition = "".join(result) - return definition + '\n' + return definition + "\n" def _append_decorators(self, result): if self._extracting_staticmethod(): - result.append('@staticmethod\n') + result.append("@staticmethod\n") elif self._extracting_classmethod(): - result.append('@classmethod\n') + result.append("@classmethod\n") def _extracting_classmethod(self): return self.info.kind == "classmethod" @@ -579,23 +606,24 @@ def _extracting_staticmethod(self): def _get_function_signature(self, args): args = list(args) - prefix = '' + prefix = "" if self._extracting_method() or self._extracting_classmethod(): self_name = self._get_self_name() if self_name is None: - raise RefactoringError('Extracting a method from a function ' - 'with no self argument.') + raise RefactoringError( + "Extracting a method from a function " "with no self argument." + ) if self_name in args: args.remove(self_name) args.insert(0, self_name) - return prefix + self.info.new_name + \ - '(%s)' % self._get_comma_form(args) + return prefix + self.info.new_name + "(%s)" % self._get_comma_form(args) def _extracting_method(self): - return not self._extracting_staticmethod() and \ - (self.info.method and \ - not self.info.make_global and \ - _get_function_kind(self.info.scope) == 'method') + return not self._extracting_staticmethod() and ( + self.info.method + and not self.info.make_global + and _get_function_kind(self.info.scope) == "method" + ) def _get_self_name(self): if self._extracting_classmethod(): @@ -610,68 +638,72 @@ def _get_scope_self_name(self): return param_names[0] def _get_function_call(self, args): - return '{prefix}{name}({args})'.format( + return "{prefix}{name}({args})".format( prefix=self._get_function_call_prefix(args), name=self.info.new_name, - args=self._get_comma_form(args)) + args=self._get_comma_form(args), + ) def _get_function_call_prefix(self, args): - prefix = '' + prefix = "" if self.info.method and not self.info.make_global: if self._extracting_staticmethod() or self._extracting_classmethod(): - prefix = self.info.scope.parent.pyobject.get_name() + '.' + prefix = self.info.scope.parent.pyobject.get_name() + "." else: self_name = self._get_self_name() if self_name in args: args.remove(self_name) - prefix = self_name + '.' + prefix = self_name + "." return prefix def _get_comma_form(self, names): - return ', '.join(names) + return ", ".join(names) def _get_call(self): args = self._find_function_arguments() returns = self._find_function_returns() - call_prefix = '' + call_prefix = "" if returns and (not self.info.one_line or self.info.returning_named_expr): - assignment_operator = ' := ' if self.info.one_line else ' = ' + assignment_operator = " := " if self.info.one_line else " = " call_prefix = self._get_comma_form(returns) + assignment_operator if self.info.returned: - call_prefix = 'return ' + call_prefix = "return " return call_prefix + self._get_function_call(args) def _find_function_arguments(self): # if not make_global, do not pass any global names; they are # all visible. if self.info.global_ and not self.info.make_global: - return list(self.info_collector.read & self.info_collector.postread & self.info_collector.written) + return list( + self.info_collector.read + & self.info_collector.postread + & self.info_collector.written + ) if not self.info.one_line: - result = (self.info_collector.prewritten & - self.info_collector.read) - result |= (self.info_collector.prewritten & - self.info_collector.postread & - (self.info_collector.maybe_written - - self.info_collector.written)) + result = self.info_collector.prewritten & self.info_collector.read + result |= ( + self.info_collector.prewritten + & self.info_collector.postread + & (self.info_collector.maybe_written - self.info_collector.written) + ) return list(result) start = self.info.region[0] if start == self.info.lines_region[0]: - start = start + re.search('\\S', self.info.extracted).start() - function_definition = self.info.source[start:self.info.region[1]] + start = start + re.search("\\S", self.info.extracted).start() + function_definition = self.info.source[start : self.info.region[1]] read = _VariableReadsAndWritesFinder.find_reads_for_one_liners( - function_definition) + function_definition + ) return list(self.info_collector.prewritten.intersection(read)) def _find_function_returns(self): if self.info.one_line: - written = self.info_collector.written | \ - self.info_collector.maybe_written + written = self.info_collector.written | self.info_collector.maybe_written return list(written & self.info_collector.postread) if self.info.returned: return [] - written = self.info_collector.written | \ - self.info_collector.maybe_written + written = self.info_collector.written | self.info_collector.maybe_written return list(written & self.info_collector.postread) def _get_unindented_function_body(self, returns): @@ -683,23 +715,27 @@ def _get_multiline_function_body(self, returns): unindented_body = sourceutils.fix_indentation(self.info.extracted, 0) unindented_body = self._insert_globals(unindented_body) if returns: - unindented_body += '\nreturn %s' % self._get_comma_form(returns) + unindented_body += "\nreturn %s" % self._get_comma_form(returns) return unindented_body def _get_one_line_function_body(self): if self.info.returning_named_expr: - body = 'return ' + '(' + _join_lines(self.info.extracted) + ')' + body = "return " + "(" + _join_lines(self.info.extracted) + ")" else: - body = 'return ' + _join_lines(self.info.extracted) + body = "return " + _join_lines(self.info.extracted) return self._insert_globals(body) def _insert_globals(self, unindented_body): globals_in_body = self._get_globals_in_body(unindented_body) - globals_ = self.info_collector.globals_ & (self.info_collector.written | self.info_collector.maybe_written) + globals_ = self.info_collector.globals_ & ( + self.info_collector.written | self.info_collector.maybe_written + ) globals_ = globals_ - globals_in_body if globals_: - unindented_body = "global {}\n{}".format(", ".join(globals_), unindented_body) + unindented_body = "global {}\n{}".format( + ", ".join(globals_), unindented_body + ) return unindented_body @staticmethod @@ -709,18 +745,17 @@ def _get_globals_in_body(unindented_body): ast.walk(node, visitor) return visitor.globals_ -class _ExtractVariableParts(object): +class _ExtractVariableParts(object): def __init__(self, info): self.info = info def get_definition(self): - result = self.info.new_name + ' = ' + \ - _join_lines(self.info.extracted) + '\n' + result = self.info.new_name + " = " + _join_lines(self.info.extracted) + "\n" return result def get_body_pattern(self): - return '(%s)' % self.info.extracted.strip() + return "(%s)" % self.info.extracted.strip() def get_replacement_pattern(self): return self.info.new_name @@ -730,7 +765,6 @@ def get_checks(self): class _FunctionInformationCollector(object): - def __init__(self, start, end, is_global): self.start = start self.end = end @@ -852,8 +886,11 @@ def _handle_loop_context(self, node): def _get_argnames(arguments): - result = [pycompat.get_ast_arg_arg(node) for node in arguments.args - if isinstance(node, pycompat.ast_arg_type)] + result = [ + pycompat.get_ast_arg_arg(node) + for node in arguments.args + if isinstance(node, pycompat.ast_arg_type) + ] if arguments.vararg: result.append(pycompat.get_ast_arg_arg(arguments.vararg)) if arguments.kwarg: @@ -862,7 +899,6 @@ def _get_argnames(arguments): class _VariableReadsAndWritesFinder(object): - def __init__(self): self.written = set() self.read = set() @@ -885,7 +921,7 @@ def _Class(self, node): @staticmethod def find_reads_and_writes(code): - if code.strip() == '': + if code.strip() == "": return set(), set() node = _parse_text(code) visitor = _VariableReadsAndWritesFinder() @@ -894,7 +930,7 @@ def find_reads_and_writes(code): @staticmethod def find_reads_for_one_liners(code): - if code.strip() == '': + if code.strip() == "": return set(), set() node = _parse_text(code) visitor = _VariableReadsAndWritesFinder() @@ -905,7 +941,7 @@ def find_reads_for_one_liners(code): class _BaseErrorFinder(object): @classmethod def has_errors(cls, code): - if code.strip() == '': + if code.strip() == "": return False node = _parse_text(code) visitor = cls() @@ -914,7 +950,6 @@ def has_errors(cls, code): class _UnmatchedBreakOrContinueFinder(_BaseErrorFinder): - def __init__(self): self.error = False self.loop_count = 0 @@ -931,7 +966,7 @@ def loop_encountered(self, node): ast.walk(child, self) self.loop_count -= 1 if node.orelse: - if isinstance(node.orelse,(list,tuple)): + if isinstance(node.orelse, (list, tuple)): for node_ in node.orelse: ast.walk(node_, self) else: @@ -955,7 +990,6 @@ def _ClassDef(self, node): class _AsyncStatementFinder(_BaseErrorFinder): - def __init__(self): self.error = False @@ -991,9 +1025,12 @@ def _parse_text(body): except SyntaxError: # needed to parse expression containing := operator try: - node = ast.parse('(' + body + ')') + node = ast.parse("(" + body + ")") except SyntaxError: - node = ast.parse('async def __rope_placeholder__():\n' + sourceutils.fix_indentation(body, 4)) + node = ast.parse( + "async def __rope_placeholder__():\n" + + sourceutils.fix_indentation(body, 4) + ) node.body = node.body[0].body return node @@ -1001,8 +1038,8 @@ def _parse_text(body): def _join_lines(code): lines = [] for line in code.splitlines(): - if line.endswith('\\'): + if line.endswith("\\"): lines.append(line[:-1].strip()) else: lines.append(line.strip()) - return ' '.join(lines) + return " ".join(lines) diff --git a/rope/refactor/functionutils.py b/rope/refactor/functionutils.py index 58baf9174..f17438943 100644 --- a/rope/refactor/functionutils.py +++ b/rope/refactor/functionutils.py @@ -5,9 +5,9 @@ class DefinitionInfo(object): - - def __init__(self, function_name, is_method, args_with_defaults, - args_arg, keywords_arg): + def __init__( + self, function_name, is_method, args_with_defaults, args_arg, keywords_arg + ): self.function_name = function_name self.is_method = is_method self.args_with_defaults = args_with_defaults @@ -15,40 +15,45 @@ def __init__(self, function_name, is_method, args_with_defaults, self.keywords_arg = keywords_arg def to_string(self): - return '%s(%s)' % (self.function_name, self.arguments_to_string()) + return "%s(%s)" % (self.function_name, self.arguments_to_string()) def arguments_to_string(self, from_index=0): params = [] for arg, default in self.args_with_defaults: if default is not None: - params.append('%s=%s' % (arg, default)) + params.append("%s=%s" % (arg, default)) else: params.append(arg) if self.args_arg is not None: - params.append('*' + self.args_arg) + params.append("*" + self.args_arg) if self.keywords_arg: - params.append('**' + self.keywords_arg) - return ', '.join(params[from_index:]) + params.append("**" + self.keywords_arg) + return ", ".join(params[from_index:]) @staticmethod def _read(pyfunction, code): kind = pyfunction.get_kind() - is_method = kind == 'method' - is_lambda = kind == 'lambda' + is_method = kind == "method" + is_lambda = kind == "lambda" info = _FunctionParser(code, is_method, is_lambda) args, keywords = info.get_parameters() args_arg = None keywords_arg = None - if args and args[-1].startswith('**'): + if args and args[-1].startswith("**"): keywords_arg = args[-1][2:] del args[-1] - if args and args[-1].startswith('*'): + if args and args[-1].startswith("*"): args_arg = args[-1][1:] del args[-1] args_with_defaults = [(name, None) for name in args] args_with_defaults.extend(keywords) - return DefinitionInfo(info.get_function_name(), is_method, - args_with_defaults, args_arg, keywords_arg) + return DefinitionInfo( + info.get_function_name(), + is_method, + args_with_defaults, + args_arg, + keywords_arg, + ) @staticmethod def read(pyfunction): @@ -64,9 +69,16 @@ def read(pyfunction): class CallInfo(object): - - def __init__(self, function_name, args, keywords, args_arg, - keywords_arg, implicit_arg, constructor): + def __init__( + self, + function_name, + args, + keywords, + args_arg, + keywords_arg, + implicit_arg, + constructor, + ): self.function_name = function_name self.args = args self.keywords = keywords @@ -78,7 +90,7 @@ def __init__(self, function_name, args, keywords, args_arg, def to_string(self): function = self.function_name if self.implicit_arg: - function = self.args[0] + '.' + self.function_name + function = self.args[0] + "." + self.function_name params = [] start = 0 if self.implicit_arg or self.constructor: @@ -86,13 +98,12 @@ def to_string(self): if self.args[start:]: params.extend(self.args[start:]) if self.keywords: - params.extend(['%s=%s' % (name, value) - for name, value in self.keywords]) + params.extend(["%s=%s" % (name, value) for name, value in self.keywords]) if self.args_arg is not None: - params.append('*' + self.args_arg) + params.append("*" + self.args_arg) if self.keywords_arg: - params.append('**' + self.keywords_arg) - return '%s(%s)' % (function, ', '.join(params)) + params.append("**" + self.keywords_arg) + return "%s(%s)" % (function, ", ".join(params)) @staticmethod def read(primary, pyname, definition_info, code): @@ -103,48 +114,56 @@ def read(primary, pyname, definition_info, code): args, keywords = info.get_parameters() args_arg = None keywords_arg = None - if args and args[-1].startswith('**'): + if args and args[-1].startswith("**"): keywords_arg = args[-1][2:] del args[-1] - if args and args[-1].startswith('*'): + if args and args[-1].startswith("*"): args_arg = args[-1][1:] del args[-1] if is_constructor: args.insert(0, definition_info.args_with_defaults[0][0]) - return CallInfo(info.get_function_name(), args, keywords, args_arg, - keywords_arg, is_method_call or is_classmethod, - is_constructor) + return CallInfo( + info.get_function_name(), + args, + keywords, + args_arg, + keywords_arg, + is_method_call or is_classmethod, + is_constructor, + ) @staticmethod def _is_method_call(primary, pyname): - return primary is not None and \ - isinstance(primary.get_object().get_type(), - rope.base.pyobjects.PyClass) and \ - CallInfo._is_method(pyname) + return ( + primary is not None + and isinstance(primary.get_object().get_type(), rope.base.pyobjects.PyClass) + and CallInfo._is_method(pyname) + ) @staticmethod def _is_class(pyname): - return pyname is not None and \ - isinstance(pyname.get_object(), - rope.base.pyobjects.PyClass) + return pyname is not None and isinstance( + pyname.get_object(), rope.base.pyobjects.PyClass + ) @staticmethod def _is_method(pyname): - if pyname is not None and \ - isinstance(pyname.get_object(), rope.base.pyobjects.PyFunction): - return pyname.get_object().get_kind() == 'method' + if pyname is not None and isinstance( + pyname.get_object(), rope.base.pyobjects.PyFunction + ): + return pyname.get_object().get_kind() == "method" return False @staticmethod def _is_classmethod(pyname): - if pyname is not None and \ - isinstance(pyname.get_object(), rope.base.pyobjects.PyFunction): - return pyname.get_object().get_kind() == 'classmethod' + if pyname is not None and isinstance( + pyname.get_object(), rope.base.pyobjects.PyFunction + ): + return pyname.get_object().get_kind() == "classmethod" return False class ArgumentMapping(object): - def __init__(self, definition_info, call_info): self.call_info = call_info self.param_dict = {} @@ -180,37 +199,42 @@ def to_call_info(self, definition_info): break args.extend(self.args_arg) keywords.extend(self.keyword_args) - return CallInfo(self.call_info.function_name, args, keywords, - self.call_info.args_arg, self.call_info.keywords_arg, - self.call_info.implicit_arg, - self.call_info.constructor) + return CallInfo( + self.call_info.function_name, + args, + keywords, + self.call_info.args_arg, + self.call_info.keywords_arg, + self.call_info.implicit_arg, + self.call_info.constructor, + ) class _FunctionParser(object): - def __init__(self, call, implicit_arg, is_lambda=False): self.call = call self.implicit_arg = implicit_arg self.word_finder = worder.Worder(self.call) if is_lambda: - self.last_parens = self.call.rindex(':') + self.last_parens = self.call.rindex(":") else: - self.last_parens = self.call.rindex(')') - self.first_parens = self.word_finder._find_parens_start( - self.last_parens) + self.last_parens = self.call.rindex(")") + self.first_parens = self.word_finder._find_parens_start(self.last_parens) def get_parameters(self): - args, keywords = self.word_finder.get_parameters(self.first_parens, - self.last_parens) + args, keywords = self.word_finder.get_parameters( + self.first_parens, self.last_parens + ) if self.is_called_as_a_method(): - instance = self.call[:self.call.rindex('.', 0, self.first_parens)] + instance = self.call[: self.call.rindex(".", 0, self.first_parens)] args.insert(0, instance.strip()) return args, keywords def get_instance(self): if self.is_called_as_a_method(): return self.word_finder.get_primary_at( - self.call.rindex('.', 0, self.first_parens) - 1) + self.call.rindex(".", 0, self.first_parens) - 1 + ) def get_function_name(self): if self.is_called_as_a_method(): @@ -219,4 +243,4 @@ def get_function_name(self): return self.word_finder.get_primary_at(self.first_parens - 1) def is_called_as_a_method(self): - return self.implicit_arg and '.' in self.call[:self.first_parens] + return self.implicit_arg and "." in self.call[: self.first_parens] diff --git a/rope/refactor/importutils/__init__.py b/rope/refactor/importutils/__init__.py index 0f35f522c..0638a4dba 100644 --- a/rope/refactor/importutils/__init__.py +++ b/rope/refactor/importutils/__init__.py @@ -26,53 +26,58 @@ def __init__(self, project): def organize_imports(self, resource, offset=None): return self._perform_command_on_import_tools( - self.import_tools.organize_imports, resource, offset) + self.import_tools.organize_imports, resource, offset + ) def expand_star_imports(self, resource, offset=None): return self._perform_command_on_import_tools( - self.import_tools.expand_stars, resource, offset) + self.import_tools.expand_stars, resource, offset + ) def froms_to_imports(self, resource, offset=None): return self._perform_command_on_import_tools( - self.import_tools.froms_to_imports, resource, offset) + self.import_tools.froms_to_imports, resource, offset + ) def relatives_to_absolutes(self, resource, offset=None): return self._perform_command_on_import_tools( - self.import_tools.relatives_to_absolutes, resource, offset) + self.import_tools.relatives_to_absolutes, resource, offset + ) def handle_long_imports(self, resource, offset=None): return self._perform_command_on_import_tools( - self.import_tools.handle_long_imports, resource, offset) + self.import_tools.handle_long_imports, resource, offset + ) def _perform_command_on_import_tools(self, method, resource, offset): pymodule = self.project.get_pymodule(resource) before_performing = pymodule.source_code import_filter = None if offset is not None: - import_filter = self._line_filter( - pymodule.lines.get_line_number(offset)) + import_filter = self._line_filter(pymodule.lines.get_line_number(offset)) result = method(pymodule, import_filter=import_filter) if result is not None and result != before_performing: - changes = ChangeSet(method.__name__.replace('_', ' ') + - ' in <%s>' % resource.path) + changes = ChangeSet( + method.__name__.replace("_", " ") + " in <%s>" % resource.path + ) changes.add_change(ChangeContents(resource, result)) return changes def _line_filter(self, lineno): def import_filter(import_stmt): return import_stmt.start_line <= lineno < import_stmt.end_line + return import_filter class ImportTools(object): - def __init__(self, project): self.project = project def get_import(self, resource): """The import statement for `resource`""" module_name = libutils.modname(resource) - return NormalImport(((module_name, None), )) + return NormalImport(((module_name, None),)) def get_from_import(self, resource, name): """The from import statement for `name` in `resource`""" @@ -81,30 +86,33 @@ def get_from_import(self, resource, name): if isinstance(name, list): names = [(imported, None) for imported in name] else: - names = [(name, None), ] + names = [ + (name, None), + ] return FromImport(module_name, 0, tuple(names)) def module_imports(self, module, imports_filter=None): - return module_imports.ModuleImports(self.project, module, - imports_filter) + return module_imports.ModuleImports(self.project, module, imports_filter) def froms_to_imports(self, pymodule, import_filter=None): pymodule = self._clean_up_imports(pymodule, import_filter) module_imports = self.module_imports(pymodule, import_filter) for import_stmt in module_imports.imports: - if import_stmt.readonly or \ - not self._is_transformable_to_normal(import_stmt.import_info): + if import_stmt.readonly or not self._is_transformable_to_normal( + import_stmt.import_info + ): continue pymodule = self._from_to_normal(pymodule, import_stmt) # Adding normal imports in place of froms module_imports = self.module_imports(pymodule, import_filter) for import_stmt in module_imports.imports: - if not import_stmt.readonly and \ - self._is_transformable_to_normal(import_stmt.import_info): - import_stmt.import_info = \ - NormalImport(((import_stmt.import_info.module_name, - None),)) + if not import_stmt.readonly and self._is_transformable_to_normal( + import_stmt.import_info + ): + import_stmt.import_info = NormalImport( + ((import_stmt.import_info.module_name, None),) + ) module_imports.remove_duplicates() return module_imports.get_changed_source() @@ -122,13 +130,16 @@ def _from_to_normal(self, pymodule, import_stmt): if alias is not None: imported = alias occurrence_finder = occurrences.create_finder( - self.project, imported, pymodule[imported], imports=False) + self.project, imported, pymodule[imported], imports=False + ) source = rename.rename_in_module( - occurrence_finder, module_name + '.' + name, - pymodule=pymodule, replace_primary=True) + occurrence_finder, + module_name + "." + name, + pymodule=pymodule, + replace_primary=True, + ) if source is not None: - pymodule = libutils.get_string_module( - self.project, source, resource) + pymodule = libutils.get_string_module(self.project, source, resource) return pymodule def _clean_up_imports(self, pymodule, import_filter): @@ -137,20 +148,17 @@ def _clean_up_imports(self, pymodule, import_filter): module_with_imports.expand_stars() source = module_with_imports.get_changed_source() if source is not None: - pymodule = libutils.get_string_module( - self.project, source, resource) + pymodule = libutils.get_string_module(self.project, source, resource) source = self.relatives_to_absolutes(pymodule) if source is not None: - pymodule = libutils.get_string_module( - self.project, source, resource) + pymodule = libutils.get_string_module(self.project, source, resource) module_with_imports = self.module_imports(pymodule, import_filter) module_with_imports.remove_duplicates() module_with_imports.remove_unused_imports() source = module_with_imports.get_changed_source() if source is not None: - pymodule = libutils.get_string_module( - self.project, source, resource) + pymodule = libutils.get_string_module(self.project, source, resource) return pymodule def relatives_to_absolutes(self, pymodule, import_filter=None): @@ -170,9 +178,15 @@ def _is_transformable_to_normal(self, import_info): return False return True - def organize_imports(self, pymodule, - unused=True, duplicates=True, - selfs=True, sort=True, import_filter=None): + def organize_imports( + self, + pymodule, + unused=True, + duplicates=True, + selfs=True, + sort=True, + import_filter=None, + ): if unused or duplicates: module_imports = self.module_imports(pymodule, import_filter) if unused: @@ -184,7 +198,8 @@ def organize_imports(self, pymodule, source = module_imports.get_changed_source() if source is not None: pymodule = libutils.get_string_module( - self.project, source, pymodule.get_resource()) + self.project, source, pymodule.get_resource() + ) if selfs: pymodule = self._remove_self_imports(pymodule, import_filter) if sort: @@ -194,12 +209,13 @@ def organize_imports(self, pymodule, def _remove_self_imports(self, pymodule, import_filter=None): module_imports = self.module_imports(pymodule, import_filter) - to_be_fixed, to_be_renamed = \ - module_imports.get_self_import_fix_and_rename_list() + ( + to_be_fixed, + to_be_renamed, + ) = module_imports.get_self_import_fix_and_rename_list() for name in to_be_fixed: try: - pymodule = self._rename_in_module(pymodule, name, '', - till_dot=True) + pymodule = self._rename_in_module(pymodule, name, "", till_dot=True) except ValueError: # There is a self import with direct access to it return pymodule @@ -210,31 +226,33 @@ def _remove_self_imports(self, pymodule, import_filter=None): source = module_imports.get_changed_source() if source is not None: pymodule = libutils.get_string_module( - self.project, source, pymodule.get_resource()) + self.project, source, pymodule.get_resource() + ) return pymodule def _rename_in_module(self, pymodule, name, new_name, till_dot=False): - old_name = name.split('.')[-1] + old_name = name.split(".")[-1] old_pyname = rope.base.evaluate.eval_str(pymodule.get_scope(), name) occurrence_finder = occurrences.create_finder( - self.project, old_name, old_pyname, imports=False) + self.project, old_name, old_pyname, imports=False + ) changes = rope.base.codeanalyze.ChangeCollector(pymodule.source_code) - for occurrence in occurrence_finder.find_occurrences( - pymodule=pymodule): + for occurrence in occurrence_finder.find_occurrences(pymodule=pymodule): start, end = occurrence.get_primary_range() if till_dot: - new_end = pymodule.source_code.index('.', end) + 1 - space = pymodule.source_code[end:new_end - 1].strip() - if not space == '': + new_end = pymodule.source_code.index(".", end) + 1 + space = pymodule.source_code[end : new_end - 1].strip() + if not space == "": for c in space: - if not c.isspace() and c not in '\\': + if not c.isspace() and c not in "\\": raise ValueError() end = new_end changes.add_change(start, end, new_name) source = changes.get_changed() if source is not None: pymodule = libutils.get_string_module( - self.project, source, pymodule.get_resource()) + self.project, source, pymodule.get_resource() + ) return pymodule def sort_imports(self, pymodule, import_filter=None): @@ -242,22 +260,25 @@ def sort_imports(self, pymodule, import_filter=None): module_imports.sort_imports() return module_imports.get_changed_source() - def handle_long_imports(self, pymodule, maxdots=2, maxlength=27, - import_filter=None): + def handle_long_imports( + self, pymodule, maxdots=2, maxlength=27, import_filter=None + ): # IDEA: `maxdots` and `maxlength` can be specified in project config # adding new from imports module_imports = self.module_imports(pymodule, import_filter) to_be_fixed = module_imports.handle_long_imports(maxdots, maxlength) # performing the renaming pymodule = libutils.get_string_module( - self.project, module_imports.get_changed_source(), - resource=pymodule.get_resource()) + self.project, + module_imports.get_changed_source(), + resource=pymodule.get_resource(), + ) for name in to_be_fixed: - pymodule = self._rename_in_module(pymodule, name, - name.split('.')[-1]) + pymodule = self._rename_in_module(pymodule, name, name.split(".")[-1]) # organizing imports - return self.organize_imports(pymodule, selfs=False, sort=False, - import_filter=import_filter) + return self.organize_imports( + pymodule, selfs=False, sort=False, import_filter=import_filter + ) def get_imports(project, pydefined): @@ -285,20 +306,20 @@ def add_import(project, pymodule, module_name, name=None): names.append(name) candidates.append(from_import) # from pkg import mod - if '.' in module_name: - pkg, mod = module_name.rsplit('.', 1) + if "." in module_name: + pkg, mod = module_name.rsplit(".", 1) from_import = FromImport(pkg, 0, [(mod, None)]) - if project.prefs.get('prefer_module_from_imports'): + if project.prefs.get("prefer_module_from_imports"): selected_import = from_import candidates.append(from_import) if name: - names.append(mod + '.' + name) + names.append(mod + "." + name) else: names.append(mod) # import mod normal_import = NormalImport([(module_name, None)]) if name: - names.append(module_name + '.' + name) + names.append(module_name + "." + name) else: names.append(module_name) diff --git a/rope/refactor/importutils/actions.py b/rope/refactor/importutils/actions.py index 2839cf2e7..5414a06eb 100644 --- a/rope/refactor/importutils/actions.py +++ b/rope/refactor/importutils/actions.py @@ -5,10 +5,9 @@ class ImportInfoVisitor(object): - def dispatch(self, import_): try: - method_name = 'visit' + import_.import_info.__class__.__name__ + method_name = "visit" + import_.import_info.__class__.__name__ method = getattr(self, method_name) return method(import_, import_.import_info) except exceptions.ModuleNotFoundError: @@ -25,7 +24,6 @@ def visitFromImport(self, import_stmt, import_info): class RelativeToAbsoluteVisitor(ImportInfoVisitor): - def __init__(self, project, current_folder): self.to_be_absolute = [] self.project = project @@ -33,8 +31,7 @@ def __init__(self, project, current_folder): self.context = importinfo.ImportContext(project, current_folder) def visitNormalImport(self, import_stmt, import_info): - self.to_be_absolute.extend( - self._get_relative_to_absolute_list(import_info)) + self.to_be_absolute.extend(self._get_relative_to_absolute_list(import_info)) new_pairs = [] for name, alias in import_info.names_and_aliases: resource = self.project.find_module(name, folder=self.folder) @@ -44,7 +41,8 @@ def visitNormalImport(self, import_stmt, import_info): absolute_name = libutils.modname(resource) new_pairs.append((absolute_name, alias)) if not import_info._are_name_and_alias_lists_equal( - new_pairs, import_info.names_and_aliases): + new_pairs, import_info.names_and_aliases + ): import_stmt.import_info = importinfo.NormalImport(new_pairs) def _get_relative_to_absolute_list(self, import_info): @@ -67,11 +65,11 @@ def visitFromImport(self, import_stmt, import_info): absolute_name = libutils.modname(resource) if import_info.module_name != absolute_name: import_stmt.import_info = importinfo.FromImport( - absolute_name, 0, import_info.names_and_aliases) + absolute_name, 0, import_info.names_and_aliases + ) class FilteringVisitor(ImportInfoVisitor): - def __init__(self, project, folder, can_select): self.to_be_absolute = [] self.project = project @@ -84,6 +82,7 @@ def can_select_name_and_alias(name, alias): if alias is not None: imported = alias return can_select(imported) + return can_select_name_and_alias def visitNormalImport(self, import_stmt, import_info): @@ -107,11 +106,11 @@ def visitFromImport(self, import_stmt, import_info): if self.can_select(name, alias): new_pairs.append((name, alias)) return importinfo.FromImport( - import_info.module_name, import_info.level, new_pairs) + import_info.module_name, import_info.level, new_pairs + ) class RemovingVisitor(ImportInfoVisitor): - def __init__(self, project, folder, can_select): self.to_be_absolute = [] self.project = project @@ -148,46 +147,52 @@ def visitNormalImport(self, import_stmt, import_info): if not isinstance(self.import_info, import_info.__class__): return False # Adding ``import x`` and ``import x.y`` that results ``import x.y`` - if len(import_info.names_and_aliases) == \ - len(self.import_info.names_and_aliases) == 1: + if ( + len(import_info.names_and_aliases) + == len(self.import_info.names_and_aliases) + == 1 + ): imported1 = import_info.names_and_aliases[0] imported2 = self.import_info.names_and_aliases[0] if imported1[1] == imported2[1] is None: - if imported1[0].startswith(imported2[0] + '.'): + if imported1[0].startswith(imported2[0] + "."): return True - if imported2[0].startswith(imported1[0] + '.'): + if imported2[0].startswith(imported1[0] + "."): import_stmt.import_info = self.import_info return True # Multiple imports using a single import statement is discouraged # so we won't bother adding them. if self.import_info._are_name_and_alias_lists_equal( - import_info.names_and_aliases, - self.import_info.names_and_aliases): + import_info.names_and_aliases, self.import_info.names_and_aliases + ): return True def visitFromImport(self, import_stmt, import_info): - if isinstance(self.import_info, import_info.__class__) and \ - import_info.module_name == self.import_info.module_name and \ - import_info.level == self.import_info.level: + if ( + isinstance(self.import_info, import_info.__class__) + and import_info.module_name == self.import_info.module_name + and import_info.level == self.import_info.level + ): if import_info.is_star_import(): return True if self.import_info.is_star_import(): import_stmt.import_info = self.import_info return True if self.project.prefs.get("split_imports"): - return self.import_info.names_and_aliases == \ - import_info.names_and_aliases + return ( + self.import_info.names_and_aliases == import_info.names_and_aliases + ) new_pairs = list(import_info.names_and_aliases) for pair in self.import_info.names_and_aliases: if pair not in new_pairs: new_pairs.append(pair) import_stmt.import_info = importinfo.FromImport( - import_info.module_name, import_info.level, new_pairs) + import_info.module_name, import_info.level, new_pairs + ) return True class ExpandStarsVisitor(ImportInfoVisitor): - def __init__(self, project, folder, can_select): self.project = project self.filtering = FilteringVisitor(project, folder, can_select) @@ -202,15 +207,14 @@ def visitFromImport(self, import_stmt, import_info): for name in import_info.get_imported_names(self.context): new_pairs.append((name, None)) new_import = importinfo.FromImport( - import_info.module_name, import_info.level, new_pairs) - import_stmt.import_info = \ - self.filtering.visitFromImport(None, new_import) + import_info.module_name, import_info.level, new_pairs + ) + import_stmt.import_info = self.filtering.visitFromImport(None, new_import) else: self.filtering.dispatch(import_stmt) class SelfImportVisitor(ImportInfoVisitor): - def __init__(self, project, current_folder, resource): self.project = project self.folder = current_folder @@ -231,7 +235,8 @@ def visitNormalImport(self, import_stmt, import_info): else: new_pairs.append((name, alias)) if not import_info._are_name_and_alias_lists_equal( - new_pairs, import_info.names_and_aliases): + new_pairs, import_info.names_and_aliases + ): import_stmt.import_info = importinfo.NormalImport(new_pairs) def visitFromImport(self, import_stmt, import_info): @@ -246,8 +251,10 @@ def visitFromImport(self, import_stmt, import_info): for name, alias in import_info.names_and_aliases: try: result = pymodule[name].get_object() - if isinstance(result, pyobjects.PyModule) and \ - result.get_resource() == self.resource: + if ( + isinstance(result, pyobjects.PyModule) + and result.get_resource() == self.resource + ): imported = name if alias is not None: imported = alias @@ -257,9 +264,11 @@ def visitFromImport(self, import_stmt, import_info): except exceptions.AttributeNotFoundError: new_pairs.append((name, alias)) if not import_info._are_name_and_alias_lists_equal( - new_pairs, import_info.names_and_aliases): + new_pairs, import_info.names_and_aliases + ): import_stmt.import_info = importinfo.FromImport( - import_info.module_name, import_info.level, new_pairs) + import_info.module_name, import_info.level, new_pairs + ) def _importing_names_from_self(self, import_info, import_stmt): if not import_info.is_star_import(): @@ -270,7 +279,6 @@ def _importing_names_from_self(self, import_info, import_stmt): class SortingVisitor(ImportInfoVisitor): - def __init__(self, project, current_folder): self.project = project self.folder = current_folder @@ -283,14 +291,12 @@ def __init__(self, project, current_folder): def visitNormalImport(self, import_stmt, import_info): if import_info.names_and_aliases: name, alias = import_info.names_and_aliases[0] - resource = self.project.find_module( - name, folder=self.folder) + resource = self.project.find_module(name, folder=self.folder) self._check_imported_resource(import_stmt, resource, name) def visitFromImport(self, import_stmt, import_info): resource = import_info.get_imported_resource(self.context) - self._check_imported_resource(import_stmt, resource, - import_info.module_name) + self._check_imported_resource(import_stmt, resource, import_info.module_name) def _check_imported_resource(self, import_stmt, resource, imported_name): info = import_stmt.import_info @@ -298,14 +304,13 @@ def _check_imported_resource(self, import_stmt, resource, imported_name): self.in_project.add(import_stmt) elif _is_future(info): self.future.add(import_stmt) - elif imported_name.split('.')[0] in stdmods.standard_modules(): + elif imported_name.split(".")[0] in stdmods.standard_modules(): self.standard.add(import_stmt) else: self.third_party.add(import_stmt) class LongImportVisitor(ImportInfoVisitor): - def __init__(self, current_folder, project, maxdots, maxlength): self.maxdots = maxdots self.maxlength = maxlength @@ -318,19 +323,20 @@ def visitNormalImport(self, import_stmt, import_info): for name, alias in import_info.names_and_aliases: if alias is None and self._is_long(name): self.to_be_renamed.add(name) - last_dot = name.rindex('.') + last_dot = name.rindex(".") from_ = name[:last_dot] - imported = name[last_dot + 1:] + imported = name[last_dot + 1 :] self.new_imports.append( - importinfo.FromImport(from_, 0, ((imported, None), ))) + importinfo.FromImport(from_, 0, ((imported, None),)) + ) def _is_long(self, name): - return name.count('.') > self.maxdots or \ - ('.' in name and len(name) > self.maxlength) + return name.count(".") > self.maxdots or ( + "." in name and len(name) > self.maxlength + ) class RemovePyNameVisitor(ImportInfoVisitor): - def __init__(self, project, pymodule, pyname, folder): self.pymodule = pymodule self.pyname = pyname @@ -348,7 +354,8 @@ def visitFromImport(self, import_stmt, import_info): pass new_pairs.append((name, alias)) return importinfo.FromImport( - import_info.module_name, import_info.level, new_pairs) + import_info.module_name, import_info.level, new_pairs + ) def dispatch(self, import_): result = ImportInfoVisitor.dispatch(self, import_) @@ -357,5 +364,4 @@ def dispatch(self, import_): def _is_future(info): - return isinstance(info, importinfo.FromImport) and \ - info.module_name == '__future__' + return isinstance(info, importinfo.FromImport) and info.module_name == "__future__" diff --git a/rope/refactor/importutils/importinfo.py b/rope/refactor/importutils/importinfo.py index 114080aac..6452b691c 100644 --- a/rope/refactor/importutils/importinfo.py +++ b/rope/refactor/importutils/importinfo.py @@ -6,8 +6,9 @@ class ImportStatement(object): """ - def __init__(self, import_info, start_line, end_line, - main_statement=None, blank_lines=0): + def __init__( + self, import_info, start_line, end_line, main_statement=None, blank_lines=0 + ): self.start_line = start_line self.end_line = end_line self.readonly = False @@ -22,8 +23,11 @@ def _get_import_info(self): return self._import_info def _set_import_info(self, new_import): - if not self.readonly and \ - new_import is not None and not new_import == self._import_info: + if ( + not self.readonly + and new_import is not None + and not new_import == self._import_info + ): self._is_changed = True self._import_info = new_import @@ -49,21 +53,22 @@ def get_new_start(self): return self.new_start def is_changed(self): - return self._is_changed or (self.new_start is not None or - self.new_start != self.start_line) + return self._is_changed or ( + self.new_start is not None or self.new_start != self.start_line + ) def accept(self, visitor): return visitor.dispatch(self) class ImportInfo(object): - def get_imported_primaries(self, context): pass def get_imported_names(self, context): - return [primary.split('.')[0] - for primary in self.get_imported_primaries(context)] + return [ + primary.split(".")[0] for primary in self.get_imported_primaries(context) + ] def get_import_statement(self): pass @@ -83,8 +88,10 @@ def _are_name_and_alias_lists_equal(self, list1, list2): return True def __eq__(self, obj): - return isinstance(obj, self.__class__) and \ - self.get_import_statement() == obj.get_import_statement() + return ( + isinstance(obj, self.__class__) + and self.get_import_statement() == obj.get_import_statement() + ) def __ne__(self, obj): return not self.__eq__(obj) @@ -95,7 +102,6 @@ def get_empty_import(): class NormalImport(ImportInfo): - def __init__(self, names_and_aliases): self.names_and_aliases = names_and_aliases @@ -109,12 +115,12 @@ def get_imported_primaries(self, context): return result def get_import_statement(self): - result = 'import ' + result = "import " for name, alias in self.names_and_aliases: result += name if alias: - result += ' as ' + alias - result += ', ' + result += " as " + alias + result += ", " return result[:-2] def is_empty(self): @@ -122,17 +128,15 @@ def is_empty(self): class FromImport(ImportInfo): - def __init__(self, module_name, level, names_and_aliases): self.module_name = module_name self.level = level self.names_and_aliases = names_and_aliases def get_imported_primaries(self, context): - if self.names_and_aliases[0][0] == '*': + if self.names_and_aliases[0][0] == "*": module = self.get_imported_module(context) - return [name for name in module - if not name.startswith('_')] + return [name for name in module if not name.startswith("_")] result = [] for name, alias in self.names_and_aliases: if alias: @@ -147,11 +151,11 @@ def get_imported_resource(self, context): Returns `None` if module was not found. """ if self.level == 0: - return context.project.find_module( - self.module_name, folder=context.folder) + return context.project.find_module(self.module_name, folder=context.folder) else: return context.project.find_relative_module( - self.module_name, context.folder, self.level) + self.module_name, context.folder, self.level + ) def get_imported_module(self, context): """Get the imported `PyModule` @@ -160,27 +164,26 @@ def get_imported_module(self, context): could not be found. """ if self.level == 0: - return context.project.get_module( - self.module_name, context.folder) + return context.project.get_module(self.module_name, context.folder) else: return context.project.get_relative_module( - self.module_name, context.folder, self.level) + self.module_name, context.folder, self.level + ) def get_import_statement(self): - result = 'from ' + '.' * self.level + self.module_name + ' import ' + result = "from " + "." * self.level + self.module_name + " import " for name, alias in self.names_and_aliases: result += name if alias: - result += ' as ' + alias - result += ', ' + result += " as " + alias + result += ", " return result[:-2] def is_empty(self): return len(self.names_and_aliases) == 0 def is_star_import(self): - return len(self.names_and_aliases) > 0 and \ - self.names_and_aliases[0][0] == '*' + return len(self.names_and_aliases) > 0 and self.names_and_aliases[0][0] == "*" class EmptyImport(ImportInfo): @@ -195,7 +198,6 @@ def get_imported_primaries(self, context): class ImportContext(object): - def __init__(self, project, folder): self.project = project self.folder = folder diff --git a/rope/refactor/importutils/module_imports.py b/rope/refactor/importutils/module_imports.py index 2436879a4..ee745d11d 100644 --- a/rope/refactor/importutils/module_imports.py +++ b/rope/refactor/importutils/module_imports.py @@ -7,7 +7,6 @@ class ModuleImports(object): - def __init__(self, project, pymodule, import_filter=None): self.project = project self.pymodule = pymodule @@ -35,7 +34,7 @@ def _get_unbound_names(self, defined_pyobject): def _get_all_star_list(self, pymodule): result = set() try: - all_star_list = pymodule.get_attribute('__all__') + all_star_list = pymodule.get_attribute("__all__") except exceptions.AttributeNotFoundError: return result @@ -61,18 +60,22 @@ def _get_all_star_list(self, pymodule): return result def remove_unused_imports(self): - can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule) | self._get_all_star_list(self.pymodule)) + can_select = _OneTimeSelector( + self._get_unbound_names(self.pymodule) + | self._get_all_star_list(self.pymodule) + ) visitor = actions.RemovingVisitor( - self.project, self._current_folder(), can_select) + self.project, self._current_folder(), can_select + ) for import_statement in self.imports: import_statement.accept(visitor) def get_used_imports(self, defined_pyobject): result = [] - can_select = _OneTimeSelector( - self._get_unbound_names(defined_pyobject)) + can_select = _OneTimeSelector(self._get_unbound_names(defined_pyobject)) visitor = actions.FilteringVisitor( - self.project, self._current_folder(), can_select) + self.project, self._current_folder(), can_select + ) for import_statement in self.imports: new_import = import_statement.accept(visitor) if new_import is not None and not new_import.is_empty(): @@ -80,9 +83,8 @@ def get_used_imports(self, defined_pyobject): return result def get_changed_source(self): - if (not self.project.prefs.get("pull_imports_to_top") and - not self.sorted): - return ''.join(self._rewrite_imports(self.imports)) + if not self.project.prefs.get("pull_imports_to_top") and not self.sorted: + return "".join(self._rewrite_imports(self.imports)) # Make sure we forward a removed import's preceding blank # lines count to the following import statement. @@ -92,8 +94,7 @@ def get_changed_source(self): stmt.blank_lines = max(prev_stmt.blank_lines, stmt.blank_lines) prev_stmt = stmt # The new list of imports. - imports = [stmt for stmt in self.imports - if not stmt.import_info.is_empty()] + imports = [stmt for stmt in self.imports if not stmt.import_info.is_empty()] after_removing = self._remove_imports(self.imports) first_non_blank = self._first_non_blank_line(after_removing, 0) @@ -105,16 +106,15 @@ def get_changed_source(self): sorted_imports = sorted(imports, key=self._get_location) for stmt in sorted_imports: if stmt != sorted_imports[0]: - result.append('\n' * stmt.blank_lines) - result.append(stmt.get_import_statement() + '\n') + result.append("\n" * stmt.blank_lines) + result.append(stmt.get_import_statement() + "\n") if sorted_imports and first_non_blank < len(after_removing): - result.append('\n' * self.separating_lines) + result.append("\n" * self.separating_lines) # Writing the body - first_after_imports = self._first_non_blank_line(after_removing, - first_import) + first_after_imports = self._first_non_blank_line(after_removing, first_import) result.extend(after_removing[first_after_imports:]) - return ''.join(result) + return "".join(result) def _get_import_location(self, stmt): start = stmt.get_new_start() @@ -137,9 +137,10 @@ def _remove_imports(self, imports): start, end = stmt.get_old_location() blank_lines = 0 if start != first_import_line: - blank_lines = _count_blank_lines(lines.__getitem__, start - 2, - last_index - 1, -1) - after_removing.extend(lines[last_index:start - 1 - blank_lines]) + blank_lines = _count_blank_lines( + lines.__getitem__, start - 2, last_index - 1, -1 + ) + after_removing.extend(lines[last_index : start - 1 - blank_lines]) last_index = end - 1 after_removing.extend(lines[last_index:]) return after_removing @@ -150,16 +151,15 @@ def _rewrite_imports(self, imports): last_index = 0 for stmt in imports: start, end = stmt.get_old_location() - after_rewriting.extend(lines[last_index:start - 1]) + after_rewriting.extend(lines[last_index : start - 1]) if not stmt.import_info.is_empty(): - after_rewriting.append(stmt.get_import_statement() + '\n') + after_rewriting.append(stmt.get_import_statement() + "\n") last_index = end - 1 after_rewriting.extend(lines[last_index:]) return after_rewriting def _first_non_blank_line(self, lines, lineno): - return lineno + _count_blank_lines(lines.__getitem__, lineno, - len(lines)) + return lineno + _count_blank_lines(lines.__getitem__, lineno, len(lines)) def add_import(self, import_info): visitor = actions.AddingVisitor(self.project, [import_info]) @@ -169,9 +169,11 @@ def add_import(self, import_info): else: lineno = self._get_new_import_lineno() blanks = self._get_new_import_blanks() - self.imports.append(importinfo.ImportStatement( - import_info, lineno, lineno, - blank_lines=blanks)) + self.imports.append( + importinfo.ImportStatement( + import_info, lineno, lineno, blank_lines=blanks + ) + ) def _get_new_import_blanks(self): return 0 @@ -183,22 +185,23 @@ def _get_new_import_lineno(self): def filter_names(self, can_select): visitor = actions.RemovingVisitor( - self.project, self._current_folder(), can_select) + self.project, self._current_folder(), can_select + ) for import_statement in self.imports: import_statement.accept(visitor) def expand_stars(self): can_select = _OneTimeSelector(self._get_unbound_names(self.pymodule)) visitor = actions.ExpandStarsVisitor( - self.project, self._current_folder(), can_select) + self.project, self._current_folder(), can_select + ) for import_statement in self.imports: import_statement.accept(visitor) def remove_duplicates(self): added_imports = [] for import_stmt in self.imports: - visitor = actions.AddingVisitor(self.project, - [import_stmt.import_info]) + visitor = actions.AddingVisitor(self.project, [import_stmt.import_info]) for added_import in added_imports: if added_import.accept(visitor): import_stmt.empty_import() @@ -215,8 +218,8 @@ def force_single_imports(self): for name_and_alias in import_info.names_and_aliases: if hasattr(import_info, "module_name"): new_import = importinfo.FromImport( - import_info.module_name, import_info.level, - [name_and_alias]) + import_info.module_name, import_info.level, [name_and_alias] + ) else: new_import = importinfo.NormalImport([name_and_alias]) self.add_import(new_import) @@ -224,7 +227,8 @@ def force_single_imports(self): def get_relative_to_absolute_list(self): visitor = actions.RelativeToAbsoluteVisitor( - self.project, self._current_folder()) + self.project, self._current_folder() + ) for import_stmt in self.imports: if not import_stmt.readonly: import_stmt.accept(visitor) @@ -232,7 +236,8 @@ def get_relative_to_absolute_list(self): def get_self_import_fix_and_rename_list(self): visitor = actions.SelfImportVisitor( - self.project, self._current_folder(), self.pymodule.get_resource()) + self.project, self._current_folder(), self.pymodule.get_resource() + ) for import_stmt in self.imports: if not import_stmt.readonly: import_stmt.accept(visitor) @@ -269,22 +274,27 @@ def _first_import_line(self): if self.pymodule.get_doc() is not None: lineno = 1 if len(nodes) > lineno: - if (isinstance(nodes[lineno], ast.Import) or - isinstance(nodes[lineno], ast.ImportFrom)): + if isinstance(nodes[lineno], ast.Import) or isinstance( + nodes[lineno], ast.ImportFrom + ): return nodes[lineno].lineno - lineno = self.pymodule.logical_lines.logical_line_in( - nodes[lineno].lineno)[0] + lineno = self.pymodule.logical_lines.logical_line_in(nodes[lineno].lineno)[ + 0 + ] else: lineno = self.pymodule.lines.length() - return lineno - _count_blank_lines(self.pymodule.lines.get_line, - lineno - 1, 1, -1) + return lineno - _count_blank_lines( + self.pymodule.lines.get_line, lineno - 1, 1, -1 + ) def _get_import_name(self, import_stmt): import_info = import_stmt.import_info if hasattr(import_info, "module_name"): - return "%s.%s" % (import_info.module_name, - import_info.names_and_aliases[0][0]) + return "%s.%s" % ( + import_info.module_name, + import_info.names_and_aliases[0][0], + ) else: return import_info.names_and_aliases[0][0] @@ -292,13 +302,13 @@ def _key_imports(self, stm1): str1 = stm1.get_import_statement() return str1.startswith("from "), str1 - #str1 = stmt1.get_import_statement() - #str2 = stmt2.get_import_statement() - #if str1.startswith('from ') and not str2.startswith('from '): + # str1 = stmt1.get_import_statement() + # str2 = stmt2.get_import_statement() + # if str1.startswith('from ') and not str2.startswith('from '): # return 1 - #if not str1.startswith('from ') and str2.startswith('from '): + # if not str1.startswith('from ') and str2.startswith('from '): # return -1 - #return cmp(str1, str2) + # return cmp(str1, str2) def _move_imports(self, imports, index, blank_lines): if imports: @@ -312,7 +322,8 @@ def _move_imports(self, imports, index, blank_lines): def handle_long_imports(self, maxdots, maxlength): visitor = actions.LongImportVisitor( - self._current_folder(), self.project, maxdots, maxlength) + self._current_folder(), self.project, maxdots, maxlength + ) for import_statement in self.imports: if not import_statement.readonly: import_statement.accept(visitor) @@ -322,8 +333,9 @@ def handle_long_imports(self, maxdots, maxlength): def remove_pyname(self, pyname): """Removes pyname when imported in ``from mod import x``""" - visitor = actions.RemovePyNameVisitor(self.project, self.pymodule, - pyname, self._current_folder()) + visitor = actions.RemovePyNameVisitor( + self.project, self.pymodule, pyname, self._current_folder() + ) for import_stmt in self.imports: import_stmt.accept(visitor) @@ -331,7 +343,7 @@ def remove_pyname(self, pyname): def _count_blank_lines(get_line, start, end, step=1): count = 0 for idx in range(start, end, step): - if get_line(idx).strip() == '': + if get_line(idx).strip() == "": count += 1 else: break @@ -339,7 +351,6 @@ def _count_blank_lines(get_line, start, end, step=1): class _OneTimeSelector(object): - def __init__(self, names): self.names = names self.selected_names = set() @@ -352,9 +363,9 @@ def __call__(self, imported_primary): return False def _get_dotted_tokens(self, imported_primary): - tokens = imported_primary.split('.') + tokens = imported_primary.split(".") for i in range(len(tokens)): - yield '.'.join(tokens[:i + 1]) + yield ".".join(tokens[: i + 1]) def _can_name_be_added(self, imported_primary): for name in self._get_dotted_tokens(imported_primary): @@ -364,13 +375,16 @@ def _can_name_be_added(self, imported_primary): class _UnboundNameFinder(object): - def __init__(self, pyobject): self.pyobject = pyobject def _visit_child_scope(self, node): - pyobject = self.pyobject.get_module().get_scope().\ - get_inner_scope_for_line(node.lineno).pyobject + pyobject = ( + self.pyobject.get_module() + .get_scope() + .get_inner_scope_for_line(node.lineno) + .pyobject + ) visitor = _LocalUnboundNameFinder(pyobject, self) for child in ast.get_child_nodes(node): ast.walk(child, visitor) @@ -382,8 +396,7 @@ def _ClassDef(self, node): self._visit_child_scope(node) def _Name(self, node): - if self._get_root()._is_node_interesting(node) and \ - not self.is_bound(node.id): + if self._get_root()._is_node_interesting(node) and not self.is_bound(node.id): self.add_unbound(node.id) def _Attribute(self, node): @@ -393,9 +406,10 @@ def _Attribute(self, node): node = node.value if isinstance(node, ast.Name): result.append(node.id) - primary = '.'.join(reversed(result)) - if self._get_root()._is_node_interesting(node) and \ - not self.is_bound(primary): + primary = ".".join(reversed(result)) + if self._get_root()._is_node_interesting(node) and not self.is_bound( + primary + ): self.add_unbound(primary) else: ast.walk(node, self) @@ -411,14 +425,12 @@ def add_unbound(self, name): class _GlobalUnboundNameFinder(_UnboundNameFinder): - def __init__(self, pymodule, wanted_pyobject): super(_GlobalUnboundNameFinder, self).__init__(pymodule) self.unbound = set() self.names = set() for name, pyname in pymodule._get_structural_attributes().items(): - if not isinstance(pyname, (pynames.ImportedName, - pynames.ImportedModule)): + if not isinstance(pyname, (pynames.ImportedName, pynames.ImportedModule)): self.names.add(name) wanted_scope = wanted_pyobject.get_scope() self.start = wanted_scope.get_start() @@ -428,20 +440,19 @@ def _get_root(self): return self def is_bound(self, primary, propagated=False): - name = primary.split('.')[0] + name = primary.split(".")[0] return name in self.names def add_unbound(self, name): - names = name.split('.') + names = name.split(".") for i in range(len(names)): - self.unbound.add('.'.join(names[:i + 1])) + self.unbound.add(".".join(names[: i + 1])) def _is_node_interesting(self, node): return self.start <= node.lineno < self.end class _LocalUnboundNameFinder(_UnboundNameFinder): - def __init__(self, pyobject, parent): super(_LocalUnboundNameFinder, self).__init__(pyobject) self.parent = parent @@ -450,7 +461,7 @@ def _get_root(self): return self.parent._get_root() def is_bound(self, primary, propagated=False): - name = primary.split('.')[0] + name = primary.split(".")[0] if propagated: names = self.pyobject.get_scope().get_propagated_names() else: @@ -464,7 +475,6 @@ def add_unbound(self, name): class _GlobalImportFinder(object): - def __init__(self, pymodule): self.current_folder = None if pymodule.get_resource(): @@ -478,16 +488,18 @@ def visit_import(self, node, end_line): start_line = node.lineno import_statement = importinfo.ImportStatement( importinfo.NormalImport(self._get_names(node.names)), - start_line, end_line, self._get_text(start_line, end_line), - blank_lines=self._count_empty_lines_before(start_line)) + start_line, + end_line, + self._get_text(start_line, end_line), + blank_lines=self._count_empty_lines_before(start_line), + ) self.imports.append(import_statement) def _count_empty_lines_before(self, lineno): return _count_blank_lines(self.lines.get_line, lineno - 1, 0, -1) def _count_empty_lines_after(self, lineno): - return _count_blank_lines(self.lines.get_line, lineno + 1, - self.lines.length()) + return _count_blank_lines(self.lines.get_line, lineno + 1, self.lines.length()) def get_separating_line_count(self): if not self.imports: @@ -498,21 +510,27 @@ def _get_text(self, start_line, end_line): result = [] for index in range(start_line, end_line): result.append(self.lines.get_line(index)) - return '\n'.join(result) + return "\n".join(result) def visit_from(self, node, end_line): level = 0 if node.level: level = node.level import_info = importinfo.FromImport( - node.module or '', # see comment at rope.base.ast.walk - level, self._get_names(node.names)) + node.module or "", # see comment at rope.base.ast.walk + level, + self._get_names(node.names), + ) start_line = node.lineno - self.imports.append(importinfo.ImportStatement( - import_info, node.lineno, end_line, - self._get_text(start_line, end_line), - blank_lines= - self._count_empty_lines_before(start_line))) + self.imports.append( + importinfo.ImportStatement( + import_info, + node.lineno, + end_line, + self._get_text(start_line, end_line), + blank_lines=self._count_empty_lines_before(start_line), + ) + ) def _get_names(self, alias_names): result = [] diff --git a/rope/refactor/inline.py b/rope/refactor/inline.py index 467edefaa..d7f272eaf 100644 --- a/rope/refactor/inline.py +++ b/rope/refactor/inline.py @@ -20,11 +20,25 @@ import rope.base.exceptions import rope.refactor.functionutils -from rope.base import (pynames, pyobjects, codeanalyze, - taskhandle, evaluate, worder, utils, libutils) +from rope.base import ( + pynames, + pyobjects, + codeanalyze, + taskhandle, + evaluate, + worder, + utils, + libutils, +) from rope.base.change import ChangeSet, ChangeContents -from rope.refactor import (occurrences, rename, sourceutils, - importutils, move, change_signature) +from rope.refactor import ( + occurrences, + rename, + sourceutils, + importutils, + move, + change_signature, +) def unique_prefix(): @@ -42,8 +56,10 @@ def create_inline(project, resource, offset): """ pyname = _get_pyname(project, resource, offset) - message = 'Inline refactoring should be performed on ' \ - 'a method, local variable or parameter.' + message = ( + "Inline refactoring should be performed on " + "a method, local variable or parameter." + ) if pyname is None: raise rope.base.exceptions.RefactoringError(message) if isinstance(pyname, pynames.ImportedName): @@ -59,7 +75,6 @@ def create_inline(project, resource, offset): class _Inliner(object): - def __init__(self, project, resource, offset): self.project = project self.pyname = _get_pyname(project, resource, offset) @@ -77,25 +92,24 @@ def get_kind(self): class InlineMethod(_Inliner): - def __init__(self, *args, **kwds): super(InlineMethod, self).__init__(*args, **kwds) self.pyfunction = self.pyname.get_object() self.pymodule = self.pyfunction.get_module() self.resource = self.pyfunction.get_module().get_resource() self.occurrence_finder = occurrences.create_finder( - self.project, self.name, self.pyname) - self.normal_generator = _DefinitionGenerator(self.project, - self.pyfunction) + self.project, self.name, self.pyname + ) + self.normal_generator = _DefinitionGenerator(self.project, self.pyfunction) self._init_imports() def _init_imports(self): body = sourceutils.get_body(self.pyfunction) - body, imports = move.moving_code_with_imports( - self.project, self.resource, body) + body, imports = move.moving_code_with_imports(self.project, self.resource, body) self.imports = imports self.others_generator = _DefinitionGenerator( - self.project, self.pyfunction, body=body) + self.project, self.pyfunction, body=body + ) def _get_scope_range(self): scope = self.pyfunction.get_scope() @@ -103,49 +117,57 @@ def _get_scope_range(self): start_line = scope.get_start() if self.pyfunction.decorators: decorators = self.pyfunction.decorators - if hasattr(decorators[0], 'lineno'): + if hasattr(decorators[0], "lineno"): start_line = decorators[0].lineno start_offset = lines.get_line_start(start_line) - end_offset = min(lines.get_line_end(scope.end) + 1, - len(self.pymodule.source_code)) + end_offset = min( + lines.get_line_end(scope.end) + 1, len(self.pymodule.source_code) + ) return (start_offset, end_offset) - def get_changes(self, remove=True, only_current=False, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, + remove=True, + only_current=False, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Get the changes this refactoring makes If `remove` is `False` the definition will not be removed. If `only_current` is `True`, the the current occurrence will be inlined, only. """ - changes = ChangeSet('Inline method <%s>' % self.name) + changes = ChangeSet("Inline method <%s>" % self.name) if resources is None: resources = self.project.get_python_files() if only_current: resources = [self.original] if remove: resources.append(self.resource) - job_set = task_handle.create_jobset('Collecting Changes', - len(resources)) + job_set = task_handle.create_jobset("Collecting Changes", len(resources)) for file in resources: job_set.started_job(file.path) if file == self.resource: - changes.add_change(self._defining_file_changes( - changes, remove=remove, only_current=only_current)) + changes.add_change( + self._defining_file_changes( + changes, remove=remove, only_current=only_current + ) + ) else: aim = None if only_current and self.original == file: aim = self.offset handle = _InlineFunctionCallsForModuleHandle( - self.project, file, self.others_generator, aim) + self.project, file, self.others_generator, aim + ) result = move.ModuleSkipRenamer( - self.occurrence_finder, file, handle).get_changed_module() + self.occurrence_finder, file, handle + ).get_changed_module() if result is not None: - result = _add_imports(self.project, result, - file, self.imports) + result = _add_imports(self.project, result, file, self.imports) if remove: - result = _remove_from(self.project, self.pyname, - result, file) + result = _remove_from(self.project, self.pyname, result, file) changes.add_change(ChangeContents(file, result)) job_set.finished_job() return changes @@ -156,12 +178,11 @@ def _get_removed_range(self): start, end = self._get_scope_range() end_line = scope.get_end() for i in range(end_line + 1, lines.length()): - if lines.get_line(i).strip() == '': + if lines.get_line(i).strip() == "": end_line = i else: break - end = min(lines.get_line_end(end_line) + 1, - len(self.pymodule.source_code)) + end = min(lines.get_line_end(end_line) + 1, len(self.pymodule.source_code)) return (start, end) def _defining_file_changes(self, changes, remove, only_current): @@ -174,22 +195,28 @@ def _defining_file_changes(self, changes, remove, only_current): # we don't want to change any of them aim = len(self.resource.read()) + 100 handle = _InlineFunctionCallsForModuleHandle( - self.project, self.resource, - self.normal_generator, aim_offset=aim) + self.project, self.resource, self.normal_generator, aim_offset=aim + ) replacement = None if remove: replacement = self._get_method_replacement() result = move.ModuleSkipRenamer( - self.occurrence_finder, self.resource, handle, start_offset, - end_offset, replacement).get_changed_module() + self.occurrence_finder, + self.resource, + handle, + start_offset, + end_offset, + replacement, + ).get_changed_module() return ChangeContents(self.resource, result) def _get_method_replacement(self): if self._is_the_last_method_of_a_class(): indents = sourceutils.get_indents( - self.pymodule.lines, self.pyfunction.get_scope().get_start()) - return ' ' * indents + 'pass\n' - return '' + self.pymodule.lines, self.pyfunction.get_scope().get_start() + ) + return " " * indents + "pass\n" + return "" def _is_the_last_method_of_a_class(self): pyclass = self.pyfunction.parent @@ -198,17 +225,18 @@ def _is_the_last_method_of_a_class(self): class_start, class_end = sourceutils.get_body_region(pyclass) source = self.pymodule.source_code func_start, func_end = self._get_scope_range() - if source[class_start:func_start].strip() == '' and \ - source[func_end:class_end].strip() == '': + if ( + source[class_start:func_start].strip() == "" + and source[func_end:class_end].strip() == "" + ): return True return False def get_kind(self): - return 'method' + return "method" class InlineVariable(_Inliner): - def __init__(self, *args, **kwds): super(InlineVariable, self).__init__(*args, **kwds) self.pymodule = self.pyname.get_definition_location()[0] @@ -219,10 +247,17 @@ def __init__(self, *args, **kwds): def _check_exceptional_conditions(self): if len(self.pyname.assignments) != 1: raise rope.base.exceptions.RefactoringError( - 'Local variable should be assigned once for inlining.') - - def get_changes(self, remove=True, only_current=False, resources=None, - docs=False, task_handle=taskhandle.NullTaskHandle()): + "Local variable should be assigned once for inlining." + ) + + def get_changes( + self, + remove=True, + only_current=False, + resources=None, + docs=False, + task_handle=taskhandle.NullTaskHandle(), + ): if resources is None: if rename._is_local(self.pyname): resources = [self.resource] @@ -232,9 +267,8 @@ def get_changes(self, remove=True, only_current=False, resources=None, resources = [self.original] if remove and self.original != self.resource: resources.append(self.resource) - changes = ChangeSet('Inline variable <%s>' % self.name) - jobset = task_handle.create_jobset('Calculating changes', - len(resources)) + changes = ChangeSet("Inline variable <%s>" % self.name) + jobset = task_handle.create_jobset("Calculating changes", len(resources)) for resource in resources: jobset.started_job(resource.path) @@ -244,8 +278,7 @@ def get_changes(self, remove=True, only_current=False, resources=None, else: result = self._change_module(resource, remove, only_current) if result is not None: - result = _add_imports(self.project, result, - resource, self.imports) + result = _add_imports(self.project, result, resource, self.imports) changes.add_change(ChangeContents(resource, result)) jobset.finished_job() return changes @@ -254,45 +287,53 @@ def _change_main_module(self, remove, only_current, docs): region = None if only_current and self.original == self.resource: region = self.region - return _inline_variable(self.project, self.pymodule, self.pyname, - self.name, remove=remove, region=region, - docs=docs) + return _inline_variable( + self.project, + self.pymodule, + self.pyname, + self.name, + remove=remove, + region=region, + docs=docs, + ) def _init_imports(self): vardef = _getvardef(self.pymodule, self.pyname) self.imported, self.imports = move.moving_code_with_imports( - self.project, self.resource, vardef) + self.project, self.resource, vardef + ) def _change_module(self, resource, remove, only_current): - filters = [occurrences.NoImportsFilter(), - occurrences.PyNameFilter(self.pyname)] + filters = [occurrences.NoImportsFilter(), occurrences.PyNameFilter(self.pyname)] if only_current and resource == self.original: + def check_aim(occurrence): start, end = occurrence.get_primary_range() if self.offset < start or end < self.offset: return False + filters.insert(0, check_aim) finder = occurrences.Finder(self.project, self.name, filters=filters) changed = rename.rename_in_module( - finder, self.imported, resource=resource, replace_primary=True) + finder, self.imported, resource=resource, replace_primary=True + ) if changed and remove: - changed = _remove_from(self.project, self.pyname, - changed, resource) + changed = _remove_from(self.project, self.pyname, changed, resource) return changed def get_kind(self): - return 'variable' + return "variable" class InlineParameter(_Inliner): - def __init__(self, *args, **kwds): super(InlineParameter, self).__init__(*args, **kwds) resource, offset = self._function_location() index = self.pyname.index self.changers = [change_signature.ArgumentDefaultInliner(index)] - self.signature = change_signature.ChangeSignature(self.project, - resource, offset) + self.signature = change_signature.ChangeSignature( + self.project, resource, offset + ) def _function_location(self): pymodule, lineno = self.pyname.get_definition_location() @@ -311,17 +352,17 @@ def get_changes(self, **kwds): return self.signature.get_changes(self.changers, **kwds) def get_kind(self): - return 'parameter' + return "parameter" def _join_lines(lines): definition_lines = [] for unchanged_line in lines: line = unchanged_line.strip() - if line.endswith('\\'): + if line.endswith("\\"): line = line[:-1].strip() definition_lines.append(line) - joined = ' '.join(definition_lines) + joined = " ".join(definition_lines) return joined @@ -347,39 +388,43 @@ def _get_definition_info(self): def _get_definition_params(self): definition_info = self.definition_info paramdict = dict([pair for pair in definition_info.args_with_defaults]) - if definition_info.args_arg is not None or \ - definition_info.keywords_arg is not None: + if ( + definition_info.args_arg is not None + or definition_info.keywords_arg is not None + ): raise rope.base.exceptions.RefactoringError( - 'Cannot inline functions with list and keyword arguements.') - if self.pyfunction.get_kind() == 'classmethod': - paramdict[definition_info.args_with_defaults[0][0]] = \ - self.pyfunction.parent.get_name() + "Cannot inline functions with list and keyword arguements." + ) + if self.pyfunction.get_kind() == "classmethod": + paramdict[ + definition_info.args_with_defaults[0][0] + ] = self.pyfunction.parent.get_name() return paramdict def get_function_name(self): return self.pyfunction.get_name() - def get_definition(self, primary, pyname, call, host_vars=[], - returns=False): + def get_definition(self, primary, pyname, call, host_vars=[], returns=False): # caching already calculated definitions - return self._calculate_definition(primary, pyname, call, - host_vars, returns) + return self._calculate_definition(primary, pyname, call, host_vars, returns) def _calculate_header(self, primary, pyname, call): # A header is created which initializes parameters # to the values passed to the function. call_info = rope.refactor.functionutils.CallInfo.read( - primary, pyname, self.definition_info, call) + primary, pyname, self.definition_info, call + ) paramdict = self.definition_params mapping = rope.refactor.functionutils.ArgumentMapping( - self.definition_info, call_info) + self.definition_info, call_info + ) for param_name, value in mapping.param_dict.items(): paramdict[param_name] = value - header = '' + header = "" to_be_inlined = [] for name, value in paramdict.items(): if name != value and value is not None: - header += name + ' = ' + value.replace('\n', ' ') + '\n' + header += name + " = " + value.replace("\n", " ") + "\n" to_be_inlined.append(name) return header, to_be_inlined @@ -390,32 +435,33 @@ def _calculate_definition(self, primary, pyname, call, host_vars, returns): source = header + self.body mod = libutils.get_string_module(self.project, source) name_dict = mod.get_scope().get_names() - all_names = [x for x in name_dict if - not isinstance(name_dict[x], - rope.base.builtins.BuiltinName)] + all_names = [ + x + for x in name_dict + if not isinstance(name_dict[x], rope.base.builtins.BuiltinName) + ] # If there is a name conflict, all variable names # inside the inlined function are renamed if len(set(all_names).intersection(set(host_vars))) > 0: prefix = next(_DefinitionGenerator.unique_prefix) - guest = libutils.get_string_module(self.project, source, - self.resource) + guest = libutils.get_string_module(self.project, source, self.resource) to_be_inlined = [prefix + item for item in to_be_inlined] for item in all_names: pyname = guest[item] - occurrence_finder = occurrences.create_finder(self.project, - item, pyname) - source = rename.rename_in_module(occurrence_finder, - prefix + item, pymodule=guest) - guest = libutils.get_string_module( - self.project, source, self.resource) - - #parameters not reassigned inside the functions are now inlined. + occurrence_finder = occurrences.create_finder( + self.project, item, pyname + ) + source = rename.rename_in_module( + occurrence_finder, prefix + item, pymodule=guest + ) + guest = libutils.get_string_module(self.project, source, self.resource) + + # parameters not reassigned inside the functions are now inlined. for name in to_be_inlined: - pymodule = libutils.get_string_module( - self.project, source, self.resource) + pymodule = libutils.get_string_module(self.project, source, self.resource) pyname = pymodule[name] source = _inline_variable(self.project, pymodule, pyname, name) @@ -425,58 +471,55 @@ def _replace_returns_with(self, source, returns): result = [] returned = None last_changed = 0 - for match in _DefinitionGenerator._get_return_pattern().finditer( - source): + for match in _DefinitionGenerator._get_return_pattern().finditer(source): for key, value in match.groupdict().items(): - if value and key == 'return': - result.append(source[last_changed:match.start('return')]) + if value and key == "return": + result.append(source[last_changed : match.start("return")]) if returns: - self._check_nothing_after_return(source, - match.end('return')) - beg_idx = match.end('return') + self._check_nothing_after_return(source, match.end("return")) + beg_idx = match.end("return") returned = _join_lines( - source[beg_idx:len(source)].splitlines()) + source[beg_idx : len(source)].splitlines() + ) last_changed = len(source) else: - current = match.end('return') - while current < len(source) and \ - source[current] in ' \t': + current = match.end("return") + while current < len(source) and source[current] in " \t": current += 1 last_changed = current - if current == len(source) or source[current] == '\n': - result.append('pass') + if current == len(source) or source[current] == "\n": + result.append("pass") result.append(source[last_changed:]) - return ''.join(result), returned + return "".join(result), returned def _check_nothing_after_return(self, source, offset): lines = codeanalyze.SourceLinesAdapter(source) lineno = lines.get_line_number(offset) logical_lines = codeanalyze.LogicalLineFinder(lines) lineno = logical_lines.logical_line_in(lineno)[1] - if source[lines.get_line_end(lineno):len(source)].strip() != '': + if source[lines.get_line_end(lineno) : len(source)].strip() != "": raise rope.base.exceptions.RefactoringError( - 'Cannot inline functions with statements ' + - 'after return statement.') + "Cannot inline functions with statements " + "after return statement." + ) @classmethod def _get_return_pattern(cls): - if not hasattr(cls, '_return_pattern'): + if not hasattr(cls, "_return_pattern"): + def named_pattern(name, list_): return "(?P<%s>" % name + "|".join(list_) + ")" - comment_pattern = named_pattern('comment', [r'#[^\n]*']) - string_pattern = named_pattern('string', - [codeanalyze.get_string_pattern()]) - return_pattern = r'\b(?Preturn)\b' - cls._return_pattern = re.compile(comment_pattern + "|" + - string_pattern + "|" + - return_pattern) + + comment_pattern = named_pattern("comment", [r"#[^\n]*"]) + string_pattern = named_pattern("string", [codeanalyze.get_string_pattern()]) + return_pattern = r"\b(?Preturn)\b" + cls._return_pattern = re.compile( + comment_pattern + "|" + string_pattern + "|" + return_pattern + ) return cls._return_pattern class _InlineFunctionCallsForModuleHandle(object): - - def __init__(self, project, resource, - definition_generator, aim_offset=None): + def __init__(self, project, resource, definition_generator, aim_offset=None): """Inlines occurrences If `aim` is not `None` only the occurrences that intersect @@ -491,7 +534,8 @@ def __init__(self, project, resource, def occurred_inside_skip(self, change_collector, occurrence): if not occurrence.is_defined(): raise rope.base.exceptions.RefactoringError( - 'Cannot inline functions that reference themselves') + "Cannot inline functions that reference themselves" + ) def occurred_outside_skip(self, change_collector, occurrence): start, end = occurrence.get_primary_range() @@ -501,38 +545,47 @@ def occurred_outside_skip(self, change_collector, occurrence): # the function is referenced outside an import statement if not occurrence.is_called(): raise rope.base.exceptions.RefactoringError( - 'Reference to inlining function other than function call' - ' in ' % (self.resource.path, start)) + "Reference to inlining function other than function call" + " in " % (self.resource.path, start) + ) if self.aim is not None and (self.aim < start or self.aim > end): return end_parens = self._find_end_parens(self.source, end - 1) lineno = self.lines.get_line_number(start) - start_line, end_line = self.pymodule.logical_lines.\ - logical_line_in(lineno) + start_line, end_line = self.pymodule.logical_lines.logical_line_in(lineno) line_start = self.lines.get_line_start(start_line) line_end = self.lines.get_line_end(end_line) - returns = self.source[line_start:start].strip() != '' or \ - self.source[end_parens:line_end].strip() != '' + returns = ( + self.source[line_start:start].strip() != "" + or self.source[end_parens:line_end].strip() != "" + ) indents = sourceutils.get_indents(self.lines, start_line) primary, pyname = occurrence.get_primary_and_pyname() host = self.pymodule scope = host.scope.get_inner_scope_for_line(lineno) definition, returned = self.generator.get_definition( - primary, pyname, self.source[start:end_parens], scope.get_names(), - returns=returns) + primary, + pyname, + self.source[start:end_parens], + scope.get_names(), + returns=returns, + ) end = min(line_end + 1, len(self.source)) change_collector.add_change( - line_start, end, sourceutils.fix_indentation(definition, indents)) + line_start, end, sourceutils.fix_indentation(definition, indents) + ) if returns: name = returned if name is None: - name = 'None' + name = "None" change_collector.add_change( - line_end, end, self.source[line_start:start] + name + - self.source[end_parens:end]) + line_end, + end, + self.source[line_start:start] + name + self.source[end_parens:end], + ) def _find_end_parens(self, source, offset): finder = worder.Worder(source) @@ -557,22 +610,29 @@ def lines(self): return self.pymodule.lines -def _inline_variable(project, pymodule, pyname, name, - remove=True, region=None, docs=False): +def _inline_variable( + project, pymodule, pyname, name, remove=True, region=None, docs=False +): definition = _getvardef(pymodule, pyname) start, end = _assigned_lineno(pymodule, pyname) - occurrence_finder = occurrences.create_finder(project, name, pyname, - docs=docs) + occurrence_finder = occurrences.create_finder(project, name, pyname, docs=docs) changed_source = rename.rename_in_module( - occurrence_finder, definition, pymodule=pymodule, - replace_primary=True, writes=False, region=region) + occurrence_finder, + definition, + pymodule=pymodule, + replace_primary=True, + writes=False, + region=region, + ) if changed_source is None: changed_source = pymodule.source_code if remove: lines = codeanalyze.SourceLinesAdapter(changed_source) - source = changed_source[:lines.get_line_start(start)] + \ - changed_source[lines.get_line_end(end) + 1:] + source = ( + changed_source[: lines.get_line_start(start)] + + changed_source[lines.get_line_end(end) + 1 :] + ) else: source = changed_source return source @@ -583,12 +643,13 @@ def _getvardef(pymodule, pyname): lines = pymodule.lines start, end = _assigned_lineno(pymodule, pyname) definition_with_assignment = _join_lines( - [lines.get_line(n) for n in range(start, end + 1)]) + [lines.get_line(n) for n in range(start, end + 1)] + ) if assignment.levels: - raise rope.base.exceptions.RefactoringError( - 'Cannot inline tuple assignments.') - definition = definition_with_assignment[definition_with_assignment. - index('=') + 1:].strip() + raise rope.base.exceptions.RefactoringError("Cannot inline tuple assignments.") + definition = definition_with_assignment[ + definition_with_assignment.index("=") + 1 : + ].strip() return definition diff --git a/rope/refactor/introduce_factory.py b/rope/refactor/introduce_factory.py index bbaf347e9..58ff31435 100644 --- a/rope/refactor/introduce_factory.py +++ b/rope/refactor/introduce_factory.py @@ -2,29 +2,34 @@ import rope.base.pyobjects from rope.base import libutils from rope.base import taskhandle, evaluate -from rope.base.change import (ChangeSet, ChangeContents) +from rope.base.change import ChangeSet, ChangeContents from rope.refactor import rename, occurrences, sourceutils, importutils class IntroduceFactory(object): - def __init__(self, project, resource, offset): self.project = project self.offset = offset this_pymodule = self.project.get_pymodule(resource) self.old_pyname = evaluate.eval_location(this_pymodule, offset) - if self.old_pyname is None or \ - not isinstance(self.old_pyname.get_object(), - rope.base.pyobjects.PyClass): + if self.old_pyname is None or not isinstance( + self.old_pyname.get_object(), rope.base.pyobjects.PyClass + ): raise rope.base.exceptions.RefactoringError( - 'Introduce factory should be performed on a class.') + "Introduce factory should be performed on a class." + ) self.old_name = self.old_pyname.get_object().get_name() self.pymodule = self.old_pyname.get_object().get_module() self.resource = self.pymodule.get_resource() - def get_changes(self, factory_name, global_factory=False, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, + factory_name, + global_factory=False, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Get the changes this refactoring makes `factory_name` indicates the name of the factory function to @@ -38,21 +43,18 @@ def get_changes(self, factory_name, global_factory=False, resources=None, """ if resources is None: resources = self.project.get_python_files() - changes = ChangeSet('Introduce factory method <%s>' % factory_name) - job_set = task_handle.create_jobset('Collecting Changes', - len(resources)) - self._change_module(resources, changes, factory_name, - global_factory, job_set) + changes = ChangeSet("Introduce factory method <%s>" % factory_name) + job_set = task_handle.create_jobset("Collecting Changes", len(resources)) + self._change_module(resources, changes, factory_name, global_factory, job_set) return changes def get_name(self): """Return the name of the class""" return self.old_name - def _change_module(self, resources, changes, - factory_name, global_, job_set): + def _change_module(self, resources, changes, factory_name, global_, job_set): if global_: - replacement = '__rope_factory_%s_' % factory_name + replacement = "__rope_factory_%s_" % factory_name else: replacement = self._new_function_name(factory_name, global_) @@ -62,15 +64,16 @@ def _change_module(self, resources, changes, self._change_resource(changes, factory_name, global_) job_set.finished_job() continue - changed_code = self._rename_occurrences(file_, replacement, - global_) + changed_code = self._rename_occurrences(file_, replacement, global_) if changed_code is not None: if global_: new_pymodule = libutils.get_string_module( - self.project, changed_code, self.resource) + self.project, changed_code, self.resource + ) modname = libutils.modname(self.resource) changed_code, imported = importutils.add_import( - self.project, new_pymodule, modname, factory_name) + self.project, new_pymodule, modname, factory_name + ) changed_code = changed_code.replace(replacement, imported) changes.add_change(ChangeContents(file_, changed_code)) job_set.finished_job() @@ -78,18 +81,18 @@ def _change_module(self, resources, changes, def _change_resource(self, changes, factory_name, global_): class_scope = self.old_pyname.get_object().get_scope() source_code = self._rename_occurrences( - self.resource, self._new_function_name(factory_name, - global_), global_) + self.resource, self._new_function_name(factory_name, global_), global_ + ) if source_code is None: source_code = self.pymodule.source_code else: self.pymodule = libutils.get_string_module( - self.project, source_code, resource=self.resource) + self.project, source_code, resource=self.resource + ) lines = self.pymodule.lines start = self._get_insertion_offset(class_scope, lines) result = source_code[:start] - result += self._get_factory_method(lines, class_scope, - factory_name, global_) + result += self._get_factory_method(lines, class_scope, factory_name, global_) result += source_code[start:] changes.add_change(ChangeContents(self.resource, result)) @@ -100,21 +103,26 @@ def _get_insertion_offset(self, class_scope, lines): start = lines.get_line_end(start_line) + 1 return start - def _get_factory_method(self, lines, class_scope, - factory_name, global_): - unit_indents = ' ' * sourceutils.get_indent(self.project) + def _get_factory_method(self, lines, class_scope, factory_name, global_): + unit_indents = " " * sourceutils.get_indent(self.project) if global_: if self._get_scope_indents(lines, class_scope) > 0: raise rope.base.exceptions.RefactoringError( - 'Cannot make global factory method for nested classes.') - return ('\ndef %s(*args, **kwds):\n%sreturn %s(*args, **kwds)\n' % - (factory_name, unit_indents, self.old_name)) - unindented_factory = \ - ('@staticmethod\ndef %s(*args, **kwds):\n' % factory_name + - '%sreturn %s(*args, **kwds)\n' % (unit_indents, self.old_name)) - indents = self._get_scope_indents(lines, class_scope) + \ - sourceutils.get_indent(self.project) - return '\n' + sourceutils.indent_lines(unindented_factory, indents) + "Cannot make global factory method for nested classes." + ) + return "\ndef %s(*args, **kwds):\n%sreturn %s(*args, **kwds)\n" % ( + factory_name, + unit_indents, + self.old_name, + ) + unindented_factory = ( + "@staticmethod\ndef %s(*args, **kwds):\n" % factory_name + + "%sreturn %s(*args, **kwds)\n" % (unit_indents, self.old_name) + ) + indents = self._get_scope_indents(lines, class_scope) + sourceutils.get_indent( + self.project + ) + return "\n" + sourceutils.indent_lines(unindented_factory, indents) def _get_scope_indents(self, lines, scope): return sourceutils.get_indents(lines, scope.get_start()) @@ -123,13 +131,16 @@ def _new_function_name(self, factory_name, global_): if global_: return factory_name else: - return self.old_name + '.' + factory_name + return self.old_name + "." + factory_name def _rename_occurrences(self, file_, changed_name, global_factory): - finder = occurrences.create_finder(self.project, self.old_name, - self.old_pyname, only_calls=True) - result = rename.rename_in_module(finder, changed_name, resource=file_, - replace_primary=global_factory) + finder = occurrences.create_finder( + self.project, self.old_name, self.old_pyname, only_calls=True + ) + result = rename.rename_in_module( + finder, changed_name, resource=file_, replace_primary=global_factory + ) return result + IntroduceFactoryRefactoring = IntroduceFactory diff --git a/rope/refactor/introduce_parameter.py b/rope/refactor/introduce_parameter.py index 43d6f755b..d5e1936ee 100644 --- a/rope/refactor/introduce_parameter.py +++ b/rope/refactor/introduce_parameter.py @@ -40,56 +40,54 @@ def __init__(self, project, resource, offset): self.offset = offset self.pymodule = self.project.get_pymodule(self.resource) scope = self.pymodule.get_scope().get_inner_scope_for_offset(offset) - if scope.get_kind() != 'Function': + if scope.get_kind() != "Function": raise exceptions.RefactoringError( - 'Introduce parameter should be performed inside functions') + "Introduce parameter should be performed inside functions" + ) self.pyfunction = scope.pyobject self.name, self.pyname = self._get_name_and_pyname() if self.pyname is None: raise exceptions.RefactoringError( - 'Cannot find the definition of <%s>' % self.name) + "Cannot find the definition of <%s>" % self.name + ) def _get_primary(self): word_finder = worder.Worder(self.resource.read()) return word_finder.get_primary_at(self.offset) def _get_name_and_pyname(self): - return (worder.get_name_at(self.resource, self.offset), - evaluate.eval_location(self.pymodule, self.offset)) + return ( + worder.get_name_at(self.resource, self.offset), + evaluate.eval_location(self.pymodule, self.offset), + ) def get_changes(self, new_parameter): definition_info = functionutils.DefinitionInfo.read(self.pyfunction) - definition_info.args_with_defaults.append((new_parameter, - self._get_primary())) + definition_info.args_with_defaults.append((new_parameter, self._get_primary())) collector = codeanalyze.ChangeCollector(self.resource.read()) header_start, header_end = self._get_header_offsets() body_start, body_end = sourceutils.get_body_region(self.pyfunction) - collector.add_change(header_start, header_end, - definition_info.to_string()) - self._change_function_occurances(collector, body_start, - body_end, new_parameter) - changes = rope.base.change.ChangeSet('Introduce parameter <%s>' % - new_parameter) - change = rope.base.change.ChangeContents(self.resource, - collector.get_changed()) + collector.add_change(header_start, header_end, definition_info.to_string()) + self._change_function_occurances(collector, body_start, body_end, new_parameter) + changes = rope.base.change.ChangeSet("Introduce parameter <%s>" % new_parameter) + change = rope.base.change.ChangeContents(self.resource, collector.get_changed()) changes.add_change(change) return changes def _get_header_offsets(self): lines = self.pymodule.lines start_line = self.pyfunction.get_scope().get_start() - end_line = self.pymodule.logical_lines.\ - logical_line_in(start_line)[1] + end_line = self.pymodule.logical_lines.logical_line_in(start_line)[1] start = lines.get_line_start(start_line) end = lines.get_line_end(end_line) - start = self.pymodule.source_code.find('def', start) + 4 - end = self.pymodule.source_code.rfind(':', start, end) + start = self.pymodule.source_code.find("def", start) + 4 + end = self.pymodule.source_code.rfind(":", start, end) return start, end - def _change_function_occurances(self, collector, function_start, - function_end, new_name): - finder = occurrences.create_finder(self.project, self.name, - self.pyname) + def _change_function_occurances( + self, collector, function_start, function_end, new_name + ): + finder = occurrences.create_finder(self.project, self.name, self.pyname) for occurrence in finder.find_occurrences(resource=self.resource): start, end = occurrence.get_primary_range() if function_start <= start < function_end: diff --git a/rope/refactor/localtofield.py b/rope/refactor/localtofield.py index f276070f7..15057e39a 100644 --- a/rope/refactor/localtofield.py +++ b/rope/refactor/localtofield.py @@ -3,7 +3,6 @@ class LocalToField(object): - def __init__(self, project, resource, offset): self.project = project self.resource = resource @@ -15,35 +14,39 @@ def get_changes(self): pyname = evaluate.eval_location(this_pymodule, self.offset) if not self._is_a_method_local(pyname): raise exceptions.RefactoringError( - 'Convert local variable to field should be performed on \n' - 'a local variable of a method.') + "Convert local variable to field should be performed on \n" + "a local variable of a method." + ) pymodule, lineno = pyname.get_definition_location() function_scope = pymodule.get_scope().get_inner_scope_for_line(lineno) # Not checking redefinition - #self._check_redefinition(name, function_scope) + # self._check_redefinition(name, function_scope) new_name = self._get_field_name(function_scope.pyobject, name) - changes = Rename(self.project, self.resource, self.offset).\ - get_changes(new_name, resources=[self.resource]) + changes = Rename(self.project, self.resource, self.offset).get_changes( + new_name, resources=[self.resource] + ) return changes def _check_redefinition(self, name, function_scope): class_scope = function_scope.parent if name in class_scope.pyobject: - raise exceptions.RefactoringError( - 'The field %s already exists' % name) + raise exceptions.RefactoringError("The field %s already exists" % name) def _get_field_name(self, pyfunction, name): self_name = pyfunction.get_param_names()[0] - new_name = self_name + '.' + name + new_name = self_name + "." + name return new_name def _is_a_method_local(self, pyname): pymodule, lineno = pyname.get_definition_location() holding_scope = pymodule.get_scope().get_inner_scope_for_line(lineno) parent = holding_scope.parent - return isinstance(pyname, pynames.AssignedName) and \ - pyname in holding_scope.get_names().values() and \ - holding_scope.get_kind() == 'Function' and \ - parent is not None and parent.get_kind() == 'Class' + return ( + isinstance(pyname, pynames.AssignedName) + and pyname in holding_scope.get_names().values() + and holding_scope.get_kind() == "Function" + and parent is not None + and parent.get_kind() == "Class" + ) diff --git a/rope/refactor/method_object.py b/rope/refactor/method_object.py index 29ce429db..d65d3b135 100644 --- a/rope/refactor/method_object.py +++ b/rope/refactor/method_object.py @@ -6,48 +6,56 @@ class MethodObject(object): - def __init__(self, project, resource, offset): self.project = project this_pymodule = self.project.get_pymodule(resource) pyname = evaluate.eval_location(this_pymodule, offset) - if pyname is None or not isinstance(pyname.get_object(), - pyobjects.PyFunction): + if pyname is None or not isinstance(pyname.get_object(), pyobjects.PyFunction): raise exceptions.RefactoringError( - 'Replace method with method object refactoring should be ' - 'performed on a function.') + "Replace method with method object refactoring should be " + "performed on a function." + ) self.pyfunction = pyname.get_object() self.pymodule = self.pyfunction.get_module() self.resource = self.pymodule.get_resource() def get_new_class(self, name): body = sourceutils.fix_indentation( - self._get_body(), sourceutils.get_indent(self.project) * 2) - return 'class %s(object):\n\n%s%sdef __call__(self):\n%s' % \ - (name, self._get_init(), - ' ' * sourceutils.get_indent(self.project), body) + self._get_body(), sourceutils.get_indent(self.project) * 2 + ) + return "class %s(object):\n\n%s%sdef __call__(self):\n%s" % ( + name, + self._get_init(), + " " * sourceutils.get_indent(self.project), + body, + ) def get_changes(self, classname=None, new_class_name=None): if new_class_name is not None: warnings.warn( - 'new_class_name parameter is deprecated; use classname', - DeprecationWarning, stacklevel=2) + "new_class_name parameter is deprecated; use classname", + DeprecationWarning, + stacklevel=2, + ) classname = new_class_name collector = codeanalyze.ChangeCollector(self.pymodule.source_code) start, end = sourceutils.get_body_region(self.pyfunction) indents = sourceutils.get_indents( - self.pymodule.lines, self.pyfunction.get_scope().get_start()) + \ - sourceutils.get_indent(self.project) - new_contents = ' ' * indents + 'return %s(%s)()\n' % \ - (classname, ', '.join(self._get_parameter_names())) + self.pymodule.lines, self.pyfunction.get_scope().get_start() + ) + sourceutils.get_indent(self.project) + new_contents = " " * indents + "return %s(%s)()\n" % ( + classname, + ", ".join(self._get_parameter_names()), + ) collector.add_change(start, end, new_contents) insertion = self._get_class_insertion_point() - collector.add_change(insertion, insertion, - '\n\n' + self.get_new_class(classname)) - changes = change.ChangeSet( - 'Replace method with method object refactoring') - changes.add_change(change.ChangeContents(self.resource, - collector.get_changed())) + collector.add_change( + insertion, insertion, "\n\n" + self.get_new_class(classname) + ) + changes = change.ChangeSet("Replace method with method object refactoring") + changes.add_change( + change.ChangeContents(self.resource, collector.get_changed()) + ) return changes def _get_class_insertion_point(self): @@ -60,31 +68,29 @@ def _get_class_insertion_point(self): def _get_body(self): body = sourceutils.get_body(self.pyfunction) for param in self._get_parameter_names(): - body = param + ' = None\n' + body - pymod = libutils.get_string_module( - self.project, body, self.resource) + body = param + " = None\n" + body + pymod = libutils.get_string_module(self.project, body, self.resource) pyname = pymod[param] finder = occurrences.create_finder(self.project, param, pyname) - result = rename.rename_in_module(finder, 'self.' + param, - pymodule=pymod) - body = result[result.index('\n') + 1:] + result = rename.rename_in_module(finder, "self." + param, pymodule=pymod) + body = result[result.index("\n") + 1 :] return body def _get_init(self): params = self._get_parameter_names() - indents = ' ' * sourceutils.get_indent(self.project) + indents = " " * sourceutils.get_indent(self.project) if not params: - return '' - header = indents + 'def __init__(self' - body = '' + return "" + header = indents + "def __init__(self" + body = "" for arg in params: new_name = arg - if arg == 'self': - new_name = 'host' - header += ', %s' % new_name - body += indents * 2 + 'self.%s = %s\n' % (arg, new_name) - header += '):' - return '%s\n%s\n' % (header, body) + if arg == "self": + new_name = "host" + header += ", %s" % new_name + body += indents * 2 + "self.%s = %s\n" % (arg, new_name) + header += "):" + return "%s\n%s\n" % (header, body) def _get_parameter_names(self): return self.pyfunction.get_param_names() diff --git a/rope/refactor/move.py b/rope/refactor/move.py index d82275a07..4a984467e 100644 --- a/rope/refactor/move.py +++ b/rope/refactor/move.py @@ -4,11 +4,18 @@ based on inputs. """ -from rope.base import (pyobjects, codeanalyze, exceptions, pynames, - taskhandle, evaluate, worder, libutils) +from rope.base import ( + pyobjects, + codeanalyze, + exceptions, + pynames, + taskhandle, + evaluate, + worder, + libutils, +) from rope.base.change import ChangeSet, ChangeContents, MoveResource -from rope.refactor import importutils, rename, occurrences, sourceutils, \ - functionutils +from rope.refactor import importutils, rename, occurrences, sourceutils, functionutils def create_move(project, resource, offset=None): @@ -24,19 +31,23 @@ def create_move(project, resource, offset=None): pyname = evaluate.eval_location(this_pymodule, offset) if pyname is not None: pyobject = pyname.get_object() - if isinstance(pyobject, pyobjects.PyModule) or \ - isinstance(pyobject, pyobjects.PyPackage): + if isinstance(pyobject, pyobjects.PyModule) or isinstance( + pyobject, pyobjects.PyPackage + ): return MoveModule(project, pyobject.get_resource()) - if isinstance(pyobject, pyobjects.PyFunction) and \ - isinstance(pyobject.parent, pyobjects.PyClass): + if isinstance(pyobject, pyobjects.PyFunction) and isinstance( + pyobject.parent, pyobjects.PyClass + ): return MoveMethod(project, resource, offset) - if isinstance(pyobject, pyobjects.PyDefinedObject) and \ - isinstance(pyobject.parent, pyobjects.PyModule) or \ - isinstance(pyname, pynames.AssignedName): + if ( + isinstance(pyobject, pyobjects.PyDefinedObject) + and isinstance(pyobject.parent, pyobjects.PyModule) + or isinstance(pyname, pynames.AssignedName) + ): return MoveGlobal(project, resource, offset) raise exceptions.RefactoringError( - 'Move only works on global classes/functions/variables, modules and ' - 'methods.') + "Move only works on global classes/functions/variables, modules and " "methods." + ) class MoveMethod(object): @@ -54,12 +65,16 @@ def __init__(self, project, resource, offset): pyname = evaluate.eval_location(this_pymodule, offset) self.method_name = worder.get_name_at(resource, offset) self.pyfunction = pyname.get_object() - if self.pyfunction.get_kind() != 'method': - raise exceptions.RefactoringError('Only normal methods' - ' can be moved.') - - def get_changes(self, dest_attr, new_name=None, resources=None, - task_handle=taskhandle.NullTaskHandle()): + if self.pyfunction.get_kind() != "method": + raise exceptions.RefactoringError("Only normal methods" " can be moved.") + + def get_changes( + self, + dest_attr, + new_name=None, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Return the changes needed for this refactoring Parameters: @@ -72,18 +87,20 @@ def get_changes(self, dest_attr, new_name=None, resources=None, will be applied to all python files. """ - changes = ChangeSet('Moving method <%s>' % self.method_name) + changes = ChangeSet("Moving method <%s>" % self.method_name) if resources is None: resources = self.project.get_python_files() if new_name is None: new_name = self.get_method_name() - resource1, start1, end1, new_content1 = \ - self._get_changes_made_by_old_class(dest_attr, new_name) + resource1, start1, end1, new_content1 = self._get_changes_made_by_old_class( + dest_attr, new_name + ) collector1 = codeanalyze.ChangeCollector(resource1.read()) collector1.add_change(start1, end1, new_content1) - resource2, start2, end2, new_content2 = \ - self._get_changes_made_by_new_class(dest_attr, new_name) + resource2, start2, end2, new_content2 = self._get_changes_made_by_new_class( + dest_attr, new_name + ) if resource1 == resource2: collector1.add_change(start2, end2, new_content2) else: @@ -94,15 +111,16 @@ def get_changes(self, dest_attr, new_name=None, resources=None, new_imports = self._get_used_imports(import_tools) if new_imports: goal_pymodule = libutils.get_string_module( - self.project, result, resource2) + self.project, result, resource2 + ) result = _add_imports_to_module( - import_tools, goal_pymodule, new_imports) + import_tools, goal_pymodule, new_imports + ) if resource2 in resources: changes.add_change(ChangeContents(resource2, result)) if resource1 in resources: - changes.add_change(ChangeContents(resource1, - collector1.get_changed())) + changes.add_change(ChangeContents(resource1, collector1.get_changed())) return changes def get_method_name(self): @@ -114,86 +132,97 @@ def _get_used_imports(self, import_tools): def _get_changes_made_by_old_class(self, dest_attr, new_name): pymodule = self.pyfunction.get_module() indents = self._get_scope_indents(self.pyfunction) - body = 'return self.%s.%s(%s)\n' % ( - dest_attr, new_name, self._get_passed_arguments_string()) + body = "return self.%s.%s(%s)\n" % ( + dest_attr, + new_name, + self._get_passed_arguments_string(), + ) region = sourceutils.get_body_region(self.pyfunction) - return (pymodule.get_resource(), region[0], region[1], - sourceutils.fix_indentation(body, indents)) + return ( + pymodule.get_resource(), + region[0], + region[1], + sourceutils.fix_indentation(body, indents), + ) def _get_scope_indents(self, pyobject): pymodule = pyobject.get_module() return sourceutils.get_indents( - pymodule.lines, pyobject.get_scope().get_start()) + \ - sourceutils.get_indent(self.project) + pymodule.lines, pyobject.get_scope().get_start() + ) + sourceutils.get_indent(self.project) def _get_changes_made_by_new_class(self, dest_attr, new_name): old_pyclass = self.pyfunction.parent if dest_attr not in old_pyclass: raise exceptions.RefactoringError( - 'Destination attribute <%s> not found' % dest_attr) + "Destination attribute <%s> not found" % dest_attr + ) pyclass = old_pyclass[dest_attr].get_object().get_type() if not isinstance(pyclass, pyobjects.PyClass): raise exceptions.RefactoringError( - 'Unknown class type for attribute <%s>' % dest_attr) + "Unknown class type for attribute <%s>" % dest_attr + ) pymodule = pyclass.get_module() resource = pyclass.get_module().get_resource() start, end = sourceutils.get_body_region(pyclass) - pre_blanks = '\n' - if pymodule.source_code[start:end].strip() != 'pass': - pre_blanks = '\n\n' + pre_blanks = "\n" + if pymodule.source_code[start:end].strip() != "pass": + pre_blanks = "\n\n" start = end indents = self._get_scope_indents(pyclass) body = pre_blanks + sourceutils.fix_indentation( - self.get_new_method(new_name), indents) + self.get_new_method(new_name), indents + ) return resource, start, end, body def get_new_method(self, name): - return '%s\n%s' % ( + return "%s\n%s" % ( self._get_new_header(name), - sourceutils.fix_indentation(self._get_body(), - sourceutils.get_indent(self.project))) + sourceutils.fix_indentation( + self._get_body(), sourceutils.get_indent(self.project) + ), + ) def _get_unchanged_body(self): return sourceutils.get_body(self.pyfunction) - def _get_body(self, host='host'): + def _get_body(self, host="host"): self_name = self._get_self_name() - body = self_name + ' = None\n' + self._get_unchanged_body() + body = self_name + " = None\n" + self._get_unchanged_body() pymodule = libutils.get_string_module(self.project, body) - finder = occurrences.create_finder( - self.project, self_name, pymodule[self_name]) + finder = occurrences.create_finder(self.project, self_name, pymodule[self_name]) result = rename.rename_in_module(finder, host, pymodule=pymodule) if result is None: result = body - return result[result.index('\n') + 1:] + return result[result.index("\n") + 1 :] def _get_self_name(self): return self.pyfunction.get_param_names()[0] def _get_new_header(self, name): - header = 'def %s(self' % name + header = "def %s(self" % name if self._is_host_used(): - header += ', host' + header += ", host" definition_info = functionutils.DefinitionInfo.read(self.pyfunction) others = definition_info.arguments_to_string(1) if others: - header += ', ' + others - return header + '):' + header += ", " + others + return header + "):" def _get_passed_arguments_string(self): - result = '' + result = "" if self._is_host_used(): - result = 'self' + result = "self" definition_info = functionutils.DefinitionInfo.read(self.pyfunction) others = definition_info.arguments_to_string(1) if others: if result: - result += ', ' + result += ", " result += others return result def _is_host_used(self): - return self._get_body('__old_self') != self._get_unchanged_body() + return self._get_body("__old_self") != self._get_unchanged_body() class MoveGlobal(object): @@ -205,8 +234,8 @@ def __init__(self, project, resource, offset): self.old_pyname = evaluate.eval_location(this_pymodule, offset) if self.old_pyname is None: raise exceptions.RefactoringError( - 'Move refactoring should be performed on a ' - 'class/function/variable.') + "Move refactoring should be performed on a " "class/function/variable." + ) if self._is_variable(self.old_pyname): self.old_name = worder.get_name_at(resource, offset) pymodule = this_pymodule @@ -215,27 +244,31 @@ def __init__(self, project, resource, offset): pymodule = self.old_pyname.get_object().get_module() self._check_exceptional_conditions() self.source = pymodule.get_resource() - self.tools = _MoveTools(self.project, self.source, - self.old_pyname, self.old_name) + self.tools = _MoveTools( + self.project, self.source, self.old_pyname, self.old_name + ) self.import_tools = self.tools.import_tools def _import_filter(self, stmt): - module_name = libutils.modname(self.source) - - if isinstance(stmt.import_info, importutils.NormalImport): - # Affect any statement that imports the source module - return any(module_name == name - for name, alias in stmt.import_info.names_and_aliases) - elif isinstance(stmt.import_info, importutils.FromImport): - # Affect statements importing from the source package - if '.' in module_name: - package_name, basename = module_name.rsplit('.', 1) - if (stmt.import_info.module_name == package_name and - any(basename == name - for name, alias in stmt.import_info.names_and_aliases)): - return True - return stmt.import_info.module_name == module_name - return False + module_name = libutils.modname(self.source) + + if isinstance(stmt.import_info, importutils.NormalImport): + # Affect any statement that imports the source module + return any( + module_name == name + for name, alias in stmt.import_info.names_and_aliases + ) + elif isinstance(stmt.import_info, importutils.FromImport): + # Affect statements importing from the source package + if "." in module_name: + package_name, basename = module_name.rsplit(".", 1) + if stmt.import_info.module_name == package_name and any( + basename == name + for name, alias in stmt.import_info.names_and_aliases + ): + return True + return stmt.import_info.module_name == module_name + return False def _check_exceptional_conditions(self): if self._is_variable(self.old_pyname): @@ -244,43 +277,46 @@ def _check_exceptional_conditions(self): pymodule.get_scope().get_name(self.old_name) except exceptions.NameNotFoundError: self._raise_refactoring_error() - elif not (isinstance(self.old_pyname.get_object(), - pyobjects.PyDefinedObject) and - self._is_global(self.old_pyname.get_object())): + elif not ( + isinstance(self.old_pyname.get_object(), pyobjects.PyDefinedObject) + and self._is_global(self.old_pyname.get_object()) + ): self._raise_refactoring_error() def _raise_refactoring_error(self): raise exceptions.RefactoringError( - 'Move refactoring should be performed on a global class, function ' - 'or variable.') + "Move refactoring should be performed on a global class, function " + "or variable." + ) def _is_global(self, pyobject): return pyobject.get_scope().parent == pyobject.get_module().get_scope() def _is_variable(self, pyname): - return isinstance(pyname, pynames.AssignedName) + return isinstance(pyname, pynames.AssignedName) - def get_changes(self, dest, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, dest, resources=None, task_handle=taskhandle.NullTaskHandle() + ): if resources is None: resources = self.project.get_python_files() if dest is None or not dest.exists(): - raise exceptions.RefactoringError( - 'Move destination does not exist.') - if dest.is_folder() and dest.has_child('__init__.py'): - dest = dest.get_child('__init__.py') + raise exceptions.RefactoringError("Move destination does not exist.") + if dest.is_folder() and dest.has_child("__init__.py"): + dest = dest.get_child("__init__.py") if dest.is_folder(): raise exceptions.RefactoringError( - 'Move destination for non-modules should not be folders.') + "Move destination for non-modules should not be folders." + ) if self.source == dest: raise exceptions.RefactoringError( - 'Moving global elements to the same module.') + "Moving global elements to the same module." + ) return self._calculate_changes(dest, resources, task_handle) def _calculate_changes(self, dest, resources, task_handle): - changes = ChangeSet('Moving global <%s>' % self.old_name) - job_set = task_handle.create_jobset('Collecting Changes', - len(resources)) + changes = ChangeSet("Moving global <%s>" % self.old_name) + job_set = task_handle.create_jobset("Collecting Changes", len(resources)) for file_ in resources: job_set.started_job(file_.path) if file_ == self.source: @@ -290,20 +326,20 @@ def _calculate_changes(self, dest, resources, task_handle): elif self.tools.occurs_in_module(resource=file_): pymodule = self.project.get_pymodule(file_) # Changing occurrences - placeholder = '__rope_renaming_%s_' % self.old_name - source = self.tools.rename_in_module(placeholder, - resource=file_) + placeholder = "__rope_renaming_%s_" % self.old_name + source = self.tools.rename_in_module(placeholder, resource=file_) should_import = source is not None # Removing out of date imports pymodule = self.tools.new_pymodule(pymodule, source) source = self.import_tools.organize_imports( - pymodule, sort=False, import_filter=self._import_filter) + pymodule, sort=False, import_filter=self._import_filter + ) # Adding new import if should_import: pymodule = self.tools.new_pymodule(pymodule, source) source, imported = importutils.add_import( - self.project, pymodule, self._new_modname(dest), - self.old_name) + self.project, pymodule, self._new_modname(dest), self.old_name + ) source = source.replace(placeholder, imported) source = self.tools.new_source(pymodule, source) if source != file_.read(): @@ -312,22 +348,22 @@ def _calculate_changes(self, dest, resources, task_handle): return changes def _source_module_changes(self, dest): - placeholder = '__rope_moving_%s_' % self.old_name + placeholder = "__rope_moving_%s_" % self.old_name handle = _ChangeMoveOccurrencesHandle(placeholder) occurrence_finder = occurrences.create_finder( - self.project, self.old_name, self.old_pyname) + self.project, self.old_name, self.old_pyname + ) start, end = self._get_moving_region() - renamer = ModuleSkipRenamer(occurrence_finder, self.source, - handle, start, end) + renamer = ModuleSkipRenamer(occurrence_finder, self.source, handle, start, end) source = renamer.get_changed_module() pymodule = libutils.get_string_module(self.project, source, self.source) source = self.import_tools.organize_imports(pymodule, sort=False) if handle.occurred: - pymodule = libutils.get_string_module( - self.project, source, self.source) + pymodule = libutils.get_string_module(self.project, source, self.source) # Adding new import source, imported = importutils.add_import( - self.project, pymodule, self._new_modname(dest), self.old_name) + self.project, pymodule, self._new_modname(dest), self.old_name + ) source = source.replace(placeholder, imported) return ChangeContents(self.source, source) @@ -349,41 +385,44 @@ def _dest_module_changes(self, dest): if module_with_imports.imports: lineno = module_with_imports.imports[-1].end_line - 1 else: - while lineno < pymodule.lines.length() and \ - pymodule.lines.get_line(lineno + 1).\ - lstrip().startswith('#'): + while lineno < pymodule.lines.length() and pymodule.lines.get_line( + lineno + 1 + ).lstrip().startswith("#"): lineno += 1 if lineno > 0: cut = pymodule.lines.get_line_end(lineno) + 1 - result = source[:cut] + '\n\n' + moving + source[cut:] + result = source[:cut] + "\n\n" + moving + source[cut:] else: result = moving + source # Organizing imports source = result pymodule = libutils.get_string_module(self.project, source, dest) - source = self.import_tools.organize_imports(pymodule, sort=False, - unused=False) + source = self.import_tools.organize_imports(pymodule, sort=False, unused=False) # Remove unused imports of the old module pymodule = libutils.get_string_module(self.project, source, dest) source = self.import_tools.organize_imports( - pymodule, sort=False, selfs=False, unused=True, - import_filter=self._import_filter) + pymodule, + sort=False, + selfs=False, + unused=True, + import_filter=self._import_filter, + ) return ChangeContents(dest, source) def _get_moving_element_with_imports(self): return moving_code_with_imports( - self.project, self.source, self._get_moving_element()) + self.project, self.source, self._get_moving_element() + ) def _get_module_with_imports(self, source_code, resource): - pymodule = libutils.get_string_module( - self.project, source_code, resource) + pymodule = libutils.get_string_module(self.project, source_code, resource) return self.import_tools.module_imports(pymodule) def _get_moving_element(self): start, end = self._get_moving_region() moving = self.source.read()[start:end] - return moving.rstrip() + '\n' + return moving.rstrip() + "\n" def _get_moving_region(self): pymodule = self.project.get_pymodule(self.source) @@ -391,7 +430,8 @@ def _get_moving_region(self): if self._is_variable(self.old_pyname): logical_lines = pymodule.logical_lines lineno = logical_lines.logical_line_in( - self.old_pyname.get_definition_location()[1])[0] + self.old_pyname.get_definition_location()[1] + )[0] start = lines.get_line_start(lineno) end_line = logical_lines.logical_line_in(lineno)[1] else: @@ -401,12 +441,11 @@ def _get_moving_region(self): # Include comment lines before the definition start_line = lines.get_line_number(start) - while start_line > 1 and lines.get_line(start_line - 1).startswith('#'): - start_line -= 1 + while start_line > 1 and lines.get_line(start_line - 1).startswith("#"): + start_line -= 1 start = lines.get_line_start(start_line) - while end_line < lines.length() and \ - lines.get_line(end_line + 1).strip() == '': + while end_line < lines.length() and lines.get_line(end_line + 1).strip() == "": end_line += 1 end = min(lines.get_line_end(end_line) + 1, len(pymodule.source_code)) return start, end @@ -417,8 +456,7 @@ def _add_imports2(self, pymodule, new_imports): return pymodule, False else: resource = pymodule.get_resource() - pymodule = libutils.get_string_module( - self.project, source, resource) + pymodule = libutils.get_string_module(self.project, source, resource) return pymodule, True @@ -427,43 +465,42 @@ class MoveModule(object): def __init__(self, project, resource): self.project = project - if not resource.is_folder() and resource.name == '__init__.py': + if not resource.is_folder() and resource.name == "__init__.py": resource = resource.parent - if resource.is_folder() and not resource.has_child('__init__.py'): - raise exceptions.RefactoringError( - 'Cannot move non-package folder.') - dummy_pymodule = libutils.get_string_module(self.project, '') - self.old_pyname = pynames.ImportedModule(dummy_pymodule, - resource=resource) + if resource.is_folder() and not resource.has_child("__init__.py"): + raise exceptions.RefactoringError("Cannot move non-package folder.") + dummy_pymodule = libutils.get_string_module(self.project, "") + self.old_pyname = pynames.ImportedModule(dummy_pymodule, resource=resource) self.source = self.old_pyname.get_object().get_resource() if self.source.is_folder(): self.old_name = self.source.name else: self.old_name = self.source.name[:-3] - self.tools = _MoveTools(self.project, self.source, - self.old_pyname, self.old_name) + self.tools = _MoveTools( + self.project, self.source, self.old_pyname, self.old_name + ) self.import_tools = self.tools.import_tools - def get_changes(self, dest, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, dest, resources=None, task_handle=taskhandle.NullTaskHandle() + ): if resources is None: resources = self.project.get_python_files() if dest is None or not dest.is_folder(): raise exceptions.RefactoringError( - 'Move destination for modules should be packages.') + "Move destination for modules should be packages." + ) return self._calculate_changes(dest, resources, task_handle) def _calculate_changes(self, dest, resources, task_handle): - changes = ChangeSet('Moving module <%s>' % self.old_name) - job_set = task_handle.create_jobset('Collecting changes', - len(resources)) + changes = ChangeSet("Moving module <%s>" % self.old_name) + job_set = task_handle.create_jobset("Collecting changes", len(resources)) for module in resources: job_set.started_job(module.path) if module == self.source: self._change_moving_module(changes, dest) else: - source = self._change_occurrences_in_module(dest, - resource=module) + source = self._change_occurrences_in_module(dest, resource=module) if source is not None: changes.add_change(ChangeContents(module, source)) job_set.finished_job() @@ -474,7 +511,7 @@ def _calculate_changes(self, dest, resources, task_handle): def _new_modname(self, dest): destname = libutils.modname(dest) if destname: - return destname + '.' + self.old_name + return destname + "." + self.old_name return self.old_name def _new_import(self, dest): @@ -490,10 +527,8 @@ def _change_moving_module(self, changes, dest): if source != self.source.read(): changes.add_change(ChangeContents(self.source, source)) - def _change_occurrences_in_module(self, dest, pymodule=None, - resource=None): - if not self.tools.occurs_in_module(pymodule=pymodule, - resource=resource): + def _change_occurrences_in_module(self, dest, pymodule=None, resource=None): + if not self.tools.occurs_in_module(pymodule=pymodule, resource=resource): return if pymodule is None: pymodule = self.project.get_pymodule(resource) @@ -502,8 +537,7 @@ def _change_occurrences_in_module(self, dest, pymodule=None, changed = False source = None if libutils.modname(dest): - changed = self._change_import_statements(dest, new_name, - module_imports) + changed = self._change_import_statements(dest, new_name, module_imports) if changed: source = module_imports.get_changed_source() source = self.tools.new_source(pymodule, source) @@ -511,10 +545,14 @@ def _change_occurrences_in_module(self, dest, pymodule=None, new_import = self._new_import(dest) source = self.tools.rename_in_module( - new_name, imports=True, pymodule=pymodule, - resource=resource if not changed else None) + new_name, + imports=True, + pymodule=pymodule, + resource=resource if not changed else None, + ) should_import = self.tools.occurs_in_module( - pymodule=pymodule, resource=resource, imports=False) + pymodule=pymodule, resource=resource, imports=False + ) pymodule = self.tools.new_pymodule(pymodule, source) source = self.tools.remove_old_imports(pymodule) if should_import: @@ -531,12 +569,13 @@ def _change_import_statements(self, dest, new_name, module_imports): changed = False for import_stmt in module_imports.imports: - if not any(name_and_alias[0] == self.old_name - for name_and_alias in - import_stmt.import_info.names_and_aliases) and \ - not any(name_and_alias[0] == libutils.modname(self.source) - for name_and_alias in - import_stmt.import_info.names_and_aliases): + if not any( + name_and_alias[0] == self.old_name + for name_and_alias in import_stmt.import_info.names_and_aliases + ) and not any( + name_and_alias[0] == libutils.modname(self.source) + for name_and_alias in import_stmt.import_info.names_and_aliases + ): continue # Case 1: Look for normal imports of the moving module. @@ -544,27 +583,35 @@ def _change_import_statements(self, dest, new_name, module_imports): continue # Case 2: The moving module is from-imported. - changed = self._handle_moving_in_from_import_stmt( - dest, import_stmt, module_imports, parent_module) or changed + changed = ( + self._handle_moving_in_from_import_stmt( + dest, import_stmt, module_imports, parent_module + ) + or changed + ) # Case 3: Names are imported from the moving module. context = importutils.importinfo.ImportContext(self.project, None) - if not import_stmt.import_info.is_empty() and \ - import_stmt.import_info.get_imported_resource(context) == \ - moving_module: + if ( + not import_stmt.import_info.is_empty() + and import_stmt.import_info.get_imported_resource(context) + == moving_module + ): import_stmt.import_info = importutils.FromImport( - new_name, import_stmt.import_info.level, - import_stmt.import_info.names_and_aliases) + new_name, + import_stmt.import_info.level, + import_stmt.import_info.names_and_aliases, + ) changed = True return changed - def _handle_moving_in_from_import_stmt(self, dest, import_stmt, - module_imports, parent_module): + def _handle_moving_in_from_import_stmt( + self, dest, import_stmt, module_imports, parent_module + ): changed = False context = importutils.importinfo.ImportContext(self.project, None) - if import_stmt.import_info.get_imported_resource(context) == \ - parent_module: + if import_stmt.import_info.get_imported_resource(context) == parent_module: imports = import_stmt.import_info.names_and_aliases new_imports = [] for name, alias in imports: @@ -572,8 +619,8 @@ def _handle_moving_in_from_import_stmt(self, dest, import_stmt, if name == self.old_name: changed = True new_import = importutils.FromImport( - libutils.modname(dest), 0, - [(self.old_name, alias)]) + libutils.modname(dest), 0, [(self.old_name, alias)] + ) module_imports.add_import(new_import) else: new_imports.append((name, alias)) @@ -585,14 +632,14 @@ def _handle_moving_in_from_import_stmt(self, dest, import_stmt, import_stmt.import_info = importutils.FromImport( import_stmt.import_info.module_name, import_stmt.import_info.level, - new_imports) + new_imports, + ) else: import_stmt.empty_import() return changed class _ChangeMoveOccurrencesHandle(object): - def __init__(self, new_name): self.new_name = new_name self.occurred = False @@ -607,7 +654,6 @@ def occurred_outside_skip(self, change_collector, occurrence): class _MoveTools(object): - def __init__(self, project, source, pyname, old_name): self.project = project self.source = source @@ -626,44 +672,53 @@ class CanSelect(object): def __call__(self, name): try: - if name == self.old_name and \ - pymodule[name].get_object() == \ - self.old_pyname.get_object(): + if ( + name == self.old_name + and pymodule[name].get_object() == self.old_pyname.get_object() + ): self.changed = True return False except exceptions.AttributeNotFoundError: pass return True + can_select = CanSelect() module_with_imports.filter_names(can_select) new_source = module_with_imports.get_changed_source() if old_source != new_source: return new_source - def rename_in_module(self, new_name, pymodule=None, - imports=False, resource=None): + def rename_in_module(self, new_name, pymodule=None, imports=False, resource=None): occurrence_finder = self._create_finder(imports) source = rename.rename_in_module( - occurrence_finder, new_name, replace_primary=True, - pymodule=pymodule, resource=resource) + occurrence_finder, + new_name, + replace_primary=True, + pymodule=pymodule, + resource=resource, + ) return source def occurs_in_module(self, pymodule=None, resource=None, imports=True): finder = self._create_finder(imports) - for occurrence in finder.find_occurrences(pymodule=pymodule, - resource=resource): + for occurrence in finder.find_occurrences(pymodule=pymodule, resource=resource): return True return False def _create_finder(self, imports): - return occurrences.create_finder(self.project, self.old_name, - self.old_pyname, imports=imports, - keywords=False) + return occurrences.create_finder( + self.project, + self.old_name, + self.old_pyname, + imports=imports, + keywords=False, + ) def new_pymodule(self, pymodule, source): if source is not None: return libutils.get_string_module( - self.project, source, pymodule.get_resource()) + self.project, source, pymodule.get_resource() + ) return pymodule def new_source(self, pymodule, source): @@ -690,11 +745,12 @@ def moving_code_with_imports(project, resource, source): # section, but imports would be added between them. lines = codeanalyze.SourceLinesAdapter(source) start = 1 - while start < lines.length() and lines.get_line(start).startswith('#'): + while start < lines.length() and lines.get_line(start).startswith("#"): start += 1 - moving_prefix = source[:lines.get_line_start(start)] + moving_prefix = source[: lines.get_line_start(start)] pymodule = libutils.get_string_module( - project, source[lines.get_line_start(start):], resource) + project, source[lines.get_line_start(start) :], resource + ) origin = project.get_pymodule(resource) @@ -718,8 +774,7 @@ def moving_code_with_imports(project, resource, source): # extracting imports after changes module_imports = import_tools.module_imports(pymodule) - imports = [import_stmt.import_info - for import_stmt in module_imports.imports] + imports = [import_stmt.import_info for import_stmt in module_imports.imports] start = 1 if module_imports.imports: start = module_imports.imports[-1].end_line @@ -728,12 +783,11 @@ def moving_code_with_imports(project, resource, source): start += 1 # Reinsert the prefix which was removed at the beginning - moving = moving_prefix + source[lines.get_line_start(start):] + moving = moving_prefix + source[lines.get_line_start(start) :] return moving, imports class ModuleSkipRenamerHandle(object): - def occurred_outside_skip(self, change_collector, occurrence): pass @@ -749,8 +803,15 @@ class ModuleSkipRenamer(object): """ - def __init__(self, occurrence_finder, resource, handle=None, - skip_start=0, skip_end=0, replacement=''): + def __init__( + self, + occurrence_finder, + resource, + handle=None, + skip_start=0, + skip_end=0, + replacement="", + ): """Constructor if replacement is `None` the region is not changed. Otherwise @@ -770,10 +831,10 @@ def get_changed_module(self): source = self.resource.read() change_collector = codeanalyze.ChangeCollector(source) if self.replacement is not None: - change_collector.add_change(self.skip_start, self.skip_end, - self.replacement) - for occurrence in self.occurrence_finder.find_occurrences( - self.resource): + change_collector.add_change( + self.skip_start, self.skip_end, self.replacement + ) + for occurrence in self.occurrence_finder.find_occurrences(self.resource): start, end = occurrence.get_primary_range() if self.skip_start <= start < self.skip_end: self.handle.occurred_inside_skip(change_collector, occurrence) diff --git a/rope/refactor/multiproject.py b/rope/refactor/multiproject.py index ac243bdaf..0be744eab 100644 --- a/rope/refactor/multiproject.py +++ b/rope/refactor/multiproject.py @@ -9,7 +9,6 @@ class MultiProjectRefactoring(object): - def __init__(self, refactoring, projects, addpath=True): """Create a multiproject proxy for the main refactoring @@ -22,24 +21,22 @@ def __init__(self, refactoring, projects, addpath=True): def __call__(self, project, *args, **kwds): """Create the refactoring""" - return _MultiRefactoring(self.refactoring, self.projects, - self.addpath, project, *args, **kwds) + return _MultiRefactoring( + self.refactoring, self.projects, self.addpath, project, *args, **kwds + ) class _MultiRefactoring(object): - - def __init__(self, refactoring, other_projects, addpath, - project, *args, **kwds): + def __init__(self, refactoring, other_projects, addpath, project, *args, **kwds): self.refactoring = refactoring self.projects = [project] + other_projects for other_project in other_projects: for folder in self.project.get_source_folders(): - other_project.get_prefs().add('python_path', folder.real_path) + other_project.get_prefs().add("python_path", folder.real_path) self.refactorings = [] for other in self.projects: args, kwds = self._resources_for_args(other, args, kwds) - self.refactorings.append( - self.refactoring(other, *args, **kwds)) + self.refactorings.append(self.refactoring(other, *args, **kwds)) def get_all_changes(self, *args, **kwds): """Get a project to changes dict""" @@ -54,13 +51,14 @@ def __getattr__(self, name): def _resources_for_args(self, project, args, kwds): newargs = [self._change_project_resource(project, arg) for arg in args] - newkwds = dict((name, self._change_project_resource(project, value)) - for name, value in kwds.items()) + newkwds = dict( + (name, self._change_project_resource(project, value)) + for name, value in kwds.items() + ) return newargs, newkwds def _change_project_resource(self, project, obj): - if isinstance(obj, resources.Resource) and \ - obj.project != project: + if isinstance(obj, resources.Resource) and obj.project != project: return libutils.path_to_resource(project, obj.real_path) return obj diff --git a/rope/refactor/occurrences.py b/rope/refactor/occurrences.py index b27e74bf9..e61819169 100644 --- a/rope/refactor/occurrences.py +++ b/rope/refactor/occurrences.py @@ -71,8 +71,9 @@ def __init__(self, project, name, filters=[lambda o: True], docs=False): def find_occurrences(self, resource=None, pymodule=None): """Generate `Occurrence` instances""" - tools = _OccurrenceToolsCreator(self.project, resource=resource, - pymodule=pymodule, docs=self.docs) + tools = _OccurrenceToolsCreator( + self.project, resource=resource, pymodule=pymodule, docs=self.docs + ) for offset in self._textual_finder.find_offsets(tools.source_code): occurrence = Occurrence(tools, offset) for filter in self.filters: @@ -84,9 +85,18 @@ def find_occurrences(self, resource=None, pymodule=None): break -def create_finder(project, name, pyname, only_calls=False, imports=True, - unsure=None, docs=False, instance=None, in_hierarchy=False, - keywords=True): +def create_finder( + project, + name, + pyname, + only_calls=False, + imports=True, + unsure=None, + docs=False, + instance=None, + in_hierarchy=False, + keywords=True, +): """A factory for `Finder` Based on the arguments it creates a list of filters. `instance` @@ -118,7 +128,6 @@ def create_finder(project, name, pyname, only_calls=False, imports=True, class Occurrence(object): - def __init__(self, tools, offset): self.tools = tools self.offset = offset @@ -142,27 +151,26 @@ def get_pyname(self): @utils.saveit def get_primary_and_pyname(self): try: - return self.tools.name_finder.get_primary_and_pyname_at( - self.offset) + return self.tools.name_finder.get_primary_and_pyname_at(self.offset) except exceptions.BadIdentifierError: pass @utils.saveit def is_in_import_statement(self): - return (self.tools.word_finder.is_from_statement(self.offset) or - self.tools.word_finder.is_import_statement(self.offset)) + return self.tools.word_finder.is_from_statement( + self.offset + ) or self.tools.word_finder.is_import_statement(self.offset) def is_called(self): return self.tools.word_finder.is_a_function_being_called(self.offset) def is_defined(self): - return self.tools.word_finder.is_a_class_or_function_name_in_header( - self.offset) + return self.tools.word_finder.is_a_class_or_function_name_in_header(self.offset) def is_a_fixed_primary(self): return self.tools.word_finder.is_a_class_or_function_name_in_header( - self.offset) or \ - self.tools.word_finder.is_a_name_after_from_import(self.offset) + self.offset + ) or self.tools.word_finder.is_a_name_after_from_import(self.offset) def is_written(self): return self.tools.word_finder.is_assigned_here(self.offset) @@ -171,8 +179,7 @@ def is_unsure(self): return unsure_pyname(self.get_pyname()) def is_function_keyword_parameter(self): - return self.tools.word_finder.is_function_keyword_parameter( - self.offset) + return self.tools.word_finder.is_function_keyword_parameter(self.offset) @property @utils.saveit @@ -187,13 +194,14 @@ def same_pyname(expected, pyname): return False if expected == pyname: return True - if type(expected) not in (pynames.ImportedModule, pynames.ImportedName) \ - and type(pyname) not in \ - (pynames.ImportedModule, pynames.ImportedName): + if type(expected) not in (pynames.ImportedModule, pynames.ImportedName) and type( + pyname + ) not in (pynames.ImportedModule, pynames.ImportedName): return False - return expected.get_definition_location() == \ - pyname.get_definition_location() and \ - expected.get_object() == pyname.get_object() + return ( + expected.get_definition_location() == pyname.get_definition_location() + and expected.get_object() == pyname.get_object() + ) def unsure_pyname(pyname, unbound=True): @@ -243,7 +251,7 @@ def _get_containing_class(self, pyname): if isinstance(pyname, pynames.DefinedName): scope = pyname.get_object().get_scope() parent = scope.parent - if parent is not None and parent.get_kind() == 'Class': + if parent is not None and parent.get_kind() == "Class": return parent.pyobject def _get_root_classes(self, pyclass, name): @@ -294,15 +302,16 @@ def __call__(self, occurrence): class _TextualFinder(object): - def __init__(self, name, docs=False): self.name = name self.docs = docs - self.comment_pattern = _TextualFinder.any('comment', [r'#[^\n]*']) + self.comment_pattern = _TextualFinder.any("comment", [r"#[^\n]*"]) self.string_pattern = _TextualFinder.any( - 'string', [codeanalyze.get_string_pattern()]) + "string", [codeanalyze.get_string_pattern()] + ) self.f_string_pattern = _TextualFinder.any( - 'fstring', [codeanalyze.get_formatted_string_pattern()]) + "fstring", [codeanalyze.get_formatted_string_pattern()] + ) self.pattern = self._get_occurrence_pattern(self.name) def find_offsets(self, source): @@ -317,12 +326,12 @@ def find_offsets(self, source): def _re_search(self, source): for match in self.pattern.finditer(source): - if match.groupdict()['occurrence']: - yield match.start('occurrence') - elif utils.pycompat.PY36 and match.groupdict()['fstring']: - f_string = match.groupdict()['fstring'] + if match.groupdict()["occurrence"]: + yield match.start("occurrence") + elif utils.pycompat.PY36 and match.groupdict()["fstring"]: + f_string = match.groupdict()["fstring"] for occurrence_node in self._search_in_f_string(f_string): - yield match.start('fstring') + occurrence_node.col_offset + yield match.start("fstring") + occurrence_node.col_offset def _search_in_f_string(self, f_string): tree = ast.parse(f_string) @@ -336,16 +345,15 @@ def _normal_search(self, source): try: found = source.index(self.name, current) current = found + len(self.name) - if (found == 0 or - not self._is_id_char(source[found - 1])) and \ - (current == len(source) or - not self._is_id_char(source[current])): + if (found == 0 or not self._is_id_char(source[found - 1])) and ( + current == len(source) or not self._is_id_char(source[current]) + ): yield found except ValueError: break def _is_id_char(self, c): - return c.isalnum() or c == '_' + return c.isalnum() or c == "_" def _fast_file_query(self, source): try: @@ -361,20 +369,24 @@ def _get_source(self, resource, pymodule): return pymodule.source_code def _get_occurrence_pattern(self, name): - occurrence_pattern = _TextualFinder.any('occurrence', - ['\\b' + name + '\\b']) - pattern = re.compile(occurrence_pattern + '|' + self.comment_pattern + - '|' + self.string_pattern + '|' + - self.f_string_pattern) + occurrence_pattern = _TextualFinder.any("occurrence", ["\\b" + name + "\\b"]) + pattern = re.compile( + occurrence_pattern + + "|" + + self.comment_pattern + + "|" + + self.string_pattern + + "|" + + self.f_string_pattern + ) return pattern @staticmethod def any(name, list_): - return '(?P<%s>' % name + '|'.join(list_) + ')' + return "(?P<%s>" % name + "|".join(list_) + ")" class _OccurrenceToolsCreator(object): - def __init__(self, project, resource=None, pymodule=None, docs=False): self.project = project self.__resource = resource diff --git a/rope/refactor/patchedast.py b/rope/refactor/patchedast.py index 78af8c9c2..c8e2e3e3a 100644 --- a/rope/refactor/patchedast.py +++ b/rope/refactor/patchedast.py @@ -35,7 +35,7 @@ def patch_ast(node, source, sorted_children=False): them. """ - if hasattr(node, 'region'): + if hasattr(node, "region"): return node walker = _PatchingASTWalker(source, children=sorted_children) ast.call_for_nodes(node, walker) @@ -59,7 +59,7 @@ def write_ast(patched_ast_node): result.append(write_ast(child)) else: result.append(child) - return ''.join(result) + return "".join(result) class MismatchedTokenError(exceptions.RopeError): @@ -67,7 +67,6 @@ class MismatchedTokenError(exceptions.RopeError): class _PatchingASTWalker(object): - def __init__(self, source, children=False): self.source = _Source(source) self.children = children @@ -82,22 +81,26 @@ def __init__(self, source, children=False): exec_in_or_comma = object() def __call__(self, node): - method = getattr(self, '_' + node.__class__.__name__, None) + method = getattr(self, "_" + node.__class__.__name__, None) if method is not None: return method(node) # ???: Unknown node; what should we do here? - warnings.warn('Unknown node type <%s>; please report!' - % node.__class__.__name__, RuntimeWarning) + warnings.warn( + "Unknown node type <%s>; please report!" % node.__class__.__name__, + RuntimeWarning, + ) node.region = (self.source.offset, self.source.offset) if self.children: node.sorted_children = ast.get_children(node) def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): - if hasattr(node, 'region'): + if hasattr(node, "region"): # ???: The same node was seen twice; what should we do? warnings.warn( - 'Node <%s> has been already patched; please report!' % - node.__class__.__name__, RuntimeWarning) + "Node <%s> has been already patched; please report!" + % node.__class__.__name__, + RuntimeWarning, + ) return base_children = collections.deque(base_children) self.children_stack.append(base_children) @@ -117,10 +120,11 @@ def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): else: if child is self.String: region = self.source.consume_string( - end=self._find_next_statement_start()) + end=self._find_next_statement_start() + ) elif child is self.Number: region = self.source.consume_number() - elif child == '!=': + elif child == "!=": # INFO: This has been added to handle deprecated ``<>`` region = self.source.consume_not_equal() elif child == self.semicolon_or_as_in_except: @@ -137,11 +141,13 @@ def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): elif child == self.exec_close_paren_or_space: region = self.source.consume_exec_close_paren_or_space() else: - if hasattr(ast, 'JoinedStr') and isinstance(node, (ast.JoinedStr, ast.FormattedValue)): + if hasattr(ast, "JoinedStr") and isinstance( + node, (ast.JoinedStr, ast.FormattedValue) + ): region = self.source.consume_joined_string(child) else: region = self.source.consume(child) - child = self.source[region[0]:region[1]] + child = self.source[region[0] : region[1]] token_start = region[0] if not first_token: formats.append(self.source[offset:token_start]) @@ -154,12 +160,11 @@ def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): children.append(child) start = self._handle_parens(children, start, formats) if eat_parens: - start = self._eat_surrounding_parens( - children, suspected_start, start) + start = self._eat_surrounding_parens(children, suspected_start, start) if eat_spaces: if self.children: children.appendleft(self.source[0:start]) - end_spaces = self.source[self.source.offset:] + end_spaces = self.source[self.source.offset :] self.source.consume(end_spaces) if self.children: children.append(end_spaces) @@ -175,13 +180,13 @@ def _handle_parens(self, children, start, formats): old_end = self.source.offset new_end = None for i in range(closes): - new_end = self.source.consume(')')[1] + new_end = self.source.consume(")")[1] if new_end is not None: if self.children: children.append(self.source[old_end:new_end]) new_start = start for i in range(opens): - new_start = self.source.rfind_token('(', 0, new_start) + new_start = self.source.rfind_token("(", 0, new_start) if new_start != start: if self.children: children.appendleft(self.source[new_start:start]) @@ -189,18 +194,18 @@ def _handle_parens(self, children, start, formats): return start def _eat_surrounding_parens(self, children, suspected_start, start): - index = self.source.rfind_token('(', suspected_start, start) + index = self.source.rfind_token("(", suspected_start, start) if index is not None: old_start = start old_offset = self.source.offset start = index if self.children: - children.appendleft(self.source[start + 1:old_start]) - children.appendleft('(') - token_start, token_end = self.source.consume(')') + children.appendleft(self.source[start + 1 : old_start]) + children.appendleft("(") + token_start, token_end = self.source.consume(")") if self.children: children.append(self.source[old_offset:token_start]) - children.append(')') + children.append(")") return start def _count_needed_parens(self, children): @@ -209,20 +214,20 @@ def _count_needed_parens(self, children): for child in children: if not isinstance(child, basestring): continue - if child == '' or child[0] in '\'"': + if child == "" or child[0] in "'\"": continue index = 0 while index < len(child): - if child[index] == ')': + if child[index] == ")": if opens > 0: opens -= 1 else: start += 1 - if child[index] == '(': + if child[index] == "(": opens += 1 - if child[index] == '#': + if child[index] == "#": try: - index = child.index('\n', index) + index = child.index("\n", index) except ValueError: break index += 1 @@ -232,115 +237,135 @@ def _find_next_statement_start(self): for children in reversed(self.children_stack): for child in children: if isinstance(child, ast.stmt): - return child.col_offset \ - + self.lines.get_line_start(child.lineno) + return child.col_offset + self.lines.get_line_start(child.lineno) return len(self.source.source) - _operators = {'And': 'and', 'Or': 'or', 'Add': '+', 'Sub': '-', - 'Mult': '*', 'Div': '/', 'Mod': '%', 'Pow': '**', - 'LShift': '<<', 'RShift': '>>', 'BitOr': '|', 'BitAnd': '&', - 'BitXor': '^', 'FloorDiv': '//', 'Invert': '~', - 'Not': 'not', 'UAdd': '+', 'USub': '-', 'Eq': '==', - 'NotEq': '!=', 'Lt': '<', 'LtE': '<=', 'Gt': '>', - 'GtE': '>=', 'Is': 'is', 'IsNot': 'is not', 'In': 'in', - 'NotIn': 'not in'} + _operators = { + "And": "and", + "Or": "or", + "Add": "+", + "Sub": "-", + "Mult": "*", + "Div": "/", + "Mod": "%", + "Pow": "**", + "LShift": "<<", + "RShift": ">>", + "BitOr": "|", + "BitAnd": "&", + "BitXor": "^", + "FloorDiv": "//", + "Invert": "~", + "Not": "not", + "UAdd": "+", + "USub": "-", + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "is", + "IsNot": "is not", + "In": "in", + "NotIn": "not in", + } def _get_op(self, node): - return self._operators[node.__class__.__name__].split(' ') + return self._operators[node.__class__.__name__].split(" ") def _Attribute(self, node): - self._handle(node, [node.value, '.', node.attr]) + self._handle(node, [node.value, ".", node.attr]) def _Assert(self, node): - children = ['assert', node.test] + children = ["assert", node.test] if node.msg: - children.append(',') + children.append(",") children.append(node.msg) self._handle(node, children) def _Assign(self, node): - children = self._child_nodes(node.targets, '=') - children.append('=') + children = self._child_nodes(node.targets, "=") + children.append("=") children.append(node.value) self._handle(node, children) def _AugAssign(self, node): children = [node.target] children.extend(self._get_op(node.op)) - children.extend(['=', node.value]) + children.extend(["=", node.value]) self._handle(node, children) def _AnnAssign(self, node): - children = [node.target, ':', node.annotation] + children = [node.target, ":", node.annotation] if node.value is not None: - children.append('=') + children.append("=") children.append(node.value) self._handle(node, children) def _Repr(self, node): - self._handle(node, ['`', node.value, '`']) + self._handle(node, ["`", node.value, "`"]) def _BinOp(self, node): children = [node.left] + self._get_op(node.op) + [node.right] self._handle(node, children) def _BoolOp(self, node): - self._handle(node, self._child_nodes(node.values, - self._get_op(node.op)[0])) + self._handle(node, self._child_nodes(node.values, self._get_op(node.op)[0])) def _Break(self, node): - self._handle(node, ['break']) + self._handle(node, ["break"]) def _Call(self, node): def _arg_sort_key(node): - if isinstance(node, ast.keyword): - return (node.value.lineno, node.value.col_offset) - return (node.lineno, node.col_offset) + if isinstance(node, ast.keyword): + return (node.value.lineno, node.value.col_offset) + return (node.lineno, node.col_offset) - children = [node.func, '('] + children = [node.func, "("] unstarred_args = [] starred_and_keywords = list(node.keywords) for i, arg in enumerate(node.args): - if hasattr(ast, 'Starred') and isinstance(arg, ast.Starred): - starred_and_keywords.append(arg) - else: - unstarred_args.append(arg) - if getattr(node, 'starargs', None): - starred_and_keywords.append(node.starargs) + if hasattr(ast, "Starred") and isinstance(arg, ast.Starred): + starred_and_keywords.append(arg) + else: + unstarred_args.append(arg) + if getattr(node, "starargs", None): + starred_and_keywords.append(node.starargs) starred_and_keywords.sort(key=_arg_sort_key) - children.extend(self._child_nodes(unstarred_args, ',')) + children.extend(self._child_nodes(unstarred_args, ",")) # positional args come before keywords, *args comes after all # positional args, and **kwargs comes last if starred_and_keywords: - if len(children) > 2: - children.append(',') - for i, arg in enumerate(starred_and_keywords): - if arg == getattr(node, 'starargs', None): - children.append('*') - children.append(arg) - if i + 1 < len(starred_and_keywords): - children.append(',') - - if getattr(node, 'kwargs', None): if len(children) > 2: - children.append(',') - children.extend(['**', node.kwargs]) - children.append(')') + children.append(",") + for i, arg in enumerate(starred_and_keywords): + if arg == getattr(node, "starargs", None): + children.append("*") + children.append(arg) + if i + 1 < len(starred_and_keywords): + children.append(",") + + if getattr(node, "kwargs", None): + if len(children) > 2: + children.append(",") + children.extend(["**", node.kwargs]) + children.append(")") self._handle(node, children) def _ClassDef(self, node): children = [] - if getattr(node, 'decorator_list', None): + if getattr(node, "decorator_list", None): for decorator in node.decorator_list: - children.append('@') + children.append("@") children.append(decorator) - children.extend(['class', node.name]) + children.extend(["class", node.name]) if node.bases: - children.append('(') - children.extend(self._child_nodes(node.bases, ',')) - children.append(')') - children.append(':') + children.append("(") + children.extend(self._child_nodes(node.bases, ",")) + children.append(")") + children.append(":") children.extend(node.body) self._handle(node, children) @@ -353,7 +378,7 @@ def _Compare(self, node): self._handle(node, children) def _Delete(self, node): - self._handle(node, ['del'] + self._child_nodes(node.targets, ',')) + self._handle(node, ["del"] + self._child_nodes(node.targets, ",")) def _Constant(self, node): if isinstance(node.value, basestring): @@ -369,7 +394,7 @@ def _Constant(self, node): return if node.value is Ellipsis: - self._handle(node, ['...']) + self._handle(node, ["..."]) return assert False @@ -385,14 +410,23 @@ def _Bytes(self, node): def _JoinedStr(self, node): def start_quote_char(): - possible_quotes = [(self.source.source.find(q, start, end), q) for q in QUOTE_CHARS] - quote_pos, quote_char = min((pos, q) for pos, q in possible_quotes if pos != -1) - return self.source[start:quote_pos + len(quote_char)] + possible_quotes = [ + (self.source.source.find(q, start, end), q) for q in QUOTE_CHARS + ] + quote_pos, quote_char = min( + (pos, q) for pos, q in possible_quotes if pos != -1 + ) + return self.source[start : quote_pos + len(quote_char)] def end_quote_char(): - possible_quotes = [(self.source.source.rfind(q, start, end), q) for q in reversed(QUOTE_CHARS)] - _, quote_pos, quote_char = max((len(q), pos, q) for pos, q in possible_quotes if pos != -1) - return self.source[end - len(quote_char):end] + possible_quotes = [ + (self.source.source.rfind(q, start, end), q) + for q in reversed(QUOTE_CHARS) + ] + _, quote_pos, quote_char = max( + (len(q), pos, q) for pos, q in possible_quotes if pos != -1 + ) + return self.source[end - len(quote_char) : end] QUOTE_CHARS = ['"""', "'''", '"', "'"] offset = self.source.offset @@ -411,52 +445,52 @@ def end_quote_char(): def _FormattedValue(self, node): children = [] - children.append('{') + children.append("{") children.append(node.value) if node.format_spec: - children.append(':') + children.append(":") for val in node.format_spec.values: if isinstance(val, ast.FormattedValue): children.append(val.value) else: children.append(val.s) - children.append('}') + children.append("}") self._handle(node, children) def _Continue(self, node): - self._handle(node, ['continue']) + self._handle(node, ["continue"]) def _Dict(self, node): children = [] - children.append('{') + children.append("{") if node.keys: for index, (key, value) in enumerate(zip(node.keys, node.values)): if key is None: # PEP-448 dict unpacking: {a: b, **unpack} - children.extend(['**', value]) + children.extend(["**", value]) else: - children.extend([key, ':', value]) + children.extend([key, ":", value]) if index < len(node.keys) - 1: - children.append(',') - children.append('}') + children.append(",") + children.append("}") self._handle(node, children) def _Ellipsis(self, node): - self._handle(node, ['...']) + self._handle(node, ["..."]) def _Expr(self, node): self._handle(node, [node.value]) def _NamedExpr(self, node): - children = [node.target, ':=', node.value] + children = [node.target, ":=", node.value] self._handle(node, children) def _Exec(self, node): - children = ['exec', self.exec_open_paren_or_space, node.body] + children = ["exec", self.exec_open_paren_or_space, node.body] if node.globals: children.extend([self.exec_in_or_comma, node.globals]) if node.locals: - children.extend([',', node.locals]) + children.extend([",", node.locals]) children.append(self.exec_close_paren_or_space) self._handle(node, children) @@ -464,19 +498,19 @@ def _ExtSlice(self, node): children = [] for index, dim in enumerate(node.dims): if index > 0: - children.append(',') + children.append(",") children.append(dim) self._handle(node, children) def _handle_for_loop_node(self, node, is_async): if is_async: - children = ['async', 'for'] + children = ["async", "for"] else: - children = ['for'] - children.extend([node.target, 'in', node.iter, ':']) + children = ["for"] + children.extend([node.target, "in", node.iter, ":"]) children.extend(node.body) if node.orelse: - children.extend(['else', ':']) + children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) @@ -487,37 +521,36 @@ def _AsyncFor(self, node): self._handle_for_loop_node(node, is_async=True) def _ImportFrom(self, node): - children = ['from'] + children = ["from"] if node.level: - children.append('.' * node.level) + children.append("." * node.level) # see comment at rope.base.ast.walk - children.extend([node.module or '', - 'import']) - children.extend(self._child_nodes(node.names, ',')) + children.extend([node.module or "", "import"]) + children.extend(self._child_nodes(node.names, ",")) self._handle(node, children) def _alias(self, node): children = [node.name] if node.asname: - children.extend(['as', node.asname]) + children.extend(["as", node.asname]) self._handle(node, children) def _handle_function_def_node(self, node, is_async): children = [] try: - decorators = getattr(node, 'decorator_list') + decorators = getattr(node, "decorator_list") except AttributeError: - decorators = getattr(node, 'decorators', None) + decorators = getattr(node, "decorators", None) if decorators: for decorator in decorators: - children.append('@') + children.append("@") children.append(decorator) if is_async: - children.extend(['async', 'def']) + children.extend(["async", "def"]) else: - children.extend(['def']) - children.extend([node.name, '(', node.args]) - children.extend([')', ':']) + children.extend(["def"]) + children.extend([node.name, "(", node.args]) + children.extend([")", ":"]) children.extend(node.body) self._handle(node, children) @@ -530,20 +563,19 @@ def _AsyncFunctionDef(self, node): def _arguments(self, node): children = [] args = list(node.args) - defaults = [None] * (len(args) - len(node.defaults)) + \ - list(node.defaults) + defaults = [None] * (len(args) - len(node.defaults)) + list(node.defaults) for index, (arg, default) in enumerate(zip(args, defaults)): if index > 0: - children.append(',') + children.append(",") self._add_args_to_children(children, arg, default) if node.vararg is not None: if args: - children.append(',') - children.extend(['*', pycompat.get_ast_arg_arg(node.vararg)]) + children.append(",") + children.extend(["*", pycompat.get_ast_arg_arg(node.vararg)]) if node.kwarg is not None: if args or node.vararg is not None: - children.append(',') - children.extend(['**', pycompat.get_ast_arg_arg(node.kwarg)]) + children.append(",") + children.extend(["**", pycompat.get_ast_arg_arg(node.kwarg)]) self._handle(node, children) def _add_args_to_children(self, children, arg, default): @@ -552,19 +584,19 @@ def _add_args_to_children(self, children, arg, default): else: children.append(arg) if default is not None: - children.append('=') + children.append("=") children.append(default) def _add_tuple_parameter(self, children, arg): - children.append('(') + children.append("(") for index, token in enumerate(arg): if index > 0: - children.append(',') + children.append(",") if isinstance(token, (list, tuple)): self._add_tuple_parameter(children, token) else: children.append(token) - children.append(')') + children.append(")") def _GeneratorExp(self, node): children = [node.elt] @@ -572,30 +604,30 @@ def _GeneratorExp(self, node): self._handle(node, children, eat_parens=True) def _comprehension(self, node): - children = ['for', node.target, 'in', node.iter] + children = ["for", node.target, "in", node.iter] if node.ifs: for if_ in node.ifs: - children.append('if') + children.append("if") children.append(if_) self._handle(node, children) def _Global(self, node): - children = self._child_nodes(node.names, ',') - children.insert(0, 'global') + children = self._child_nodes(node.names, ",") + children.insert(0, "global") self._handle(node, children) def _If(self, node): if self._is_elif(node): - children = ['elif'] + children = ["elif"] else: - children = ['if'] - children.extend([node.test, ':']) + children = ["if"] + children.extend([node.test, ":"]) children.extend(node.body) if node.orelse: if len(node.orelse) == 1 and self._is_elif(node.orelse[0]): pass else: - children.extend(['else', ':']) + children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) @@ -603,18 +635,17 @@ def _is_elif(self, node): if not isinstance(node, ast.If): return False offset = self.lines.get_line_start(node.lineno) + node.col_offset - word = self.source[offset:offset + 4] + word = self.source[offset : offset + 4] # XXX: This is a bug; the offset does not point to the first - alt_word = self.source[offset - 5:offset - 1] - return 'elif' in (word, alt_word) + alt_word = self.source[offset - 5 : offset - 1] + return "elif" in (word, alt_word) def _IfExp(self, node): - return self._handle(node, [node.body, 'if', node.test, - 'else', node.orelse]) + return self._handle(node, [node.body, "if", node.test, "else", node.orelse]) def _Import(self, node): - children = ['import'] - children.extend(self._child_nodes(node.names, ',')) + children = ["import"] + children.extend(self._child_nodes(node.names, ",")) self._handle(node, children) def _keyword(self, node): @@ -622,42 +653,42 @@ def _keyword(self, node): if node.arg is None: children.append(node.value) else: - children.extend([node.arg, '=', node.value]) + children.extend([node.arg, "=", node.value]) self._handle(node, children) def _Lambda(self, node): - self._handle(node, ['lambda', node.args, ':', node.body]) + self._handle(node, ["lambda", node.args, ":", node.body]) def _List(self, node): - self._handle(node, ['['] + self._child_nodes(node.elts, ',') + [']']) + self._handle(node, ["["] + self._child_nodes(node.elts, ",") + ["]"]) def _ListComp(self, node): - children = ['[', node.elt] + children = ["[", node.elt] children.extend(node.generators) - children.append(']') + children.append("]") self._handle(node, children) def _Set(self, node): if node.elts: - self._handle(node, - ['{'] + self._child_nodes(node.elts, ',') + ['}']) + self._handle(node, ["{"] + self._child_nodes(node.elts, ",") + ["}"]) return # Python doesn't have empty set literals - warnings.warn('Tried to handle empty literal; please report!', - RuntimeWarning) - self._handle(node, ['set(', ')']) + warnings.warn( + "Tried to handle empty literal; please report!", RuntimeWarning + ) + self._handle(node, ["set(", ")"]) def _SetComp(self, node): - children = ['{', node.elt] + children = ["{", node.elt] children.extend(node.generators) - children.append('}') + children.append("}") self._handle(node, children) def _DictComp(self, node): - children = ['{'] - children.extend([node.key, ':', node.value]) + children = ["{"] + children.extend([node.key, ":", node.value]) children.extend(node.generators) - children.append('}') + children.append("}") self._handle(node, children) def _Module(self, node): @@ -673,23 +704,22 @@ def _arg(self, node): self._handle(node, [node.arg]) def _Pass(self, node): - self._handle(node, ['pass']) + self._handle(node, ["pass"]) def _Print(self, node): - children = ['print'] + children = ["print"] if node.dest: - children.extend(['>>', node.dest]) + children.extend([">>", node.dest]) if node.values: - children.append(',') - children.extend(self._child_nodes(node.values, ',')) + children.append(",") + children.extend(self._child_nodes(node.values, ",")) if not node.nl: - children.append(',') + children.append(",") self._handle(node, children) def _Raise(self, node): - def get_python3_raise_children(node): - children = ['raise'] + children = ["raise"] if node.exc: children.append(node.exc) if node.cause: @@ -697,16 +727,17 @@ def get_python3_raise_children(node): return children def get_python2_raise_children(node): - children = ['raise'] + children = ["raise"] if node.type: children.append(node.type) if node.inst: - children.append(',') + children.append(",") children.append(node.inst) if node.tback: - children.append(',') + children.append(",") children.append(node.tback) return children + if pycompat.PY2: children = get_python2_raise_children(node) else: @@ -714,7 +745,7 @@ def get_python2_raise_children(node): self._handle(node, children) def _Return(self, node): - children = ['return'] + children = ["return"] if node.value: children.append(node.value) self._handle(node, children) @@ -723,7 +754,7 @@ def _Sliceobj(self, node): children = [] for index, slice in enumerate(node.nodes): if index > 0: - children.append(':') + children.append(":") if slice: children.append(slice) self._handle(node, children) @@ -732,17 +763,17 @@ def _Index(self, node): self._handle(node, [node.value]) def _Subscript(self, node): - self._handle(node, [node.value, '[', node.slice, ']']) + self._handle(node, [node.value, "[", node.slice, "]"]) def _Slice(self, node): children = [] if node.lower: children.append(node.lower) - children.append(':') + children.append(":") if node.upper: children.append(node.upper) if node.step: - children.append(':') + children.append(":") children.append(node.step) self._handle(node, children) @@ -756,26 +787,28 @@ def _TryFinally(self, node): not_empty_body = not bool(len(node.body)) elif pycompat.PY3: try: - is_there_except_handler = isinstance(node.handlers[0], ast.ExceptHandler) + is_there_except_handler = isinstance( + node.handlers[0], ast.ExceptHandler + ) not_empty_body = True except IndexError: pass children = [] if not_empty_body or not is_there_except_handler: - children.extend(['try', ':']) + children.extend(["try", ":"]) children.extend(node.body) if pycompat.PY3: children.extend(node.handlers) - children.extend(['finally', ':']) + children.extend(["finally", ":"]) children.extend(node.finalbody) self._handle(node, children) def _TryExcept(self, node): - children = ['try', ':'] + children = ["try", ":"] children.extend(node.body) children.extend(node.handlers) if node.orelse: - children.extend(['else', ':']) + children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) @@ -790,23 +823,22 @@ def _ExceptHandler(self, node): def _excepthandler(self, node): # self._handle(node, [self.semicolon_or_as_in_except]) - children = ['except'] + children = ["except"] if node.type: children.append(node.type) if node.name: children.append(self.semicolon_or_as_in_except) children.append(node.name) - children.append(':') + children.append(":") children.extend(node.body) self._handle(node, children) def _Tuple(self, node): if node.elts: - self._handle(node, self._child_nodes(node.elts, ','), - eat_parens=True) + self._handle(node, self._child_nodes(node.elts, ","), eat_parens=True) else: - self._handle(node, ['(', ')']) + self._handle(node, ["(", ")"]) def _UnaryOp(self, node): children = self._get_op(node.op) @@ -814,32 +846,32 @@ def _UnaryOp(self, node): self._handle(node, children) def _Await(self, node): - children = ['await'] + children = ["await"] if node.value: children.append(node.value) self._handle(node, children) def _Yield(self, node): - children = ['yield'] + children = ["yield"] if node.value: children.append(node.value) self._handle(node, children) def _While(self, node): - children = ['while', node.test, ':'] + children = ["while", node.test, ":"] children.extend(node.body) if node.orelse: - children.extend(['else', ':']) + children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) def _With(self, node): children = [] for item in pycompat.get_ast_with_items(node): - children.extend(['with', item.context_expr]) + children.extend(["with", item.context_expr]) if item.optional_vars: - children.extend(['as', item.optional_vars]) - children.append(':') + children.extend(["as", item.optional_vars]) + children.append(":") children.extend(node.body) self._handle(node, children) @@ -854,8 +886,8 @@ def _child_nodes(self, nodes, separator): def _Starred(self, node): self._handle(node, [node.value]) -class _Source(object): +class _Source(object): def __init__(self, source): self.source = source self.offset = 0 @@ -870,8 +902,8 @@ def consume(self, token, skip_comment=True): self._skip_comment() except (ValueError, TypeError) as e: raise MismatchedTokenError( - 'Token <%s> at %s cannot be matched' % - (token, self._get_location())) + "Token <%s> at %s cannot be matched" % (token, self._get_location()) + ) self.offset = new_offset + len(token) return (new_offset, self.offset) @@ -884,40 +916,38 @@ def consume_string(self, end=None): if _Source._string_pattern is None: string_pattern = codeanalyze.get_string_pattern() formatted_string_pattern = codeanalyze.get_formatted_string_pattern() - original = r'(?:%s)|(?:%s)' % (string_pattern, formatted_string_pattern) - pattern = r'(%s)((\s|\\\n|#[^\n]*\n)*(%s))*' % \ - (original, original) + original = r"(?:%s)|(?:%s)" % (string_pattern, formatted_string_pattern) + pattern = r"(%s)((\s|\\\n|#[^\n]*\n)*(%s))*" % (original, original) _Source._string_pattern = re.compile(pattern) repattern = _Source._string_pattern return self._consume_pattern(repattern, end) def consume_number(self): if _Source._number_pattern is None: - _Source._number_pattern = re.compile( - self._get_number_pattern()) + _Source._number_pattern = re.compile(self._get_number_pattern()) repattern = _Source._number_pattern return self._consume_pattern(repattern) def consume_not_equal(self): if _Source._not_equals_pattern is None: - _Source._not_equals_pattern = re.compile(r'<>|!=') + _Source._not_equals_pattern = re.compile(r"<>|!=") repattern = _Source._not_equals_pattern return self._consume_pattern(repattern) def consume_except_as_or_semicolon(self): - repattern = re.compile(r'as|,') + repattern = re.compile(r"as|,") return self._consume_pattern(repattern) def consume_exec_open_paren_or_space(self): - repattern = re.compile(r'\(|') + repattern = re.compile(r"\(|") return self._consume_pattern(repattern) def consume_exec_in_or_comma(self): - repattern = re.compile(r'in|,') + repattern = re.compile(r"in|,") return self._consume_pattern(repattern) def consume_exec_close_paren_or_space(self): - repattern = re.compile(r'\)|') + repattern = re.compile(r"\)|") return self._consume_pattern(repattern) def _good_token(self, token, offset, start=None): @@ -925,20 +955,20 @@ def _good_token(self, token, offset, start=None): if start is None: start = self.offset try: - comment_index = self.source.rindex('#', start, offset) + comment_index = self.source.rindex("#", start, offset) except ValueError: return True try: - new_line_index = self.source.rindex('\n', start, offset) + new_line_index = self.source.rindex("\n", start, offset) except ValueError: return False return comment_index < new_line_index def _skip_comment(self): - self.offset = self.source.index('\n', self.offset + 1) + self.offset = self.source.index("\n", self.offset + 1) def _get_location(self): - lines = self.source[:self.offset].split('\n') + lines = self.source[: self.offset].split("\n") return (len(lines), len(lines[-1])) def _consume_pattern(self, repattern, end=None): @@ -955,7 +985,7 @@ def _consume_pattern(self, repattern, end=None): def till_token(self, token): new_offset = self.source.index(token, self.offset) - return self[self.offset:new_offset] + return self[self.offset : new_offset] def rfind_token(self, token, start, end): index = start @@ -970,7 +1000,7 @@ def rfind_token(self, token, start, end): return None def from_offset(self, offset): - return self[offset:self.offset] + return self[offset : self.offset] def find_backwards(self, pattern, offset): return self.source.rindex(pattern, 0, offset) @@ -983,8 +1013,8 @@ def __getslice__(self, i, j): def _get_number_pattern(self): # HACK: It is merely an approaximation and does the job - integer = r'\-?(0x[\da-fA-F]+|\d+)[lL]?' - return r'(%s(\.\d*)?|(\.\d+))([eE][-+]?\d+)?[jJ]?' % integer + integer = r"\-?(0x[\da-fA-F]+|\d+)[lL]?" + return r"(%s(\.\d*)?|(\.\d+))([eE][-+]?\d+)?[jJ]?" % integer _string_pattern = None _number_pattern = None diff --git a/rope/refactor/rename.py b/rope/refactor/rename.py index 8f7390b03..2ea41a24e 100644 --- a/rope/refactor/rename.py +++ b/rope/refactor/rename.py @@ -1,7 +1,15 @@ import warnings -from rope.base import (exceptions, pyobjects, pynames, taskhandle, - evaluate, worder, codeanalyze, libutils) +from rope.base import ( + exceptions, + pyobjects, + pynames, + taskhandle, + evaluate, + worder, + codeanalyze, + libutils, +) from rope.base.change import ChangeSet, ChangeContents, MoveResource from rope.refactor import occurrences @@ -21,19 +29,20 @@ def __init__(self, project, resource, offset=None): if offset is not None: self.old_name = worder.get_name_at(self.resource, offset) this_pymodule = self.project.get_pymodule(self.resource) - self.old_instance, self.old_pyname = \ - evaluate.eval_location2(this_pymodule, offset) + self.old_instance, self.old_pyname = evaluate.eval_location2( + this_pymodule, offset + ) if self.old_pyname is None: raise exceptions.RefactoringError( - 'Rename refactoring should be performed' - ' on resolvable python identifiers.') + "Rename refactoring should be performed" + " on resolvable python identifiers." + ) else: - if not resource.is_folder() and resource.name == '__init__.py': + if not resource.is_folder() and resource.name == "__init__.py": resource = resource.parent - dummy_pymodule = libutils.get_string_module(self.project, '') + dummy_pymodule = libutils.get_string_module(self.project, "") self.old_instance = None - self.old_pyname = pynames.ImportedModule(dummy_pymodule, - resource=resource) + self.old_pyname = pynames.ImportedModule(dummy_pymodule, resource=resource) if resource.is_folder(): self.old_name = resource.name else: @@ -42,9 +51,16 @@ def __init__(self, project, resource, offset=None): def get_old_name(self): return self.old_name - def get_changes(self, new_name, in_file=None, in_hierarchy=False, - unsure=None, docs=False, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, + new_name, + in_file=None, + in_hierarchy=False, + unsure=None, + docs=False, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Get the changes needed for this refactoring Parameters: @@ -68,30 +84,38 @@ def get_changes(self, new_name, in_file=None, in_hierarchy=False, """ if unsure in (True, False): warnings.warn( - 'unsure parameter should be a function that returns ' - 'True or False', DeprecationWarning, stacklevel=2) + "unsure parameter should be a function that returns " "True or False", + DeprecationWarning, + stacklevel=2, + ) def unsure_func(value=unsure): return value + unsure = unsure_func if in_file is not None: warnings.warn( - '`in_file` argument has been deprecated; use `resources` ' - 'instead. ', DeprecationWarning, stacklevel=2) + "`in_file` argument has been deprecated; use `resources` " "instead. ", + DeprecationWarning, + stacklevel=2, + ) if in_file: resources = [self.resource] if _is_local(self.old_pyname): resources = [self.resource] if resources is None: resources = self.project.get_python_files() - changes = ChangeSet('Renaming <%s> to <%s>' % - (self.old_name, new_name)) + changes = ChangeSet("Renaming <%s> to <%s>" % (self.old_name, new_name)) finder = occurrences.create_finder( - self.project, self.old_name, self.old_pyname, unsure=unsure, - docs=docs, instance=self.old_instance, - in_hierarchy=in_hierarchy and self.is_method()) - job_set = task_handle.create_jobset('Collecting Changes', - len(resources)) + self.project, + self.old_name, + self.old_pyname, + unsure=unsure, + docs=docs, + instance=self.old_instance, + in_hierarchy=in_hierarchy and self.is_method(), + ) + job_set = task_handle.create_jobset("Collecting Changes", len(resources)) for file_ in resources: job_set.started_job(file_.path) new_content = rename_in_module(finder, new_name, resource=file_) @@ -107,7 +131,7 @@ def unsure_func(value=unsure): def _is_allowed_to_move(self, resources, resource): if resource.is_folder(): try: - return resource.get_child('__init__.py') in resources + return resource.get_child("__init__.py") in resources except exceptions.ResourceNotFoundError: return False else: @@ -120,18 +144,20 @@ def _is_renaming_a_module(self): def is_method(self): pyname = self.old_pyname - return isinstance(pyname, pynames.DefinedName) and \ - isinstance(pyname.get_object(), pyobjects.PyFunction) and \ - isinstance(pyname.get_object().parent, pyobjects.PyClass) + return ( + isinstance(pyname, pynames.DefinedName) + and isinstance(pyname.get_object(), pyobjects.PyFunction) + and isinstance(pyname.get_object().parent, pyobjects.PyClass) + ) def _rename_module(self, resource, new_name, changes): if not resource.is_folder(): - new_name = new_name + '.py' + new_name = new_name + ".py" parent_path = resource.parent.path - if parent_path == '': + if parent_path == "": new_location = new_name else: - new_location = parent_path + '/' + new_name + new_location = parent_path + "/" + new_name changes.add_change(MoveResource(resource, new_location)) @@ -162,30 +188,49 @@ def get_old_name(self): def _get_scope_offset(self): lines = self.pymodule.lines - scope = self.pymodule.get_scope().\ - get_inner_scope_for_line(lines.get_line_number(self.offset)) + scope = self.pymodule.get_scope().get_inner_scope_for_line( + lines.get_line_number(self.offset) + ) start = lines.get_line_start(scope.get_start()) end = lines.get_line_end(scope.get_end()) return start, end def get_changes(self, new_name, only_calls=False, reads=True, writes=True): - changes = ChangeSet('Changing <%s> occurrences to <%s>' % - (self.old_name, new_name)) + changes = ChangeSet( + "Changing <%s> occurrences to <%s>" % (self.old_name, new_name) + ) scope_start, scope_end = self._get_scope_offset() finder = occurrences.create_finder( - self.project, self.old_name, self.old_pyname, - imports=False, only_calls=only_calls) + self.project, + self.old_name, + self.old_pyname, + imports=False, + only_calls=only_calls, + ) new_contents = rename_in_module( - finder, new_name, pymodule=self.pymodule, replace_primary=True, - region=(scope_start, scope_end), reads=reads, writes=writes) + finder, + new_name, + pymodule=self.pymodule, + replace_primary=True, + region=(scope_start, scope_end), + reads=reads, + writes=writes, + ) if new_contents is not None: changes.add_change(ChangeContents(self.resource, new_contents)) return changes -def rename_in_module(occurrences_finder, new_name, resource=None, - pymodule=None, replace_primary=False, region=None, - reads=True, writes=True): +def rename_in_module( + occurrences_finder, + new_name, + resource=None, + pymodule=None, + replace_primary=False, + region=None, + reads=True, + writes=True, +): """Returns the changed source or `None` if there is no changes""" if resource is not None: source_code = resource.read() @@ -199,8 +244,9 @@ def rename_in_module(occurrences_finder, new_name, resource=None, start, end = occurrence.get_primary_range() else: start, end = occurrence.get_word_range() - if (not reads and not occurrence.is_written()) or \ - (not writes and occurrence.is_written()): + if (not reads and not occurrence.is_written()) or ( + not writes and occurrence.is_written() + ): continue if region is None or region[0] <= start < region[1]: change_collector.add_change(start, end, new_name) @@ -212,9 +258,13 @@ def _is_local(pyname): if lineno is None: return False scope = module.get_scope().get_inner_scope_for_line(lineno) - if isinstance(pyname, pynames.DefinedName) and \ - scope.get_kind() in ('Function', 'Class'): + if isinstance(pyname, pynames.DefinedName) and scope.get_kind() in ( + "Function", + "Class", + ): scope = scope.parent - return scope.get_kind() == 'Function' and \ - pyname in scope.get_names().values() and \ - isinstance(pyname, pynames.AssignedName) + return ( + scope.get_kind() == "Function" + and pyname in scope.get_names().values() + and isinstance(pyname, pynames.AssignedName) + ) diff --git a/rope/refactor/restructure.py b/rope/refactor/restructure.py index 365caf888..e65a9b464 100644 --- a/rope/refactor/restructure.py +++ b/rope/refactor/restructure.py @@ -72,8 +72,7 @@ class Restructure(object): """ - def __init__(self, project, pattern, goal, args=None, - imports=None, wildcards=None): + def __init__(self, project, pattern, goal, args=None, imports=None, wildcards=None): """Construct a restructuring See class pydoc for more info about the arguments. @@ -91,8 +90,13 @@ def __init__(self, project, pattern, goal, args=None, self.wildcards = wildcards self.template = similarfinder.CodeTemplate(self.goal) - def get_changes(self, checks=None, imports=None, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes( + self, + checks=None, + imports=None, + resources=None, + task_handle=taskhandle.NullTaskHandle(), + ): """Get the changes needed by this restructuring `resources` can be a list of `rope.base.resources.File` to @@ -117,45 +121,54 @@ def get_changes(self, checks=None, imports=None, resources=None, """ if checks is not None: warnings.warn( - 'The use of checks parameter is deprecated; ' - 'use the args parameter of the constructor instead.', - DeprecationWarning, stacklevel=2) + "The use of checks parameter is deprecated; " + "use the args parameter of the constructor instead.", + DeprecationWarning, + stacklevel=2, + ) for name, value in checks.items(): self.args[name] = similarfinder._pydefined_to_str(value) if imports is not None: warnings.warn( - 'The use of imports parameter is deprecated; ' - 'use imports parameter of the constructor, instead.', - DeprecationWarning, stacklevel=2) + "The use of imports parameter is deprecated; " + "use imports parameter of the constructor, instead.", + DeprecationWarning, + stacklevel=2, + ) self.imports = imports - changes = change.ChangeSet('Restructuring <%s> to <%s>' % - (self.pattern, self.goal)) + changes = change.ChangeSet( + "Restructuring <%s> to <%s>" % (self.pattern, self.goal) + ) if resources is not None: - files = [resource for resource in resources - if libutils.is_python_file(self.project, resource)] + files = [ + resource + for resource in resources + if libutils.is_python_file(self.project, resource) + ] else: files = self.project.get_python_files() - job_set = task_handle.create_jobset('Collecting Changes', len(files)) + job_set = task_handle.create_jobset("Collecting Changes", len(files)) for resource in files: job_set.started_job(resource.path) pymodule = self.project.get_pymodule(resource) - finder = similarfinder.SimilarFinder(pymodule, - wildcards=self.wildcards) + finder = similarfinder.SimilarFinder(pymodule, wildcards=self.wildcards) matches = list(finder.get_matches(self.pattern, self.args)) computer = self._compute_changes(matches, pymodule) result = computer.get_changed() if result is not None: - imported_source = self._add_imports(resource, result, - self.imports) - changes.add_change(change.ChangeContents(resource, - imported_source)) + imported_source = self._add_imports(resource, result, self.imports) + changes.add_change(change.ChangeContents(resource, imported_source)) job_set.finished_job() return changes def _compute_changes(self, matches, pymodule): return _ChangeComputer( - pymodule.source_code, pymodule.get_ast(), - pymodule.lines, self.template, matches) + pymodule.source_code, + pymodule.get_ast(), + pymodule.lines, + self.template, + matches, + ) def _add_imports(self, resource, source, imports): if not imports: @@ -169,10 +182,10 @@ def _add_imports(self, resource, source, imports): def _get_import_infos(self, resource, imports): pymodule = libutils.get_string_module( - self.project, '\n'.join(imports), resource) + self.project, "\n".join(imports), resource + ) imports = module_imports.ModuleImports(self.project, pymodule) - return [imports.import_info - for imports in imports.imports] + return [imports.import_info for imports in imports.imports] def make_checks(self, string_checks): """Convert str to str dicts to str to PyObject dicts @@ -182,20 +195,21 @@ def make_checks(self, string_checks): """ checks = {} for key, value in string_checks.items(): - is_pyname = not key.endswith('.object') and \ - not key.endswith('.type') + is_pyname = not key.endswith(".object") and not key.endswith(".type") evaluated = self._evaluate(value, is_pyname=is_pyname) if evaluated is not None: checks[key] = evaluated return checks def _evaluate(self, code, is_pyname=True): - attributes = code.split('.') + attributes = code.split(".") pyname = None - if attributes[0] in ('__builtin__', '__builtins__'): + if attributes[0] in ("__builtin__", "__builtins__"): + class _BuiltinsStub(object): def get_attribute(self, name): return builtins.builtins[name] + pyobject = _BuiltinsStub() else: pyobject = self.project.get_module(attributes[0]) @@ -222,7 +236,6 @@ def replace(code, pattern, goal): class _ChangeComputer(object): - def __init__(self, code, ast, lines, goal, matches): self.source = code self.goal = goal @@ -255,16 +268,16 @@ def get_changed(self): return collector.get_changed() def _is_expression(self): - return self.matches and isinstance(self.matches[0], - similarfinder.ExpressionMatch) + return self.matches and isinstance( + self.matches[0], similarfinder.ExpressionMatch + ) def _get_matched_text(self, match): mapping = {} for name in self.goal.get_names(): node = match.get_ast(name) if node is None: - raise similarfinder.BadNameInCheckError( - 'Unknown name <%s>' % name) + raise similarfinder.BadNameInCheckError("Unknown name <%s>" % name) force = self._is_expression() and match.ast == node mapping[name] = self._get_node_text(node, force) unindented = self.goal.substitute(mapping) @@ -278,8 +291,9 @@ def _get_node_text(self, node, force=False): collector = codeanalyze.ChangeCollector(main_text) for node in self._get_nearest_roots(node): sub_start, sub_end = patchedast.node_region(node) - collector.add_change(sub_start - start, sub_end - start, - self._get_node_text(node)) + collector.add_change( + sub_start - start, sub_end - start, self._get_node_text(node) + ) result = collector.get_changed() if result is None: return main_text @@ -291,9 +305,9 @@ def _auto_indent(self, offset, text): result = [] for index, line in enumerate(text.splitlines(True)): if index != 0 and line.strip(): - result.append(' ' * indents) + result.append(" " * indents) result.append(line) - return ''.join(result) + return "".join(result) def _get_nearest_roots(self, node): if node not in self._nearest_roots: diff --git a/rope/refactor/similarfinder.py b/rope/refactor/similarfinder.py index 709832151..302fcd927 100644 --- a/rope/refactor/similarfinder.py +++ b/rope/refactor/similarfinder.py @@ -4,7 +4,7 @@ import rope.refactor.wildcards from rope.base import libutils from rope.base import codeanalyze, exceptions, ast, builtins -from rope.refactor import (patchedast, wildcards) +from rope.refactor import patchedast, wildcards from rope.refactor.patchedast import MismatchedTokenError @@ -26,15 +26,17 @@ def __init__(self, pymodule, wildcards=None): self.source = pymodule.source_code try: self.raw_finder = RawSimilarFinder( - pymodule.source_code, pymodule.get_ast(), self._does_match) + pymodule.source_code, pymodule.get_ast(), self._does_match + ) except MismatchedTokenError: print("in file %s" % pymodule.resource.path) raise self.pymodule = pymodule if wildcards is None: self.wildcards = {} - for wildcard in [rope.refactor.wildcards. - DefaultWildcard(pymodule.pycore.project)]: + for wildcard in [ + rope.refactor.wildcards.DefaultWildcard(pymodule.pycore.project) + ]: self.wildcards[wildcard.get_name()] = wildcard else: self.wildcards = wildcards @@ -44,20 +46,19 @@ def get_matches(self, code, args={}, start=0, end=None): if end is None: end = len(self.source) skip_region = None - if 'skip' in args.get('', {}): - resource, region = args['']['skip'] + if "skip" in args.get("", {}): + resource, region = args[""]["skip"] if resource == self.pymodule.get_resource(): skip_region = region - return self.raw_finder.get_matches(code, start=start, end=end, - skip=skip_region) + return self.raw_finder.get_matches(code, start=start, end=end, skip=skip_region) def get_match_regions(self, *args, **kwds): for match in self.get_matches(*args, **kwds): yield match.get_region() def _does_match(self, node, name): - arg = self.args.get(name, '') - kind = 'default' + arg = self.args.get(name, "") + kind = "default" if isinstance(arg, (tuple, list)): kind = arg[0] arg = arg[1] @@ -74,7 +75,7 @@ def __init__(self, source, node=None, does_match=None): node = ast.parse(source) except SyntaxError: # needed to parse expression containing := operator - node = ast.parse('(' + source + ')') + node = ast.parse("(" + source + ")") if does_match is None: self.does_match = self._simple_does_match else: @@ -87,7 +88,7 @@ def _simple_does_match(self, node, name): def _init_using_ast(self, node, source): self.source = source self._matched_asts = {} - if not hasattr(node, 'region'): + if not hasattr(node, "region"): patchedast.patch_ast(node, source) self.ast = node @@ -105,16 +106,14 @@ def get_matches(self, code, start=0, end=None, skip=None): for match in self._get_matched_asts(code): match_start, match_end = match.get_region() if start <= match_start and match_end <= end: - if skip is not None and (skip[0] < match_end and - skip[1] > match_start): + if skip is not None and (skip[0] < match_end and skip[1] > match_start): continue yield match def _get_matched_asts(self, code): if code not in self._matched_asts: wanted = self._create_pattern(code) - matches = _ASTMatcher(self.ast, wanted, - self.does_match).find_matches() + matches = _ASTMatcher(self.ast, wanted, self.does_match).find_matches() self._matched_asts[code] = matches return self._matched_asts[code] @@ -140,7 +139,6 @@ def _replace_wildcards(self, expression): class _ASTMatcher(object): - def __init__(self, body, pattern, does_match): """Searches the given pattern in the body AST. @@ -178,7 +176,7 @@ def _check_statements(self, node): def __check_stmt_list(self, nodes): for index in range(len(nodes)): if len(nodes) - index >= len(self.pattern): - current_stmts = nodes[index:index + len(self.pattern)] + current_stmts = nodes[index : index + len(self.pattern)] mapping = {} if self._match_stmts(current_stmts, mapping): self.matches.append(StatementMatch(current_stmts, mapping)) @@ -201,8 +199,7 @@ def _match_nodes(self, expected, node, mapping): if not self._match_nodes(child1, child2, mapping): return False elif isinstance(child1, (list, tuple)): - if not isinstance(child2, (list, tuple)) or \ - len(child1) != len(child2): + if not isinstance(child2, (list, tuple)) or len(child1) != len(child2): return False for c1, c2 in zip(child1, child2): if not self._match_nodes(c1, c2, mapping): @@ -215,8 +212,7 @@ def _match_nodes(self, expected, node, mapping): def _get_children(self, node): """Return not `ast.expr_context` children of `node`""" children = ast.get_children(node) - return [child for child in children - if not isinstance(child, ast.expr_context)] + return [child for child in children if not isinstance(child, ast.expr_context)] def _match_stmts(self, current_stmts, mapping): if len(current_stmts) != len(self.pattern): @@ -238,7 +234,6 @@ def _match_wildcard(self, node1, node2, mapping): class Match(object): - def __init__(self, mapping): self.mapping = mapping @@ -251,7 +246,6 @@ def get_ast(self, name): class ExpressionMatch(Match): - def __init__(self, ast, mapping): super(ExpressionMatch, self).__init__(mapping) self.ast = ast @@ -261,7 +255,6 @@ def get_region(self): class StatementMatch(Match): - def __init__(self, ast_list, mapping): super(StatementMatch, self).__init__(mapping) self.ast_list = ast_list @@ -271,7 +264,6 @@ def get_region(self): class CodeTemplate(object): - def __init__(self, template): self.template = template self._find_names() @@ -279,10 +271,9 @@ def __init__(self, template): def _find_names(self): self.names = {} for match in CodeTemplate._get_pattern().finditer(self.template): - if 'name' in match.groupdict() and \ - match.group('name') is not None: - start, end = match.span('name') - name = self.template[start + 2:end - 1] + if "name" in match.groupdict() and match.group("name") is not None: + start, end = match.span("name") + name = self.template[start + 2 : end - 1] if name not in self.names: self.names[name] = [] self.names[name].append((start, end)) @@ -305,9 +296,13 @@ def substitute(self, mapping): @classmethod def _get_pattern(cls): if cls._match_pattern is None: - pattern = codeanalyze.get_comment_pattern() + '|' + \ - codeanalyze.get_string_pattern() + '|' + \ - r'(?P\$\{[^\s\$\}]*\})' + pattern = ( + codeanalyze.get_comment_pattern() + + "|" + + codeanalyze.get_string_pattern() + + "|" + + r"(?P\$\{[^\s\$\}]*\})" + ) cls._match_pattern = re.compile(pattern) return cls._match_pattern @@ -315,11 +310,11 @@ def _get_pattern(cls): class _RopeVariable(object): """Transform and identify rope inserted wildcards""" - _normal_prefix = '__rope__variable_normal_' - _any_prefix = '__rope__variable_any_' + _normal_prefix = "__rope__variable_normal_" + _any_prefix = "__rope__variable_any_" def get_var(self, name): - if name.startswith('?'): + if name.startswith("?"): return self._get_any(name) else: return self._get_normal(name) @@ -329,9 +324,9 @@ def is_var(self, name): def get_base(self, name): if self._is_normal(name): - return name[len(self._normal_prefix):] + return name[len(self._normal_prefix) :] if self._is_var(name): - return '?' + name[len(self._any_prefix):] + return "?" + name[len(self._any_prefix) :] def _get_normal(self, name): return self._normal_prefix + name @@ -352,23 +347,23 @@ def make_pattern(code, variables): def does_match(node, name): return isinstance(node, ast.Name) and node.id == name + finder = RawSimilarFinder(code, does_match=does_match) for variable in variables: - for match in finder.get_matches('${%s}' % variable): + for match in finder.get_matches("${%s}" % variable): start, end = match.get_region() - collector.add_change(start, end, '${%s}' % variable) + collector.add_change(start, end, "${%s}" % variable) result = collector.get_changed() return result if result is not None else code def _pydefined_to_str(pydefined): address = [] - if isinstance(pydefined, - (builtins.BuiltinClass, builtins.BuiltinFunction)): - return '__builtins__.' + pydefined.get_name() + if isinstance(pydefined, (builtins.BuiltinClass, builtins.BuiltinFunction)): + return "__builtins__." + pydefined.get_name() else: while pydefined.parent is not None: address.insert(0, pydefined.get_name()) pydefined = pydefined.parent module_name = libutils.modname(pydefined.resource) - return '.'.join(module_name.split('.') + address) + return ".".join(module_name.split(".") + address) diff --git a/rope/refactor/sourceutils.py b/rope/refactor/sourceutils.py index 9b8429066..159476595 100644 --- a/rope/refactor/sourceutils.py +++ b/rope/refactor/sourceutils.py @@ -7,9 +7,9 @@ def get_indents(lines, lineno): def find_minimum_indents(source_code): result = 80 - lines = source_code.split('\n') + lines = source_code.split("\n") for line in lines: - if line.strip() == '': + if line.strip() == "": continue result = min(result, codeanalyze.count_line_indents(line)) return result @@ -21,15 +21,15 @@ def indent_lines(source_code, amount): lines = source_code.splitlines(True) result = [] for l in lines: - if l.strip() == '': - result.append('\n') + if l.strip() == "": + result.append("\n") continue if amount < 0: indents = codeanalyze.count_line_indents(l) - result.append(max(0, indents + amount) * ' ' + l.lstrip()) + result.append(max(0, indents + amount) * " " + l.lstrip()) else: - result.append(' ' * amount + l) - return ''.join(result) + result.append(" " * amount + l) + return "".join(result) def fix_indentation(code, new_indents): @@ -45,15 +45,17 @@ def add_methods(pymodule, class_scope, methods_sources): if class_scope.get_scopes(): insertion_line = class_scope.get_scopes()[-1].get_end() insertion_offset = lines.get_line_end(insertion_line) - methods = '\n\n' + '\n\n'.join(methods_sources) + methods = "\n\n" + "\n\n".join(methods_sources) indented_methods = fix_indentation( - methods, get_indents(lines, class_scope.get_start()) + - get_indent(pymodule.pycore.project)) + methods, + get_indents(lines, class_scope.get_start()) + + get_indent(pymodule.pycore.project), + ) result = [] result.append(source_code[:insertion_offset]) result.append(indented_methods) result.append(source_code[insertion_offset:]) - return ''.join(result) + return "".join(result) def get_body(pyfunction): @@ -80,7 +82,7 @@ def get_body_region(defined): if scope_start[1] >= start_line: # a one-liner! # XXX: what if colon appears in a string - start = pymodule.source_code.index(':', start) + 1 + start = pymodule.source_code.index(":", start) + 1 while pymodule.source_code[start].isspace(): start += 1 end = min(lines.get_line_end(scope.end) + 1, len(pymodule.source_code)) @@ -88,4 +90,4 @@ def get_body_region(defined): def get_indent(project): - return project.prefs.get('indent_size', 4) + return project.prefs.get("indent_size", 4) diff --git a/rope/refactor/suites.py b/rope/refactor/suites.py index 687850808..b7f5c9a66 100644 --- a/rope/refactor/suites.py +++ b/rope/refactor/suites.py @@ -18,6 +18,7 @@ def find_visible_for_suite(root, lines): def valid(suite): return suite is not None and not suite.ignored + if valid(suite1) and not valid(suite2): return line1 if not valid(suite1) and valid(suite2): @@ -42,7 +43,7 @@ def valid(suite): def ast_suite_tree(node): - if hasattr(node, 'lineno'): + if hasattr(node, "lineno"): lineno = node.lineno else: lineno = 1 @@ -50,7 +51,6 @@ def ast_suite_tree(node): class Suite(object): - def __init__(self, child_nodes, lineno, parent=None, ignored=False): self.parent = parent self.lineno = lineno @@ -98,7 +98,6 @@ def _get_level(self): class _SuiteWalker(object): - def __init__(self, suite): self.suite = suite self.suites = [] @@ -122,7 +121,9 @@ def _TryFinally(self, node): proceed_to_except_handler = isinstance(node.body[0], ast.TryExcept) elif pycompat.PY3: try: - proceed_to_except_handler = isinstance(node.handlers[0], ast.ExceptHandler) + proceed_to_except_handler = isinstance( + node.handlers[0], ast.ExceptHandler + ) except IndexError: pass if proceed_to_except_handler: @@ -150,9 +151,7 @@ def _add_if_like_node(self, node): self.suites.append(Suite(node.orelse, node.lineno, self.suite)) def _FunctionDef(self, node): - self.suites.append(Suite(node.body, node.lineno, - self.suite, ignored=True)) + self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True)) def _ClassDef(self, node): - self.suites.append(Suite(node.body, node.lineno, - self.suite, ignored=True)) + self.suites.append(Suite(node.body, node.lineno, self.suite, ignored=True)) diff --git a/rope/refactor/topackage.py b/rope/refactor/topackage.py index f36a6d528..54efee48c 100644 --- a/rope/refactor/topackage.py +++ b/rope/refactor/topackage.py @@ -1,27 +1,24 @@ import rope.refactor.importutils -from rope.base.change import ChangeSet, ChangeContents, MoveResource, \ - CreateFolder +from rope.base.change import ChangeSet, ChangeContents, MoveResource, CreateFolder class ModuleToPackage(object): - def __init__(self, project, resource): self.project = project self.resource = resource def get_changes(self): - changes = ChangeSet('Transform <%s> module to package' % - self.resource.path) + changes = ChangeSet("Transform <%s> module to package" % self.resource.path) new_content = self._transform_relatives_to_absolute(self.resource) if new_content is not None: changes.add_change(ChangeContents(self.resource, new_content)) parent = self.resource.parent name = self.resource.name[:-3] changes.add_change(CreateFolder(parent, name)) - parent_path = parent.path + '/' + parent_path = parent.path + "/" if not parent.path: - parent_path = '' - new_path = parent_path + '%s/__init__.py' % name + parent_path = "" + new_path = parent_path + "%s/__init__.py" % name if self.resource.project == self.project: changes.add_change(MoveResource(self.resource, new_path)) return changes diff --git a/rope/refactor/usefunction.py b/rope/refactor/usefunction.py index ec1311604..20c32005a 100644 --- a/rope/refactor/usefunction.py +++ b/rope/refactor/usefunction.py @@ -1,5 +1,4 @@ -from rope.base import (change, taskhandle, evaluate, - exceptions, pyobjects, pynames, ast) +from rope.base import change, taskhandle, evaluate, exceptions, pyobjects, pynames, ast from rope.base import libutils from rope.refactor import restructure, sourceutils, similarfinder @@ -13,42 +12,46 @@ def __init__(self, project, resource, offset): this_pymodule = project.get_pymodule(resource) pyname = evaluate.eval_location(this_pymodule, offset) if pyname is None: - raise exceptions.RefactoringError('Unresolvable name selected') + raise exceptions.RefactoringError("Unresolvable name selected") self.pyfunction = pyname.get_object() - if not isinstance(self.pyfunction, pyobjects.PyFunction) or \ - not isinstance(self.pyfunction.parent, pyobjects.PyModule): + if not isinstance(self.pyfunction, pyobjects.PyFunction) or not isinstance( + self.pyfunction.parent, pyobjects.PyModule + ): raise exceptions.RefactoringError( - 'Use function works for global functions, only.') + "Use function works for global functions, only." + ) self.resource = self.pyfunction.get_module().get_resource() self._check_returns() def _check_returns(self): node = self.pyfunction.get_ast() if _yield_count(node): - raise exceptions.RefactoringError('Use function should not ' - 'be used on generators.') + raise exceptions.RefactoringError( + "Use function should not " "be used on generators." + ) returns = _return_count(node) if returns > 1: - raise exceptions.RefactoringError('usefunction: Function has more ' - 'than one return statement.') + raise exceptions.RefactoringError( + "usefunction: Function has more " "than one return statement." + ) if returns == 1 and not _returns_last(node): - raise exceptions.RefactoringError('usefunction: return should ' - 'be the last statement.') + raise exceptions.RefactoringError( + "usefunction: return should " "be the last statement." + ) - def get_changes(self, resources=None, - task_handle=taskhandle.NullTaskHandle()): + def get_changes(self, resources=None, task_handle=taskhandle.NullTaskHandle()): if resources is None: resources = self.project.get_python_files() - changes = change.ChangeSet('Using function <%s>' % - self.pyfunction.get_name()) + changes = change.ChangeSet("Using function <%s>" % self.pyfunction.get_name()) if self.resource in resources: newresources = list(resources) newresources.remove(self.resource) for c in self._restructure(newresources, task_handle).changes: changes.add_change(c) if self.resource in resources: - for c in self._restructure([self.resource], task_handle, - others=False).changes: + for c in self._restructure( + [self.resource], task_handle, others=False + ).changes: changes.add_change(c) return changes @@ -60,16 +63,16 @@ def _restructure(self, resources, task_handle, others=True): goal = self._make_goal(import_=others) imports = None if others: - imports = ['import %s' % self._module_name()] + imports = ["import %s" % self._module_name()] body_region = sourceutils.get_body_region(self.pyfunction) - args_value = {'skip': (self.resource, body_region)} - args = {'': args_value} + args_value = {"skip": (self.resource, body_region)} + args = {"": args_value} restructuring = restructure.Restructure( - self.project, pattern, goal, args=args, imports=imports) - return restructuring.get_changes(resources=resources, - task_handle=task_handle) + self.project, pattern, goal, args=args, imports=imports + ) + return restructuring.get_changes(resources=resources, task_handle=task_handle) def _find_temps(self): return find_temps(self.project, self._get_body()) @@ -80,18 +83,17 @@ def _module_name(self): def _make_pattern(self): params = self.pyfunction.get_param_names() body = self._get_body() - body = restructure.replace(body, 'return', 'pass') + body = restructure.replace(body, "return", "pass") wildcards = list(params) wildcards.extend(self._find_temps()) if self._does_return(): if self._is_expression(): - replacement = '${%s}' % self._rope_returned + replacement = "${%s}" % self._rope_returned else: - replacement = '%s = ${%s}' % (self._rope_result, - self._rope_returned) + replacement = "%s = ${%s}" % (self._rope_result, self._rope_returned) body = restructure.replace( - body, 'return ${%s}' % self._rope_returned, - replacement) + body, "return ${%s}" % self._rope_returned, replacement + ) wildcards.append(self._rope_result) return similarfinder.make_pattern(body, wildcards) @@ -102,27 +104,26 @@ def _make_goal(self, import_=False): params = self.pyfunction.get_param_names() function_name = self.pyfunction.get_name() if import_: - function_name = self._module_name() + '.' + function_name - goal = '%s(%s)' % (function_name, - ', ' .join(('${%s}' % p) for p in params)) + function_name = self._module_name() + "." + function_name + goal = "%s(%s)" % (function_name, ", ".join(("${%s}" % p) for p in params)) if self._does_return() and not self._is_expression(): - goal = '${%s} = %s' % (self._rope_result, goal) + goal = "${%s} = %s" % (self._rope_result, goal) return goal def _does_return(self): body = self._get_body() - removed_return = restructure.replace(body, 'return ${result}', '') + removed_return = restructure.replace(body, "return ${result}", "") return removed_return != body def _is_expression(self): return len(self.pyfunction.get_ast().body) == 1 - _rope_result = '_rope__result' - _rope_returned = '_rope__returned' + _rope_result = "_rope__result" + _rope_returned = "_rope__returned" def find_temps(project, code): - code = 'def f():\n' + sourceutils.indent_lines(code, 4) + code = "def f():\n" + sourceutils.indent_lines(code, 4) pymodule = libutils.get_string_module(project, code) result = [] function_scope = pymodule.get_scope().get_scopes()[0] @@ -137,9 +138,13 @@ def _returns_last(node): def _namedexpr_last(node): - if not hasattr(ast, 'NamedExpr'): # python<3.8 + if not hasattr(ast, "NamedExpr"): # python<3.8 return False - return bool(node.body) and len(node.body) == 1 and isinstance(node.body[-1].value, ast.NamedExpr) + return ( + bool(node.body) + and len(node.body) == 1 + and isinstance(node.body[-1].value, ast.NamedExpr) + ) def _yield_count(node): @@ -161,7 +166,6 @@ def _named_expr_count(node): class _ReturnOrYieldFinder(object): - def __init__(self): self.returns = 0 self.named_expression = 0 diff --git a/rope/refactor/wildcards.py b/rope/refactor/wildcards.py index 90040c794..09fa04ebd 100644 --- a/rope/refactor/wildcards.py +++ b/rope/refactor/wildcards.py @@ -3,7 +3,6 @@ class Wildcard(object): - def get_name(self): """Return the name of this wildcard""" @@ -12,7 +11,6 @@ def matches(self, suspect, arg): class Suspect(object): - def __init__(self, pymodule, node, name): self.name = name self.pymodule = pymodule @@ -38,9 +36,9 @@ def __init__(self, project): self.project = project def get_name(self): - return 'default' + return "default" - def matches(self, suspect, arg=''): + def matches(self, suspect, arg=""): args = parse_arg(arg) if not self._check_exact(args, suspect): @@ -52,20 +50,19 @@ def matches(self, suspect, arg=''): def _check_object(self, args, suspect): kind = None expected = None - unsure = args.get('unsure', False) - for check in ['name', 'object', 'type', 'instance']: + unsure = args.get("unsure", False) + for check in ["name", "object", "type", "instance"]: if check in args: kind = check expected = args[check] if expected is not None: - checker = _CheckObject(self.project, expected, - kind, unsure=unsure) + checker = _CheckObject(self.project, expected, kind, unsure=unsure) return checker(suspect.pymodule, suspect.node) return True def _check_exact(self, args, suspect): node = suspect.node - if args.get('exact'): + if args.get("exact"): if not isinstance(node, ast.Name) or not node.id == suspect.name: return False else: @@ -78,10 +75,10 @@ def parse_arg(arg): if isinstance(arg, dict): return arg result = {} - tokens = arg.split(',') + tokens = arg.split(",") for token in tokens: - if '=' in token: - parts = token.split('=', 1) + if "=" in token: + parts = token.split("=", 1) result[parts[0].strip()] = parts[1].strip() else: result[token.strip()] = True @@ -89,8 +86,7 @@ def parse_arg(arg): class _CheckObject(object): - - def __init__(self, project, expected, kind='object', unsure=False): + def __init__(self, project, expected, kind="object", unsure=False): self.project = project self.kind = kind self.unsure = unsure @@ -100,17 +96,17 @@ def __call__(self, pymodule, node): pyname = self._evaluate_node(pymodule, node) if pyname is None or self.expected is None: return self.unsure - if self._unsure_pyname(pyname, unbound=self.kind == 'name'): + if self._unsure_pyname(pyname, unbound=self.kind == "name"): return True - if self.kind == 'name': + if self.kind == "name": return self._same_pyname(self.expected, pyname) else: pyobject = pyname.get_object() - if self.kind == 'object': + if self.kind == "object": objects = [pyobject] - if self.kind == 'type': + if self.kind == "type": objects = [pyobject.get_type()] - if self.kind == 'instance': + if self.kind == "instance": objects = [pyobject] objects.extend(self._get_super_classes(pyobject)) objects.extend(self._get_super_classes(pyobject.get_type())) @@ -137,17 +133,16 @@ def _unsure_pyname(self, pyname, unbound=True): return self.unsure and occurrences.unsure_pyname(pyname, unbound) def _split_name(self, name): - parts = name.split('.') + parts = name.split(".") expression, kind = parts[0], parts[-1] if len(parts) == 1: - kind = 'name' + kind = "name" return expression, kind def _evaluate_node(self, pymodule, node): scope = pymodule.get_scope().get_inner_scope_for_line(node.lineno) expression = node - if isinstance(expression, ast.Name) and \ - isinstance(expression.ctx, ast.Store): + if isinstance(expression, ast.Name) and isinstance(expression.ctx, ast.Store): start, end = patchedast.node_region(expression) text = pymodule.source_code[start:end] return evaluate.eval_str(scope, text) @@ -155,9 +150,10 @@ def _evaluate_node(self, pymodule, node): return evaluate.eval_node(scope, expression) def _evaluate(self, code): - attributes = code.split('.') + attributes = code.split(".") pyname = None - if attributes[0] in ('__builtin__', '__builtins__'): + if attributes[0] in ("__builtin__", "__builtins__"): + class _BuiltinsStub(object): def get_attribute(self, name): return builtins.builtins[name] @@ -167,6 +163,7 @@ def __getitem__(self, name): def __contains__(self, name): return name in builtins.builtins + pyobject = _BuiltinsStub() else: pyobject = self.project.get_module(attributes[0]) diff --git a/ropetest/__init__.py b/ropetest/__init__.py index 3fd2cebcc..f47bdfefe 100644 --- a/ropetest/__init__.py +++ b/ropetest/__init__.py @@ -1,4 +1,5 @@ import sys + try: import unittest2 as unittest except ImportError: @@ -44,7 +45,7 @@ def suite(): return result -if __name__ == '__main__': +if __name__ == "__main__": runner = unittest.TextTestRunner() result = runner.run(suite()) sys.exit(not result.wasSuccessful()) diff --git a/ropetest/advanced_oi_test.py b/ropetest/advanced_oi_test.py index d50249652..89146de18 100644 --- a/ropetest/advanced_oi_test.py +++ b/ropetest/advanced_oi_test.py @@ -12,7 +12,6 @@ class DynamicOITest(unittest.TestCase): - def setUp(self): super(DynamicOITest, self).setUp() self.project = testutils.sample_project(validate_objectdb=True) @@ -23,213 +22,236 @@ def tearDown(self): super(DynamicOITest, self).tearDown() def test_simple_dti(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func(a_func)\n' + mod = testutils.create_module(self.project, "mod") + code = 'def a_func(arg):\n return eval("arg")\n' "a_var = a_func(a_func)\n" mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - self.assertEqual(pymod['a_func'].get_object(), - pymod['a_var'].get_object()) + self.assertEqual(pymod["a_func"].get_object(), pymod["a_var"].get_object()) def test_module_dti(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - code = 'import mod1\ndef a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func(mod1)\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + code = ( + 'import mod1\ndef a_func(arg):\n return eval("arg")\n' + "a_var = a_func(mod1)\n" + ) mod2.write(code) self.pycore.run_module(mod2).wait_process() pymod2 = self.project.get_pymodule(mod2) - self.assertEqual(self.project.get_pymodule(mod1), - pymod2['a_var'].get_object()) + self.assertEqual(self.project.get_pymodule(mod1), pymod2["a_var"].get_object()) def test_class_from_another_module_dti(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - code1 = 'class AClass(object):\n pass\n' - code2 = 'from mod1 import AClass\n' \ - '\ndef a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func(AClass)\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + code1 = "class AClass(object):\n pass\n" + code2 = ( + "from mod1 import AClass\n" + '\ndef a_func(arg):\n return eval("arg")\n' + "a_var = a_func(AClass)\n" + ) mod1.write(code1) mod2.write(code2) self.pycore.run_module(mod2).wait_process() - #pymod1 = self.project.get_pymodule(mod1) + # pymod1 = self.project.get_pymodule(mod1) pymod2 = self.project.get_pymodule(mod2) - self.assertEqual(pymod2['AClass'].get_object(), - pymod2['a_var'].get_object()) + self.assertEqual(pymod2["AClass"].get_object(), pymod2["a_var"].get_object()) def test_class_dti(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class AClass(object):\n pass\n' \ - '\ndef a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func(AClass)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class AClass(object):\n pass\n" + '\ndef a_func(arg):\n return eval("arg")\n' + "a_var = a_func(AClass)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - self.assertEqual(pymod['AClass'].get_object(), - pymod['a_var'].get_object()) + self.assertEqual(pymod["AClass"].get_object(), pymod["a_var"].get_object()) def test_instance_dti(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class AClass(object):\n pass\n' \ - '\ndef a_func(arg):\n return eval("arg()")\n' \ - 'a_var = a_func(AClass)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class AClass(object):\n pass\n" + '\ndef a_func(arg):\n return eval("arg()")\n' + "a_var = a_func(AClass)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - self.assertEqual(pymod['AClass'].get_object(), - pymod['a_var'].get_object().get_type()) + self.assertEqual( + pymod["AClass"].get_object(), pymod["a_var"].get_object().get_type() + ) def test_method_dti(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class AClass(object):\n def a_method(self, arg):\n' \ - ' return eval("arg()")\n' \ - 'an_instance = AClass()\n' \ - 'a_var = an_instance.a_method(AClass)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class AClass(object):\n def a_method(self, arg):\n" + ' return eval("arg()")\n' + "an_instance = AClass()\n" + "a_var = an_instance.a_method(AClass)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - self.assertEqual(pymod['AClass'].get_object(), - pymod['a_var'].get_object().get_type()) + self.assertEqual( + pymod["AClass"].get_object(), pymod["a_var"].get_object().get_type() + ) def test_function_argument_dti(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg):\n pass\n' \ - 'a_func(a_func)\n' + mod = testutils.create_module(self.project, "mod") + code = "def a_func(arg):\n pass\n" "a_func(a_func)\n" mod.write(code) self.pycore.run_module(mod).wait_process() pyscope = self.project.get_pymodule(mod).get_scope() - self.assertEqual(pyscope['a_func'].get_object(), - pyscope.get_scopes()[0]['arg'].get_object()) + self.assertEqual( + pyscope["a_func"].get_object(), pyscope.get_scopes()[0]["arg"].get_object() + ) def test_classes_with_the_same_name(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg):\n class AClass(object):\n' \ - ' pass\n return eval("arg")\n' \ - 'class AClass(object):\n pass\n' \ - 'a_var = a_func(AClass)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "def a_func(arg):\n class AClass(object):\n" + ' pass\n return eval("arg")\n' + "class AClass(object):\n pass\n" + "a_var = a_func(AClass)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - self.assertEqual(pymod['AClass'].get_object(), - pymod['a_var'].get_object()) + self.assertEqual(pymod["AClass"].get_object(), pymod["a_var"].get_object()) def test_nested_classes(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func():\n class AClass(object):\n' \ - ' pass\n return AClass\n' \ - 'def another_func(arg):\n return eval("arg")\n' \ - 'a_var = another_func(a_func())\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "def a_func():\n class AClass(object):\n" + " pass\n return AClass\n" + 'def another_func(arg):\n return eval("arg")\n' + "a_var = another_func(a_func())\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pyscope = self.project.get_pymodule(mod).get_scope() - self.assertEqual(pyscope.get_scopes()[0]['AClass'].get_object(), - pyscope['a_var'].get_object()) + self.assertEqual( + pyscope.get_scopes()[0]["AClass"].get_object(), + pyscope["a_var"].get_object(), + ) def test_function_argument_dti2(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg, a_builtin_type):\n pass\n' \ - 'a_func(a_func, [])\n' + mod = testutils.create_module(self.project, "mod") + code = "def a_func(arg, a_builtin_type):\n pass\n" "a_func(a_func, [])\n" mod.write(code) self.pycore.run_module(mod).wait_process() pyscope = self.project.get_pymodule(mod).get_scope() - self.assertEqual(pyscope['a_func'].get_object(), - pyscope.get_scopes()[0]['arg'].get_object()) + self.assertEqual( + pyscope["a_func"].get_object(), pyscope.get_scopes()[0]["arg"].get_object() + ) def test_dti_and_concluded_data_invalidation(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func(a_func)\n' + mod = testutils.create_module(self.project, "mod") + code = 'def a_func(arg):\n return eval("arg")\n' "a_var = a_func(a_func)\n" mod.write(code) pymod = self.project.get_pymodule(mod) - pymod['a_var'].get_object() + pymod["a_var"].get_object() self.pycore.run_module(mod).wait_process() - self.assertEqual(pymod['a_func'].get_object(), - pymod['a_var'].get_object()) + self.assertEqual(pymod["a_func"].get_object(), pymod["a_var"].get_object()) def test_list_objects_and_dynamicoi(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func([C()])[0]\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a_var = a_func([C()])[0]\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_for_loops_and_dynamicoi(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'for c in a_func([C()]):\n a_var = c\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "for c in a_func([C()]):\n a_var = c\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_dict_objects_and_dynamicoi(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func({1: C()})[1]\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a_var = a_func({1: C()})[1]\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_dict_keys_and_dynamicoi(self): - mod = testutils.create_module(self.project, 'mod') + mod = testutils.create_module(self.project, "mod") if pycompat.PY3: - code = 'class C(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = list(a_func({C(): 1}))[0]\n' + code = ( + "class C(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a_var = list(a_func({C(): 1}))[0]\n" + ) else: - code = 'class C(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func({C(): 1}).keys()[0]\n' + code = ( + "class C(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a_var = a_func({C(): 1}).keys()[0]\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_dict_keys_and_dynamicoi2(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a, b = a_func((C1(), C2()))\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a, b = a_func((C1(), C2()))\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_strs_and_dynamicoi(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func("hey")\n' + mod = testutils.create_module(self.project, "mod") + code = 'def a_func(arg):\n return eval("arg")\n' 'a_var = a_func("hey")\n' mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertTrue(isinstance(a_var.get_type(), rope.base.builtins.Str)) def test_textual_transformations(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n pass\ndef f():' \ - '\n pass\na_var = C()\n' \ - 'a_list = [C()]\na_str = "hey"\na_file = open("file.txt")\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n pass\ndef f():" + "\n pass\na_var = C()\n" + 'a_list = [C()]\na_str = "hey"\na_file = open("file.txt")\n' + ) mod.write(code) to_pyobject = rope.base.oi.transform.TextualToPyObject(self.project) to_textual = rope.base.oi.transform.PyObjectToTextual(self.project) @@ -237,185 +259,196 @@ def test_textual_transformations(self): def complex_to_textual(pyobject): return to_textual.transform( - to_pyobject.transform(to_textual.transform(pyobject))) + to_pyobject.transform(to_textual.transform(pyobject)) + ) test_variables = [ - ('C', ('defined', 'mod.py', 'C')), - ('f', ('defined', 'mod.py', 'f')), - ('a_var', ('instance', ('defined', 'mod.py', 'C'))), - ('a_list', - ('builtin', 'list', ('instance', ('defined', 'mod.py', 'C')))), - ('a_str', ('builtin', 'str')), - ('a_file', ('builtin', 'file')), + ("C", ("defined", "mod.py", "C")), + ("f", ("defined", "mod.py", "f")), + ("a_var", ("instance", ("defined", "mod.py", "C"))), + ("a_list", ("builtin", "list", ("instance", ("defined", "mod.py", "C")))), + ("a_str", ("builtin", "str")), + ("a_file", ("builtin", "file")), ] test_cases = [(pymod[v].get_object(), r) for v, r in test_variables] test_cases += [ - (pymod, ('defined', 'mod.py')), - (rope.base.builtins.builtins['enumerate'].get_object(), - ('builtin', 'function', 'enumerate')) + (pymod, ("defined", "mod.py")), + ( + rope.base.builtins.builtins["enumerate"].get_object(), + ("builtin", "function", "enumerate"), + ), ] for var, result in test_cases: self.assertEqual(to_textual.transform(var), result) self.assertEqual(complex_to_textual(var), result) def test_arguments_with_keywords(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a = a_func(arg=C1())\nb = a_func(arg=C2())\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a = a_func(arg=C1())\nb = a_func(arg=C2())\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_a_function_with_different_returns(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(arg):\n return eval("arg")\n' \ - 'a = a_func(C1())\nb = a_func(C2())\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + 'def a_func(arg):\n return eval("arg")\n' + "a = a_func(C1())\nb = a_func(C2())\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_a_function_with_different_returns2(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(p):\n if p == C1:\n return C1()\n' \ - ' else:\n return C2()\n' \ - 'a = a_func(C1)\nb = a_func(C2)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "def a_func(p):\n if p == C1:\n return C1()\n" + " else:\n return C2()\n" + "a = a_func(C1)\nb = a_func(C2)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_ignoring_star_args(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(p, *args):' \ - '\n if p == C1:\n return C1()\n' \ - ' else:\n return C2()\n' \ - 'a = a_func(C1, 1)\nb = a_func(C2, 2)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "def a_func(p, *args):" + "\n if p == C1:\n return C1()\n" + " else:\n return C2()\n" + "a = a_func(C1, 1)\nb = a_func(C2, 2)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_ignoring_double_star_args(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(p, *kwds, **args):\n ' \ - 'if p == C1:\n return C1()\n' \ - ' else:\n return C2()\n' \ - 'a = a_func(C1, kwd=1)\nb = a_func(C2, kwd=2)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "def a_func(p, *kwds, **args):\n " + "if p == C1:\n return C1()\n" + " else:\n return C2()\n" + "a = a_func(C1, kwd=1)\nb = a_func(C2, kwd=2)\n" + ) mod.write(code) self.pycore.run_module(mod).wait_process() pymod = self.project.get_pymodule(mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_invalidating_data_after_changing(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def a_func(arg):\n return eval("arg")\n' \ - 'a_var = a_func(a_func)\n' + mod = testutils.create_module(self.project, "mod") + code = 'def a_func(arg):\n return eval("arg")\n' "a_var = a_func(a_func)\n" mod.write(code) self.pycore.run_module(mod).wait_process() - mod.write(code.replace('a_func', 'newfunc')) + mod.write(code.replace("a_func", "newfunc")) mod.write(code) pymod = self.project.get_pymodule(mod) - self.assertNotEqual(pymod['a_func'].get_object(), - pymod['a_var'].get_object()) + self.assertNotEqual(pymod["a_func"].get_object(), pymod["a_var"].get_object()) def test_invalidating_data_after_moving(self): - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('class C(object):\n pass\n') - mod = testutils.create_module(self.project, 'mod') - code = 'import mod2\ndef a_func(arg):\n return eval(arg)\n' \ - 'a_var = a_func("mod2.C")\n' + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("class C(object):\n pass\n") + mod = testutils.create_module(self.project, "mod") + code = ( + "import mod2\ndef a_func(arg):\n return eval(arg)\n" + 'a_var = a_func("mod2.C")\n' + ) mod.write(code) self.pycore.run_module(mod).wait_process() - mod.move('newmod.py') - pymod = self.project.get_module('newmod') + mod.move("newmod.py") + pymod = self.project.get_module("newmod") pymod2 = self.project.get_pymodule(mod2) - self.assertEqual(pymod2['C'].get_object(), - pymod['a_var'].get_object()) + self.assertEqual(pymod2["C"].get_object(), pymod["a_var"].get_object()) class NewStaticOITest(unittest.TestCase): - def setUp(self): super(NewStaticOITest, self).setUp() self.project = testutils.sample_project(validate_objectdb=True) self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(NewStaticOITest, self).tearDown() def test_static_oi_for_simple_function_calls(self): - code = 'class C(object):\n pass\ndef f(p):\n pass\nf(C())\n' + code = "class C(object):\n pass\ndef f(p):\n pass\nf(C())\n" self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - f_scope = pymod['f'].get_object().get_scope() - p_type = f_scope['p'].get_object().get_type() + c_class = pymod["C"].get_object() + f_scope = pymod["f"].get_object().get_scope() + p_type = f_scope["p"].get_object().get_type() self.assertEqual(c_class, p_type) def test_static_oi_not_failing_when_callin_callables(self): - code = 'class C(object):\n pass\nC()\n' + code = "class C(object):\n pass\nC()\n" self.mod.write(code) self.pycore.analyze_module(self.mod) def test_static_oi_for_nested_calls(self): - code = 'class C(object):\n pass\ndef f(p):\n pass\n' \ - 'def g(p):\n return p\nf(g(C()))\n' + code = ( + "class C(object):\n pass\ndef f(p):\n pass\n" + "def g(p):\n return p\nf(g(C()))\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - f_scope = pymod['f'].get_object().get_scope() - p_type = f_scope['p'].get_object().get_type() + c_class = pymod["C"].get_object() + f_scope = pymod["f"].get_object().get_scope() + p_type = f_scope["p"].get_object().get_type() self.assertEqual(c_class, p_type) def test_static_oi_class_methods(self): - code = 'class C(object):\n def f(self, p):\n pass\n' \ - 'C().f(C())' + code = "class C(object):\n def f(self, p):\n pass\n" "C().f(C())" self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - f_scope = c_class['f'].get_object().get_scope() - p_type = f_scope['p'].get_object().get_type() + c_class = pymod["C"].get_object() + f_scope = c_class["f"].get_object().get_scope() + p_type = f_scope["p"].get_object().get_type() self.assertEqual(c_class, p_type) def test_static_oi_preventing_soi_maximum_recursion_exceptions(self): - code = 'item = {}\nfor item in item.keys():\n pass\n' + code = "item = {}\nfor item in item.keys():\n pass\n" self.mod.write(code) try: self.pycore.analyze_module(self.mod) @@ -423,229 +456,262 @@ def test_static_oi_preventing_soi_maximum_recursion_exceptions(self): self.fail(str(e)) def test_static_oi_for_infer_return_typs_from_funcs_based_on_params(self): - code = 'class C(object):\n pass\ndef func(p):\n return p\n' \ - 'a_var = func(C())\n' + code = ( + "class C(object):\n pass\ndef func(p):\n return p\n" + "a_var = func(C())\n" + ) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_a_function_with_different_returns(self): - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'def a_func(arg):\n return arg\n' \ - 'a = a_func(C1())\nb = a_func(C2())\n' + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "def a_func(arg):\n return arg\n" + "a = a_func(C1())\nb = a_func(C2())\n" + ) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_not_reporting_out_of_date_information(self): - code = 'class C1(object):\n pass\n' \ - 'def f(arg):\n return C1()\na_var = f('')\n' + code = ( + "class C1(object):\n pass\n" + "def f(arg):\n return C1()\na_var = f(" + ")\n" + ) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - a_var = pymod['a_var'].get_object() + c1_class = pymod["C1"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c1_class, a_var.get_type()) - self.mod.write(code.replace('C1', 'C2')) + self.mod.write(code.replace("C1", "C2")) pymod = self.project.get_pymodule(self.mod) - c2_class = pymod['C2'].get_object() - a_var = pymod['a_var'].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c2_class, a_var.get_type()) def test_invalidating_concluded_data_in_a_function(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('def func(arg):\n temp = arg\n return temp\n') - mod2.write('import mod1\n' - 'class C1(object):\n pass\n' - 'class C2(object):\n pass\n' - 'a_var = mod1.func(C1())\n') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("def func(arg):\n temp = arg\n return temp\n") + mod2.write( + "import mod1\n" + "class C1(object):\n pass\n" + "class C2(object):\n pass\n" + "a_var = mod1.func(C1())\n" + ) pymod2 = self.project.get_pymodule(mod2) - c1_class = pymod2['C1'].get_object() - a_var = pymod2['a_var'].get_object() + c1_class = pymod2["C1"].get_object() + a_var = pymod2["a_var"].get_object() self.assertEqual(c1_class, a_var.get_type()) - mod2.write(mod2.read()[:mod2.read().rfind('C1()')] + 'C2())\n') + mod2.write(mod2.read()[: mod2.read().rfind("C1()")] + "C2())\n") pymod2 = self.project.get_pymodule(mod2) - c2_class = pymod2['C2'].get_object() - a_var = pymod2['a_var'].get_object() + c2_class = pymod2["C2"].get_object() + a_var = pymod2["a_var"].get_object() self.assertEqual(c2_class, a_var.get_type()) def test_handling_generator_functions_for_strs(self): - self.mod.write('class C(object):\n pass\ndef f(p):\n yield p()\n' - 'for c in f(C):\n a_var = c\n') + self.mod.write( + "class C(object):\n pass\ndef f(p):\n yield p()\n" + "for c in f(C):\n a_var = c\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) # TODO: Returning a generator for functions that yield unknowns @unittest.skip("Returning a generator that yields unknowns") def xxx_test_handl_generator_functions_when_unknown_type_is_yielded(self): - self.mod.write('class C(object):\n pass' - '\ndef f():\n yield eval("C()")\n' - 'a_var = f()\n') + self.mod.write( + "class C(object):\n pass" + '\ndef f():\n yield eval("C()")\n' + "a_var = f()\n" + ) pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() - self.assertTrue(isinstance(a_var.get_type(), - rope.base.builtins.Generator)) + a_var = pymod["a_var"].get_object() + self.assertTrue(isinstance(a_var.get_type(), rope.base.builtins.Generator)) def test_static_oi_for_lists_depending_on_append_function(self): - code = 'class C(object):\n pass\nl = list()\n' \ - 'l.append(C())\na_var = l.pop()\n' + code = ( + "class C(object):\n pass\nl = list()\n" + "l.append(C())\na_var = l.pop()\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_static_oi_for_lists_per_object_for_get_item(self): - code = 'class C(object):\n pass\nl = list()\n' \ - 'l.append(C())\na_var = l[0]\n' + code = ( + "class C(object):\n pass\nl = list()\n" "l.append(C())\na_var = l[0]\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_static_oi_for_lists_per_object_for_fields(self): - code = 'class C(object):\n pass\n' \ - 'class A(object):\n ' \ - 'def __init__(self):\n self.l = []\n' \ - ' def set(self):\n self.l.append(C())\n' \ - 'a = A()\na.set()\na_var = a.l[0]\n' + code = ( + "class C(object):\n pass\n" + "class A(object):\n " + "def __init__(self):\n self.l = []\n" + " def set(self):\n self.l.append(C())\n" + "a = A()\na.set()\na_var = a.l[0]\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_static_oi_for_lists_per_object_for_set_item(self): - code = 'class C(object):\n pass\nl = [None]\n' \ - 'l[0] = C()\na_var = l[0]\n' + code = "class C(object):\n pass\nl = [None]\n" "l[0] = C()\na_var = l[0]\n" self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_static_oi_for_lists_per_object_for_extending_lists(self): - code = 'class C(object):\n pass\nl = []\n' \ - 'l.append(C())\nl2 = []\nl2.extend(l)\na_var = l2[0]\n' + code = ( + "class C(object):\n pass\nl = []\n" + "l.append(C())\nl2 = []\nl2.extend(l)\na_var = l2[0]\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_static_oi_for_lists_per_object_for_iters(self): - code = 'class C(object):\n pass\n' \ - 'l = []\nl.append(C())\n' \ - 'for c in l:\n a_var = c\n' + code = ( + "class C(object):\n pass\n" + "l = []\nl.append(C())\n" + "for c in l:\n a_var = c\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_static_oi_for_dicts_depending_on_append_function(self): - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'd = {}\nd[C1()] = C2()\na, b = d.popitem()\n' + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "d = {}\nd[C1()] = C2()\na, b = d.popitem()\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_static_oi_for_dicts_depending_on_for_loops(self): - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'd = {}\nd[C1()] = C2()\n' \ - 'for k, v in d.items():\n a = k\n b = v\n' + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "d = {}\nd[C1()] = C2()\n" + "for k, v in d.items():\n a = k\n b = v\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_static_oi_for_dicts_depending_on_update(self): - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'd = {}\nd[C1()] = C2()\n' \ - 'd2 = {}\nd2.update(d)\na, b = d2.popitem()\n' + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "d = {}\nd[C1()] = C2()\n" + "d2 = {}\nd2.update(d)\na, b = d2.popitem()\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_static_oi_for_dicts_depending_on_update_on_seqs(self): - code = 'class C1(object):\n pass\nclass C2(object):\n pass\n' \ - 'd = {}\nd.update([(C1(), C2())])\na, b = d.popitem()\n' + code = ( + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "d = {}\nd.update([(C1(), C2())])\na, b = d.popitem()\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_static_oi_for_sets_per_object_for_set_item(self): - code = 'class C(object):\n pass\ns = set()\n' \ - 's.add(C())\na_var = s.pop() \n' + code = ( + "class C(object):\n pass\ns = set()\n" "s.add(C())\na_var = s.pop() \n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_properties_and_calling_get_property(self): - code = 'class C1(object):\n pass\n' \ - 'class C2(object):\n c1 = C1()\n' \ - ' def get_c1(self):\n return self.c1\n' \ - ' p = property(get_c1)\nc2 = C2()\na_var = c2.p\n' + code = ( + "class C1(object):\n pass\n" + "class C2(object):\n c1 = C1()\n" + " def get_c1(self):\n return self.c1\n" + " p = property(get_c1)\nc2 = C2()\na_var = c2.p\n" + ) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - a_var = pymod['a_var'].get_object() + c1_class = pymod["C1"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c1_class, a_var.get_type()) def test_soi_on_constructors(self): - code = 'class C1(object):\n pass\n' \ - 'class C2(object):\n' \ - ' def __init__(self, arg):\n self.attr = arg\n' \ - 'c2 = C2(C1())\na_var = c2.attr' + code = ( + "class C1(object):\n pass\n" + "class C2(object):\n" + " def __init__(self, arg):\n self.attr = arg\n" + "c2 = C2(C1())\na_var = c2.attr" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - a_var = pymod['a_var'].get_object() + c1_class = pymod["C1"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c1_class, a_var.get_type()) def test_soi_on_literal_assignment(self): @@ -653,27 +719,26 @@ def test_soi_on_literal_assignment(self): self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(Str, type(a_var.get_type())) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_soi_on_typed_assignment(self): - code = 'a_var: str' + code = "a_var: str" self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(Str, type(a_var.get_type())) def test_not_saving_unknown_function_returns(self): - mod2 = testutils.create_module(self.project, 'mod2') - self.mod.write('class C(object):\n pass\nl = []\nl.append(C())\n') - mod2.write('import mod\ndef f():\n ' - 'return mod.l.pop()\na_var = f()\n') + mod2 = testutils.create_module(self.project, "mod2") + self.mod.write("class C(object):\n pass\nl = []\nl.append(C())\n") + mod2.write("import mod\ndef f():\n " "return mod.l.pop()\na_var = f()\n") pymod = self.project.get_pymodule(self.mod) pymod2 = self.project.get_pymodule(mod2) - c_class = pymod['C'].get_object() - a_var = pymod2['a_var'] + c_class = pymod["C"].get_object() + a_var = pymod2["a_var"] self.pycore.analyze_module(mod2) self.assertNotEqual(c_class, a_var.get_object().get_type()) @@ -682,97 +747,103 @@ def test_not_saving_unknown_function_returns(self): self.assertEqual(c_class, a_var.get_object().get_type()) def test_using_the_best_callinfo(self): - code = 'class C1(object):\n pass\n' \ - 'def f(arg1, arg2, arg3):\n pass\n' \ - 'f("", None, C1())\nf("", C1(), None)\n' + code = ( + "class C1(object):\n pass\n" + "def f(arg1, arg2, arg3):\n pass\n" + 'f("", None, C1())\nf("", C1(), None)\n' + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - f_scope = pymod['f'].get_object().get_scope() - arg2 = f_scope['arg2'].get_object() + c1_class = pymod["C1"].get_object() + f_scope = pymod["f"].get_object().get_scope() + arg2 = f_scope["arg2"].get_object() self.assertEqual(c1_class, arg2.get_type()) def test_call_function_and_parameters(self): - code = 'class A(object):\n def __call__(self, p):\n pass\n' \ - 'A()("")\n' + code = ( + "class A(object):\n def __call__(self, p):\n pass\n" 'A()("")\n' + ) self.mod.write(code) self.pycore.analyze_module(self.mod) scope = self.project.get_pymodule(self.mod).get_scope() - p_object = scope.get_scopes()[0].get_scopes()[0]['p'].get_object() - self.assertTrue(isinstance(p_object.get_type(), - rope.base.builtins.Str)) + p_object = scope.get_scopes()[0].get_scopes()[0]["p"].get_object() + self.assertTrue(isinstance(p_object.get_type(), rope.base.builtins.Str)) def test_report_change_in_libutils(self): - self.project.prefs['automatic_soa'] = True - code = 'class C(object):\n pass\ndef f(p):\n pass\nf(C())\n' - with open(self.mod.real_path, 'w') as mod_file: + self.project.prefs["automatic_soa"] = True + code = "class C(object):\n pass\ndef f(p):\n pass\nf(C())\n" + with open(self.mod.real_path, "w") as mod_file: mod_file.write(code) - rope.base.libutils.report_change(self.project, self.mod.real_path, '') + rope.base.libutils.report_change(self.project, self.mod.real_path, "") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - f_scope = pymod['f'].get_object().get_scope() - p_type = f_scope['p'].get_object().get_type() + c_class = pymod["C"].get_object() + f_scope = pymod["f"].get_object().get_scope() + p_type = f_scope["p"].get_object().get_type() self.assertEqual(c_class, p_type) def test_report_libutils_and_analyze_all_modules(self): - code = 'class C(object):\n pass\ndef f(p):\n pass\nf(C())\n' + code = "class C(object):\n pass\ndef f(p):\n pass\nf(C())\n" self.mod.write(code) rope.base.libutils.analyze_modules(self.project) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - f_scope = pymod['f'].get_object().get_scope() - p_type = f_scope['p'].get_object().get_type() + c_class = pymod["C"].get_object() + f_scope = pymod["f"].get_object().get_scope() + p_type = f_scope["p"].get_object().get_type() self.assertEqual(c_class, p_type) def test_validation_problems_for_objectdb_retrievals(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('l = []\nvar = l.pop()\n') - mod2.write('import mod1\n\nclass C(object):\n pass\n' - 'mod1.l.append(C())\n') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("l = []\nvar = l.pop()\n") + mod2.write("import mod1\n\nclass C(object):\n pass\n" "mod1.l.append(C())\n") self.pycore.analyze_module(mod2) pymod2 = self.project.get_pymodule(mod2) - c_class = pymod2['C'].get_object() + c_class = pymod2["C"].get_object() pymod1 = self.project.get_pymodule(mod1) - var_pyname = pymod1['var'] + var_pyname = pymod1["var"] self.assertEqual(c_class, var_pyname.get_object().get_type()) mod2.write('import mod1\n\nmod1.l.append("")\n') - self.assertNotEqual(c_class, var_pyname.get_object().get_type(), - 'Class `C` no more exists') + self.assertNotEqual( + c_class, var_pyname.get_object().get_type(), "Class `C` no more exists" + ) def test_validation_problems_for_changing_builtin_types(self): - mod1 = testutils.create_module(self.project, 'mod1') + mod1 = testutils.create_module(self.project, "mod1") mod1.write('l = []\nl.append("")\n') self.pycore.analyze_module(mod1) mod1.write('l = {}\nv = l["key"]\n') pymod1 = self.project.get_pymodule(mod1) # noqa - var = pymod1['v'].get_object() # noqa + var = pymod1["v"].get_object() # noqa def test_always_returning_containing_class_for_selfs(self): - code = 'class A(object):\n def f(p):\n return p\n' \ - 'class B(object):\n pass\nb = B()\nb.f()\n' + code = ( + "class A(object):\n def f(p):\n return p\n" + "class B(object):\n pass\nb = B()\nb.f()\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod) pymod = self.project.get_pymodule(self.mod) - a_class = pymod['A'].get_object() + a_class = pymod["A"].get_object() f_scope = a_class.get_scope().get_scopes()[0] - p_type = f_scope['p'].get_object().get_type() + p_type = f_scope["p"].get_object().get_type() self.assertEqual(a_class, p_type) def test_following_function_calls_when_asked_to(self): - code = 'class A(object):\n pass\n' \ - 'class C(object):\n' \ - ' def __init__(self, arg):\n' \ - ' self.attr = arg\n' \ - 'def f(p):\n return C(p)\n' \ - 'c = f(A())\nx = c.attr\n' + code = ( + "class A(object):\n pass\n" + "class C(object):\n" + " def __init__(self, arg):\n" + " self.attr = arg\n" + "def f(p):\n return C(p)\n" + "c = f(A())\nx = c.attr\n" + ) self.mod.write(code) self.pycore.analyze_module(self.mod, followed_calls=1) pymod = self.project.get_pymodule(self.mod) - a_class = pymod['A'].get_object() - x_var = pymod['x'].get_object().get_type() + a_class = pymod["A"].get_object() + x_var = pymod["x"].get_object().get_type() self.assertEqual(a_class, x_var) diff --git a/ropetest/builtinstest.py b/ropetest/builtinstest.py index 0e17072b1..1892d581c 100644 --- a/ropetest/builtinstest.py +++ b/ropetest/builtinstest.py @@ -7,508 +7,551 @@ from ropetest import testutils from rope.base.builtins import Dict -class BuiltinTypesTest(unittest.TestCase): +class BuiltinTypesTest(unittest.TestCase): def setUp(self): super(BuiltinTypesTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(BuiltinTypesTest, self).tearDown() def test_simple_case(self): - self.mod.write('l = []\n') + self.mod.write("l = []\n") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('append' in pymod['l'].get_object()) + self.assertTrue("append" in pymod["l"].get_object()) def test_holding_type_information(self): - self.mod.write('class C(object):\n pass\n' - 'l = [C()]\na_var = l.pop()\n') + self.mod.write("class C(object):\n pass\n" "l = [C()]\na_var = l.pop()\n") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_get_items(self): - self.mod.write('class C(object):' - '\n def __getitem__(self, i):\n return C()\n' - 'c = C()\na_var = c[0]') + self.mod.write( + "class C(object):" + "\n def __getitem__(self, i):\n return C()\n" + "c = C()\na_var = c[0]" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_get_items_for_lists(self): - self.mod.write('class C(object):\n pass\nl = [C()]\na_var = l[0]\n') + self.mod.write("class C(object):\n pass\nl = [C()]\na_var = l[0]\n") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_get_items_from_slices(self): - self.mod.write('class C(object):\n pass' - '\nl = [C()]\na_var = l[:].pop()\n') + self.mod.write("class C(object):\n pass" "\nl = [C()]\na_var = l[:].pop()\n") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_simple_for_loops(self): - self.mod.write('class C(object):\n pass\nl = [C()]\n' - 'for c in l:\n a_var = c\n') + self.mod.write( + "class C(object):\n pass\nl = [C()]\n" "for c in l:\n a_var = c\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_definition_location_for_loop_variables(self): - self.mod.write('class C(object):\n pass\nl = [C()]\n' - 'for c in l:\n pass\n') + self.mod.write( + "class C(object):\n pass\nl = [C()]\n" "for c in l:\n pass\n" + ) pymod = self.project.get_pymodule(self.mod) - c_var = pymod['c'] + c_var = pymod["c"] self.assertEqual((pymod, 4), c_var.get_definition_location()) def test_simple_case_for_dicts(self): - self.mod.write('d = {}\n') + self.mod.write("d = {}\n") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('get' in pymod['d'].get_object()) + self.assertTrue("get" in pymod["d"].get_object()) def test_get_item_for_dicts(self): - self.mod.write('class C(object):\n pass\n' - 'd = {1: C()}\na_var = d[1]\n') + self.mod.write("class C(object):\n pass\n" "d = {1: C()}\na_var = d[1]\n") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_dict_function_parent(self): - self.mod.write('d = {1: 2}\n' - 'a_var = d.keys()') + self.mod.write("d = {1: 2}\n" "a_var = d.keys()") pymod = self.project.get_pymodule(self.mod) - a_var = pymod['d'].get_object()['keys'].get_object() + a_var = pymod["d"].get_object()["keys"].get_object() self.assertEqual(type(a_var.parent), Dict) def test_popping_dicts(self): - self.mod.write('class C(object):\n pass\n' - 'd = {1: C()}\na_var = d.pop(1)\n') + self.mod.write( + "class C(object):\n pass\n" "d = {1: C()}\na_var = d.pop(1)\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_getting_keys_from_dicts(self): - self.mod.write('class C1(object):\n pass\n' - 'class C2(object):\n pass\n' - 'd = {C1(): C2()}\nfor c in d.keys():\n a_var = c\n') + self.mod.write( + "class C1(object):\n pass\n" + "class C2(object):\n pass\n" + "d = {C1(): C2()}\nfor c in d.keys():\n a_var = c\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C1'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C1"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_getting_values_from_dicts(self): - self.mod.write('class C1(object):\n pass\n' - 'class C2(object):\n pass\n' - 'd = {C1(): C2()}\nfor c in d.values():' - '\n a_var = c\n') - pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C2'].get_object() - a_var = pymod['a_var'].get_object() + self.mod.write( + "class C1(object):\n pass\n" + "class C2(object):\n pass\n" + "d = {C1(): C2()}\nfor c in d.values():" + "\n a_var = c\n" + ) + pymod = self.project.get_pymodule(self.mod) + c_class = pymod["C2"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_getting_iterkeys_from_dicts(self): - self.mod.write('class C1(object):\n pass' - '\nclass C2(object):\n pass\n' - 'd = {C1(): C2()}\nfor c in d.keys():\n a_var = c\n') + self.mod.write( + "class C1(object):\n pass" + "\nclass C2(object):\n pass\n" + "d = {C1(): C2()}\nfor c in d.keys():\n a_var = c\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C1'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C1"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_getting_itervalues_from_dicts(self): - self.mod.write('class C1(object):\n pass' - '\nclass C2(object):\n pass\n' - 'd = {C1(): C2()}\nfor c in d.values():' - '\n a_var = c\n') - pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C2'].get_object() - a_var = pymod['a_var'].get_object() + self.mod.write( + "class C1(object):\n pass" + "\nclass C2(object):\n pass\n" + "d = {C1(): C2()}\nfor c in d.values():" + "\n a_var = c\n" + ) + pymod = self.project.get_pymodule(self.mod) + c_class = pymod["C2"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_using_copy_for_dicts(self): - self.mod.write('class C1(object):\n pass' - '\nclass C2(object):\n pass\n' - 'd = {C1(): C2()}\nfor c in d.copy():\n a_var = c\n') + self.mod.write( + "class C1(object):\n pass" + "\nclass C2(object):\n pass\n" + "d = {C1(): C2()}\nfor c in d.copy():\n a_var = c\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C1'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C1"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_tuple_assignments_for_items(self): - self.mod.write('class C1(object):\n pass' - '\nclass C2(object):\n pass\n' - 'd = {C1(): C2()}\nkey, value = d.items()[0]\n') - pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - key = pymod['key'].get_object() - value = pymod['value'].get_object() + self.mod.write( + "class C1(object):\n pass" + "\nclass C2(object):\n pass\n" + "d = {C1(): C2()}\nkey, value = d.items()[0]\n" + ) + pymod = self.project.get_pymodule(self.mod) + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + key = pymod["key"].get_object() + value = pymod["value"].get_object() self.assertEqual(c1_class, key.get_type()) self.assertEqual(c2_class, value.get_type()) def test_tuple_assignment_for_lists(self): - self.mod.write('class C(object):\n pass\n' - 'l = [C(), C()]\na, b = l\n') + self.mod.write("class C(object):\n pass\n" "l = [C(), C()]\na, b = l\n") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c_class, a_var.get_type()) self.assertEqual(c_class, b_var.get_type()) def test_tuple_assignments_for_iteritems_in_fors(self): - self.mod.write('class C1(object):\n pass\n' - 'class C2(object):\n pass\n' - 'd = {C1(): C2()}\n' - 'for x, y in d.items():\n a = x;\n b = y\n') - pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + self.mod.write( + "class C1(object):\n pass\n" + "class C2(object):\n pass\n" + "d = {C1(): C2()}\n" + "for x, y in d.items():\n a = x;\n b = y\n" + ) + pymod = self.project.get_pymodule(self.mod) + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_simple_tuple_assignments(self): - self.mod.write('class C1(object):' - '\n pass\nclass C2(object):\n pass\n' - 'a, b = C1(), C2()\n') - pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + self.mod.write( + "class C1(object):" + "\n pass\nclass C2(object):\n pass\n" + "a, b = C1(), C2()\n" + ) + pymod = self.project.get_pymodule(self.mod) + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_overriding_builtin_names(self): - self.mod.write('class C(object):\n pass\nlist = C\n') + self.mod.write("class C(object):\n pass\nlist = C\n") pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - list_var = pymod['list'].get_object() + c_class = pymod["C"].get_object() + list_var = pymod["list"].get_object() self.assertEqual(c_class, list_var) def test_simple_builtin_scope_test(self): - self.mod.write('l = list()\n') + self.mod.write("l = list()\n") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('append' in pymod['l'].get_object()) + self.assertTrue("append" in pymod["l"].get_object()) def test_simple_sets(self): - self.mod.write('s = set()\n') + self.mod.write("s = set()\n") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('add' in pymod['s'].get_object()) + self.assertTrue("add" in pymod["s"].get_object()) def test_making_lists_using_the_passed_argument_to_init(self): - self.mod.write('class C(object):\n pass\nl1 = [C()]\n' - 'l2 = list(l1)\na_var = l2.pop()') + self.mod.write( + "class C(object):\n pass\nl1 = [C()]\n" "l2 = list(l1)\na_var = l2.pop()" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_making_tuples_using_the_passed_argument_to_init(self): - self.mod.write('class C(object):\n pass\nl1 = [C()]\n' - 'l2 = tuple(l1)\na_var = l2[0]') + self.mod.write( + "class C(object):\n pass\nl1 = [C()]\n" "l2 = tuple(l1)\na_var = l2[0]" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_making_sets_using_the_passed_argument_to_init(self): - self.mod.write('class C(object):\n pass\nl1 = [C()]\n' - 'l2 = set(l1)\na_var = l2.pop()') + self.mod.write( + "class C(object):\n pass\nl1 = [C()]\n" "l2 = set(l1)\na_var = l2.pop()" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_making_dicts_using_the_passed_argument_to_init(self): - self.mod.write('class C1(object):\n pass\n' - 'class C2(object):\n pass\n' - 'l1 = [(C1(), C2())]\n' - 'l2 = dict(l1)\na, b = l2.items()[0]') - pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + self.mod.write( + "class C1(object):\n pass\n" + "class C2(object):\n pass\n" + "l1 = [(C1(), C2())]\n" + "l2 = dict(l1)\na, b = l2.items()[0]" + ) + pymod = self.project.get_pymodule(self.mod) + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_range_builtin_function(self): - self.mod.write('l = range(1)\n') + self.mod.write("l = range(1)\n") pymod = self.project.get_pymodule(self.mod) - l = pymod['l'].get_object() - self.assertTrue('append' in l) + l = pymod["l"].get_object() + self.assertTrue("append" in l) def test_reversed_builtin_function(self): - self.mod.write('class C(object):\n pass\nl = [C()]\n' - 'for x in reversed(l):\n a_var = x\n') + self.mod.write( + "class C(object):\n pass\nl = [C()]\n" + "for x in reversed(l):\n a_var = x\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_sorted_builtin_function(self): - self.mod.write('class C(object):\n pass\nl = [C()]\n' - 'a_var = sorted(l).pop()\n') + self.mod.write( + "class C(object):\n pass\nl = [C()]\n" "a_var = sorted(l).pop()\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_super_builtin_function(self): self.mod.write( - 'class C(object):\n pass\n' - 'class A(object):\n def a_f(self):\n return C()\n' - 'class B(A):\n def b_f(self):\n ' - 'return super(B, self).a_f()\n' - 'a_var = B.b_f()\n') - pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + "class C(object):\n pass\n" + "class A(object):\n def a_f(self):\n return C()\n" + "class B(A):\n def b_f(self):\n " + "return super(B, self).a_f()\n" + "a_var = B.b_f()\n" + ) + pymod = self.project.get_pymodule(self.mod) + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_file_builtin_type(self): self.mod.write('for line in open("file.txt"):\n a_var = line\n') pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertTrue(isinstance(a_var.get_type(), builtins.Str)) def test_property_builtin_type(self): - self.mod.write('p = property()\n') + self.mod.write("p = property()\n") pymod = self.project.get_pymodule(self.mod) - p_var = pymod['p'].get_object() - self.assertTrue('fget' in p_var) + p_var = pymod["p"].get_object() + self.assertTrue("fget" in p_var) def test_lambda_functions(self): - self.mod.write('l = lambda: 1\n') + self.mod.write("l = lambda: 1\n") pymod = self.project.get_pymodule(self.mod) - l_var = pymod['l'].get_object() - self.assertEqual(pyobjects.get_base_type('Function'), - l_var.get_type()) + l_var = pymod["l"].get_object() + self.assertEqual(pyobjects.get_base_type("Function"), l_var.get_type()) def test_lambda_function_definition(self): - self.mod.write('l = lambda x, y = 2, *a, **b: x + y\n') + self.mod.write("l = lambda x, y = 2, *a, **b: x + y\n") pymod = self.project.get_pymodule(self.mod) - l_var = pymod['l'].get_object() + l_var = pymod["l"].get_object() self.assertTrue(l_var.get_name() is not None) self.assertEqual(len(l_var.get_param_names()), 4) - self.assertEqual((pymod, 1), - pymod['l'].get_definition_location()) + self.assertEqual((pymod, 1), pymod["l"].get_definition_location()) def test_lambdas_that_return_unknown(self): - self.mod.write('a_var = (lambda: None)()\n') + self.mod.write("a_var = (lambda: None)()\n") pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertTrue(a_var is not None) def test_builtin_zip_function(self): self.mod.write( - 'class C1(object):\n pass\nclass C2(object):\n pass\n' - 'c1_list = [C1()]\nc2_list = [C2()]\n' - 'a, b = zip(c1_list, c2_list)[0]') - pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "c1_list = [C1()]\nc2_list = [C2()]\n" + "a, b = zip(c1_list, c2_list)[0]" + ) + pymod = self.project.get_pymodule(self.mod) + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_builtin_zip_function_with_more_than_two_args(self): self.mod.write( - 'class C1(object):\n pass\nclass C2(object):\n pass\n' - 'c1_list = [C1()]\nc2_list = [C2()]\n' - 'a, b, c = zip(c1_list, c2_list, c1_list)[0]') - pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - c2_class = pymod['C2'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() - c_var = pymod['c'].get_object() + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "c1_list = [C1()]\nc2_list = [C2()]\n" + "a, b, c = zip(c1_list, c2_list, c1_list)[0]" + ) + pymod = self.project.get_pymodule(self.mod) + c1_class = pymod["C1"].get_object() + c2_class = pymod["C2"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() + c_var = pymod["c"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) self.assertEqual(c1_class, c_var.get_type()) def test_wrong_arguments_to_zip_function(self): self.mod.write( - 'class C1(object):\n pass\nc1_list = [C1()]\n' - 'a, b = zip(c1_list, 1)[0]') + "class C1(object):\n pass\nc1_list = [C1()]\n" + "a, b = zip(c1_list, 1)[0]" + ) pymod = self.project.get_pymodule(self.mod) - c1_class = pymod['C1'].get_object() - a_var = pymod['a'].get_object() - b_var = pymod['b'].get_object() # noqa + c1_class = pymod["C1"].get_object() + a_var = pymod["a"].get_object() + b_var = pymod["b"].get_object() # noqa self.assertEqual(c1_class, a_var.get_type()) def test_enumerate_builtin_function(self): - self.mod.write('class C(object):\n pass\nl = [C()]\n' - 'for i, x in enumerate(l):\n a_var = x\n') + self.mod.write( + "class C(object):\n pass\nl = [C()]\n" + "for i, x in enumerate(l):\n a_var = x\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_builtin_class_get_name(self): - self.assertEqual('object', - builtins.builtins['object'].get_object().get_name()) + self.assertEqual("object", builtins.builtins["object"].get_object().get_name()) self.assertEqual( - 'property', builtins.builtins['property'].get_object().get_name()) + "property", builtins.builtins["property"].get_object().get_name() + ) def test_star_args_and_double_star_args(self): - self.mod.write('def func(p, *args, **kwds):\n pass\n') + self.mod.write("def func(p, *args, **kwds):\n pass\n") pymod = self.project.get_pymodule(self.mod) - func_scope = pymod['func'].get_object().get_scope() - args = func_scope['args'].get_object() - kwds = func_scope['kwds'].get_object() + func_scope = pymod["func"].get_object().get_scope() + args = func_scope["args"].get_object() + kwds = func_scope["kwds"].get_object() self.assertTrue(isinstance(args.get_type(), builtins.List)) self.assertTrue(isinstance(kwds.get_type(), builtins.Dict)) def test_simple_list_comprehension_test(self): - self.mod.write('a_var = [i for i in range(10)]\n') + self.mod.write("a_var = [i for i in range(10)]\n") pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertTrue(isinstance(a_var.get_type(), builtins.List)) def test_simple_list_generator_expression(self): - self.mod.write('a_var = (i for i in range(10))\n') + self.mod.write("a_var = (i for i in range(10))\n") pymod = self.project.get_pymodule(self.mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertTrue(isinstance(a_var.get_type(), builtins.Iterator)) def test_iter_builtin_function(self): - self.mod.write('class C(object):\n pass\nl = [C()]\n' - 'for c in iter(l):\n a_var = c\n') + self.mod.write( + "class C(object):\n pass\nl = [C()]\n" + "for c in iter(l):\n a_var = c\n" + ) pymod = self.project.get_pymodule(self.mod) - c_class = pymod['C'].get_object() - a_var = pymod['a_var'].get_object() + c_class = pymod["C"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_simple_int_type(self): - self.mod.write('l = 1\n') + self.mod.write("l = 1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['int'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["int"].get_object(), pymod["l"].get_object().get_type() + ) def test_simple_float_type(self): - self.mod.write('l = 1.0\n') + self.mod.write("l = 1.0\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['float'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["float"].get_object(), pymod["l"].get_object().get_type() + ) def test_simple_float_type2(self): - self.mod.write('l = 1e1\n') + self.mod.write("l = 1e1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['float'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["float"].get_object(), pymod["l"].get_object().get_type() + ) def test_simple_complex_type(self): - self.mod.write('l = 1.0j\n') + self.mod.write("l = 1.0j\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['complex'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["complex"].get_object(), + pymod["l"].get_object().get_type(), + ) def test_handling_unaryop_on_ints(self): - self.mod.write('l = -(1)\n') + self.mod.write("l = -(1)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['int'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["int"].get_object(), pymod["l"].get_object().get_type() + ) def test_handling_binop_on_ints(self): - self.mod.write('l = 1 + 1\n') + self.mod.write("l = 1 + 1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['int'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["int"].get_object(), pymod["l"].get_object().get_type() + ) def test_handling_compares(self): - self.mod.write('l = 1 == 1\n') + self.mod.write("l = 1 == 1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['bool'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["bool"].get_object(), pymod["l"].get_object().get_type() + ) def test_handling_boolops(self): - self.mod.write('l = 1 and 2\n') + self.mod.write("l = 1 and 2\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual(builtins.builtins['int'].get_object(), - pymod['l'].get_object().get_type()) + self.assertEqual( + builtins.builtins["int"].get_object(), pymod["l"].get_object().get_type() + ) def test_binary_or_left_value_unknown(self): - code = 'var = (asdsd or 3)\n' + code = "var = (asdsd or 3)\n" pymod = libutils.get_string_module(self.project, code) - self.assertEqual(builtins.builtins['int'].get_object(), - pymod['var'].get_object().get_type()) + self.assertEqual( + builtins.builtins["int"].get_object(), pymod["var"].get_object().get_type() + ) def test_unknown_return_object(self): - src = 'import sys\n' \ - 'def foo():\n' \ - ' res = set(sys.builtin_module_names)\n' \ - ' if foo: res.add(bar)\n' - self.project.prefs['import_dynload_stdmods'] = True + src = ( + "import sys\n" + "def foo():\n" + " res = set(sys.builtin_module_names)\n" + " if foo: res.add(bar)\n" + ) + self.project.prefs["import_dynload_stdmods"] = True self.mod.write(src) self.project.pycore.analyze_module(self.mod) def test_abstractmethods_attribute(self): # see http://bugs.python.org/issue10006 for details - src = 'class SubType(type): pass\nsubtype = SubType()\n' + src = "class SubType(type): pass\nsubtype = SubType()\n" self.mod.write(src) self.project.pycore.analyze_module(self.mod) class BuiltinModulesTest(unittest.TestCase): - def setUp(self): super(BuiltinModulesTest, self).setUp() self.project = testutils.sample_project( - extension_modules=['time', 'invalid', 'invalid.sub']) - self.mod = testutils.create_module(self.project, 'mod') + extension_modules=["time", "invalid", "invalid.sub"] + ) + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(BuiltinModulesTest, self).tearDown() def test_simple_case(self): - self.mod.write('import time') + self.mod.write("import time") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('time' in pymod['time'].get_object()) + self.assertTrue("time" in pymod["time"].get_object()) def test_ignored_extensions(self): - self.mod.write('import os') + self.mod.write("import os") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('rename' not in pymod['os'].get_object()) + self.assertTrue("rename" not in pymod["os"].get_object()) def test_ignored_extensions_2(self): - self.mod.write('import os') + self.mod.write("import os") pymod = self.project.get_pymodule(self.mod) - self.assertTrue('rename' not in pymod['os'].get_object()) + self.assertTrue("rename" not in pymod["os"].get_object()) def test_nonexistent_modules(self): - self.mod.write('import invalid') + self.mod.write("import invalid") pymod = self.project.get_pymodule(self.mod) - pymod['invalid'].get_object() + pymod["invalid"].get_object() def test_nonexistent_modules_2(self): - self.mod.write('import invalid\nimport invalid.sub') + self.mod.write("import invalid\nimport invalid.sub") pymod = self.project.get_pymodule(self.mod) - invalid = pymod['invalid'].get_object() - self.assertTrue('sub' in invalid) + invalid = pymod["invalid"].get_object() + self.assertTrue("sub" in invalid) def test_time_in_std_mods(self): import rope.base.stdmods - self.assertTrue('time' in rope.base.stdmods.standard_modules()) + + self.assertTrue("time" in rope.base.stdmods.standard_modules()) def test_timemodule_normalizes_to_time(self): import rope.base.stdmods - self.assertEqual( - rope.base.stdmods.normalize_so_name('timemodule.so'), 'time') + + self.assertEqual(rope.base.stdmods.normalize_so_name("timemodule.so"), "time") diff --git a/ropetest/codeanalyzetest.py b/ropetest/codeanalyzetest.py index e43159fb8..96d2f7857 100644 --- a/ropetest/codeanalyzetest.py +++ b/ropetest/codeanalyzetest.py @@ -1,4 +1,5 @@ from textwrap import dedent + try: import unittest2 as unittest except ImportError: @@ -7,13 +8,11 @@ import rope.base.evaluate from rope.base import libutils from rope.base import exceptions, worder, codeanalyze -from rope.base.codeanalyze import (SourceLinesAdapter, - LogicalLineFinder, get_block_start) +from rope.base.codeanalyze import SourceLinesAdapter, LogicalLineFinder, get_block_start from ropetest import testutils class SourceLinesAdapterTest(unittest.TestCase): - def setUp(self): super(SourceLinesAdapterTest, self).setUp() @@ -21,38 +20,37 @@ def tearDown(self): super(SourceLinesAdapterTest, self).tearDown() def test_source_lines_simple(self): - to_lines = SourceLinesAdapter('line1\nline2\n') - self.assertEqual('line1', to_lines.get_line(1)) - self.assertEqual('line2', to_lines.get_line(2)) - self.assertEqual('', to_lines.get_line(3)) + to_lines = SourceLinesAdapter("line1\nline2\n") + self.assertEqual("line1", to_lines.get_line(1)) + self.assertEqual("line2", to_lines.get_line(2)) + self.assertEqual("", to_lines.get_line(3)) self.assertEqual(3, to_lines.length()) def test_source_lines_get_line_number(self): - to_lines = SourceLinesAdapter('line1\nline2\n') + to_lines = SourceLinesAdapter("line1\nline2\n") self.assertEqual(1, to_lines.get_line_number(0)) self.assertEqual(1, to_lines.get_line_number(5)) self.assertEqual(2, to_lines.get_line_number(7)) self.assertEqual(3, to_lines.get_line_number(12)) def test_source_lines_get_line_start(self): - to_lines = SourceLinesAdapter('line1\nline2\n') + to_lines = SourceLinesAdapter("line1\nline2\n") self.assertEqual(0, to_lines.get_line_start(1)) self.assertEqual(6, to_lines.get_line_start(2)) self.assertEqual(12, to_lines.get_line_start(3)) def test_source_lines_get_line_end(self): - to_lines = SourceLinesAdapter('line1\nline2\n') + to_lines = SourceLinesAdapter("line1\nline2\n") self.assertEqual(5, to_lines.get_line_end(1)) self.assertEqual(11, to_lines.get_line_end(2)) self.assertEqual(12, to_lines.get_line_end(3)) def test_source_lines_last_line_with_no_new_line(self): - to_lines = SourceLinesAdapter('line1') + to_lines = SourceLinesAdapter("line1") self.assertEqual(1, to_lines.get_line_number(5)) class WordRangeFinderTest(unittest.TestCase): - def setUp(self): super(WordRangeFinderTest, self).setUp() @@ -65,62 +63,60 @@ def _find_primary(self, code, offset): return result def test_keyword_before_parens(self): - code = 'if (a_var).an_attr:\n pass\n' - self.assertEqual('(a_var).an_attr', - self._find_primary(code, code.index(':'))) + code = "if (a_var).an_attr:\n pass\n" + self.assertEqual("(a_var).an_attr", self._find_primary(code, code.index(":"))) def test_inside_parans(self): - code = 'a_func(a_var)' - self.assertEqual('a_var', self._find_primary(code, 10)) + code = "a_func(a_var)" + self.assertEqual("a_var", self._find_primary(code, 10)) def test_simple_names(self): - code = 'a_var = 10' - self.assertEqual('a_var', self._find_primary(code, 3)) + code = "a_var = 10" + self.assertEqual("a_var", self._find_primary(code, 3)) def test_function_calls(self): - code = 'sample_function()' - self.assertEqual('sample_function', self._find_primary(code, 10)) + code = "sample_function()" + self.assertEqual("sample_function", self._find_primary(code, 10)) def test_attribute_accesses(self): - code = 'a_var.an_attr' - self.assertEqual('a_var.an_attr', self._find_primary(code, 10)) + code = "a_var.an_attr" + self.assertEqual("a_var.an_attr", self._find_primary(code, 10)) def test_word_finder_on_word_beginning(self): - code = 'print(a_var)\n' + code = "print(a_var)\n" word_finder = worder.Worder(code) - result = word_finder.get_word_at(code.index('a_var')) - self.assertEqual('a_var', result) + result = word_finder.get_word_at(code.index("a_var")) + self.assertEqual("a_var", result) def test_word_finder_on_primary_beginning(self): - code = 'print(a_var)\n' - result = self._find_primary(code, code.index('a_var')) - self.assertEqual('a_var', result) + code = "print(a_var)\n" + result = self._find_primary(code, code.index("a_var")) + self.assertEqual("a_var", result) def test_word_finder_on_word_ending(self): - code = 'print(a_var)\n' + code = "print(a_var)\n" word_finder = worder.Worder(code) - result = word_finder.get_word_at(code.index('a_var') + 5) - self.assertEqual('a_var', result) + result = word_finder.get_word_at(code.index("a_var") + 5) + self.assertEqual("a_var", result) def test_word_finder_on_primary_ending(self): - code = 'print(a_var)\n' - result = self._find_primary(code, code.index('a_var') + 5) - self.assertEqual('a_var', result) + code = "print(a_var)\n" + result = self._find_primary(code, code.index("a_var") + 5) + self.assertEqual("a_var", result) def test_word_finder_on_primaries_with_dots_inside_parens(self): - code = '(a_var.\nattr)' - result = self._find_primary(code, code.index('attr') + 1) - self.assertEqual('a_var.\nattr', result) + code = "(a_var.\nattr)" + result = self._find_primary(code, code.index("attr") + 1) + self.assertEqual("a_var.\nattr", result) def test_word_finder_on_primary_like_keyword(self): - code = 'is_keyword = False\n' + code = "is_keyword = False\n" result = self._find_primary(code, 1) - self.assertEqual('is_keyword', result) + self.assertEqual("is_keyword", result) def test_keyword_before_parens_no_space(self): - code = 'if(a_var).an_attr:\n pass\n' - self.assertEqual('(a_var).an_attr', - self._find_primary(code, code.index(':'))) + code = "if(a_var).an_attr:\n pass\n" + self.assertEqual("(a_var).an_attr", self._find_primary(code, code.index(":"))) def test_strings(self): code = '"a string".split()' @@ -128,25 +124,26 @@ def test_strings(self): def test_function_calls2(self): code = 'file("afile.txt").read()' - self.assertEqual('file("afile.txt").read', - self._find_primary(code, 18)) + self.assertEqual('file("afile.txt").read', self._find_primary(code, 18)) def test_parens(self): code = '("afile.txt").split()' self.assertEqual('("afile.txt").split', self._find_primary(code, 18)) def test_function_with_no_param(self): - code = 'AClass().a_func()' - self.assertEqual('AClass().a_func', self._find_primary(code, 12)) + code = "AClass().a_func()" + self.assertEqual("AClass().a_func", self._find_primary(code, 12)) def test_function_with_multiple_param(self): code = 'AClass(a_param, another_param, "a string").a_func()' - self.assertEqual('AClass(a_param, another_param, "a string").a_func', - self._find_primary(code, 44)) + self.assertEqual( + 'AClass(a_param, another_param, "a string").a_func', + self._find_primary(code, 44), + ) def test_param_expressions(self): - code = 'AClass(an_object.an_attr).a_func()' - self.assertEqual('an_object.an_attr', self._find_primary(code, 20)) + code = "AClass(an_object.an_attr).a_func()" + self.assertEqual("an_object.an_attr", self._find_primary(code, 20)) def test_string_parens(self): code = 'a_func("(").an_attr' @@ -154,110 +151,103 @@ def test_string_parens(self): def test_extra_spaces(self): code = 'a_func ( "(" ) . an_attr' - self.assertEqual('a_func ( "(" ) . an_attr', - self._find_primary(code, 26)) + self.assertEqual('a_func ( "(" ) . an_attr', self._find_primary(code, 26)) def test_relative_import(self): code = "from .module import smt" - self.assertEqual('.module', - self._find_primary(code, 5)) + self.assertEqual(".module", self._find_primary(code, 5)) def test_functions_on_ending_parens(self): - code = 'A()' - self.assertEqual('A()', self._find_primary(code, 2)) + code = "A()" + self.assertEqual("A()", self._find_primary(code, 2)) def test_splitted_statement(self): - word_finder = worder.Worder('an_object.an_attr') - self.assertEqual(('an_object', 'an_at', 10), - word_finder.get_splitted_primary_before(15)) + word_finder = worder.Worder("an_object.an_attr") + self.assertEqual( + ("an_object", "an_at", 10), word_finder.get_splitted_primary_before(15) + ) def test_empty_splitted_statement(self): - word_finder = worder.Worder('an_attr') - self.assertEqual(('', 'an_at', 0), - word_finder.get_splitted_primary_before(5)) + word_finder = worder.Worder("an_attr") + self.assertEqual(("", "an_at", 0), word_finder.get_splitted_primary_before(5)) def test_empty_splitted_statement2(self): - word_finder = worder.Worder('an_object.') - self.assertEqual(('an_object', '', 10), - word_finder.get_splitted_primary_before(10)) + word_finder = worder.Worder("an_object.") + self.assertEqual( + ("an_object", "", 10), word_finder.get_splitted_primary_before(10) + ) def test_empty_splitted_statement3(self): - word_finder = worder.Worder('') - self.assertEqual(('', '', 0), - word_finder.get_splitted_primary_before(0)) + word_finder = worder.Worder("") + self.assertEqual(("", "", 0), word_finder.get_splitted_primary_before(0)) def test_empty_splitted_statement4(self): - word_finder = worder.Worder('a_var = ') - self.assertEqual(('', '', 8), - word_finder.get_splitted_primary_before(8)) + word_finder = worder.Worder("a_var = ") + self.assertEqual(("", "", 8), word_finder.get_splitted_primary_before(8)) def test_empty_splitted_statement5(self): - word_finder = worder.Worder('a.') - self.assertEqual(('a', '', 2), - word_finder.get_splitted_primary_before(2)) + word_finder = worder.Worder("a.") + self.assertEqual(("a", "", 2), word_finder.get_splitted_primary_before(2)) def test_operators_inside_parens(self): - code = '(a_var + another_var).reverse()' - self.assertEqual('(a_var + another_var).reverse', - self._find_primary(code, 25)) + code = "(a_var + another_var).reverse()" + self.assertEqual("(a_var + another_var).reverse", self._find_primary(code, 25)) def test_dictionaries(self): code = 'print({1: "one", 2: "two"}.keys())' - self.assertEqual('{1: "one", 2: "two"}.keys', - self._find_primary(code, 29)) + self.assertEqual('{1: "one", 2: "two"}.keys', self._find_primary(code, 29)) def test_following_parens(self): - code = 'a_var = a_func()()' - result = self._find_primary(code, code.index(')(') + 3) - self.assertEqual('a_func()()', result) + code = "a_var = a_func()()" + result = self._find_primary(code, code.index(")(") + 3) + self.assertEqual("a_func()()", result) def test_comments_for_finding_statements(self): - code = '# var2 . \n var3' - self.assertEqual('var3', self._find_primary(code, code.index('3'))) + code = "# var2 . \n var3" + self.assertEqual("var3", self._find_primary(code, code.index("3"))) def test_str_in_comments_for_finding_statements(self): code = '# "var2" . \n var3' - self.assertEqual('var3', self._find_primary(code, code.index('3'))) + self.assertEqual("var3", self._find_primary(code, code.index("3"))) def test_comments_for_finding_statements2(self): code = 'var1 + "# var2".\n var3' - self.assertEqual('var3', self._find_primary(code, 21)) + self.assertEqual("var3", self._find_primary(code, 21)) def test_comments_for_finding_statements3(self): code = '"" + # var2.\n var3' - self.assertEqual('var3', self._find_primary(code, 21)) + self.assertEqual("var3", self._find_primary(code, 21)) def test_import_statement_finding(self): - code = 'import mod\na_var = 10\n' + code = "import mod\na_var = 10\n" word_finder = worder.Worder(code) - self.assertTrue(word_finder.is_import_statement(code.index('mod') + 1)) - self.assertFalse(word_finder.is_import_statement( - code.index('a_var') + 1)) + self.assertTrue(word_finder.is_import_statement(code.index("mod") + 1)) + self.assertFalse(word_finder.is_import_statement(code.index("a_var") + 1)) def test_import_statement_finding2(self): - code = 'import a.b.c.d\nresult = a.b.c.d.f()\n' + code = "import a.b.c.d\nresult = a.b.c.d.f()\n" word_finder = worder.Worder(code) - self.assertFalse(word_finder.is_import_statement(code.rindex('d') + 1)) + self.assertFalse(word_finder.is_import_statement(code.rindex("d") + 1)) def test_word_parens_range(self): - code = 's = str()\ns.title()\n' + code = "s = str()\ns.title()\n" word_finder = worder.Worder(code) - result = word_finder.get_word_parens_range(code.rindex('()') - 1) + result = word_finder.get_word_parens_range(code.rindex("()") - 1) self.assertEqual((len(code) - 3, len(code) - 1), result) def test_getting_primary_before_get_index(self): - code = '\na = (b + c).d[0]()\n' + code = "\na = (b + c).d[0]()\n" result = self._find_primary(code, len(code) - 2) - self.assertEqual('(b + c).d[0]()', result) + self.assertEqual("(b + c).d[0]()", result) def test_getting_primary_and_strings_at_the_end_of_line(self): - code = 'f(\'\\\'\')\n' + code = "f('\\'')\n" result = self._find_primary(code, len(code) - 1) # noqa def test_getting_primary_and_not_crossing_newlines(self): - code = '\na = (b + c)\n(4 + 1).x\n' + code = "\na = (b + c)\n(4 + 1).x\n" result = self._find_primary(code, len(code) - 1) - self.assertEqual('(4 + 1).x', result) + self.assertEqual("(4 + 1).x", result) # XXX: cancatenated string literals def xxx_test_getting_primary_cancatenating_strs(self): @@ -266,7 +256,7 @@ def xxx_test_getting_primary_cancatenating_strs(self): self.assertEqual('"b" "c"', result) def test_is_a_function_being_called_with_parens_on_next_line(self): - code = 'func\n(1, 2)\n' + code = "func\n(1, 2)\n" word_finder = worder.Worder(code) self.assertFalse(word_finder.is_a_function_being_called(1)) @@ -288,39 +278,39 @@ def test_get_word_parens_range_and_string_literals(self): self.assertEqual((1, len(code) - 1), result) def test_is_assigned_here_for_equality_test(self): - code = 'a == 1\n' + code = "a == 1\n" word_finder = worder.Worder(code) self.assertFalse(word_finder.is_assigned_here(0)) def test_is_assigned_here_for_not_equal_test(self): - code = 'a != 1\n' + code = "a != 1\n" word_finder = worder.Worder(code) self.assertFalse(word_finder.is_assigned_here(0)) # XXX: is_assigned_here should work for tuple assignments def xxx_test_is_assigned_here_for_tuple_assignment(self): - code = 'a, b = (1, 2)\n' + code = "a, b = (1, 2)\n" word_finder = worder.Worder(code) self.assertTrue(word_finder.is_assigned_here(0)) def test_is_from_with_from_import_and_multiline_parens(self): - code = 'from mod import \\\n (f,\n g, h)\n' + code = "from mod import \\\n (f,\n g, h)\n" word_finder = worder.Worder(code) - self.assertTrue(word_finder.is_from_statement(code.rindex('g'))) + self.assertTrue(word_finder.is_from_statement(code.rindex("g"))) def test_is_from_with_from_import_and_line_breaks_in_the_middle(self): - code = 'from mod import f,\\\n g\n' + code = "from mod import f,\\\n g\n" word_finder = worder.Worder(code) - self.assertTrue(word_finder.is_from_statement(code.rindex('g'))) + self.assertTrue(word_finder.is_from_statement(code.rindex("g"))) def test_one_letter_function_keyword_arguments(self): - code = 'f(p=1)\n' + code = "f(p=1)\n" word_finder = worder.Worder(code) - index = code.rindex('p') + index = code.rindex("p") self.assertTrue(word_finder.is_function_keyword_parameter(index)) def test_find_parens_start(self): - code = 'f(p)\n' + code = "f(p)\n" finder = worder.Worder(code) self.assertEqual(1, finder.find_parens_start_from_inside(2)) @@ -330,31 +320,35 @@ def test_underlined_find_parens_start(self): self.assertEqual(1, finder._find_parens_start(len(code) - 2)) def test_find_parens_start_with_multiple_entries(self): - code = 'myfunc(p1, p2, p3\n' + code = "myfunc(p1, p2, p3\n" finder = worder.Worder(code) - self.assertEqual(code.index('('), - finder.find_parens_start_from_inside(len(code) - 1)) + self.assertEqual( + code.index("("), finder.find_parens_start_from_inside(len(code) - 1) + ) def test_find_parens_start_with_nested_parens(self): - code = 'myfunc(p1, (p2, p3), p4\n' + code = "myfunc(p1, (p2, p3), p4\n" finder = worder.Worder(code) - self.assertEqual(code.index('('), - finder.find_parens_start_from_inside(len(code) - 1)) + self.assertEqual( + code.index("("), finder.find_parens_start_from_inside(len(code) - 1) + ) def test_find_parens_start_with_parens_in_strs(self): code = 'myfunc(p1, "(", p4\n' finder = worder.Worder(code) - self.assertEqual(code.index('('), - finder.find_parens_start_from_inside(len(code) - 1)) + self.assertEqual( + code.index("("), finder.find_parens_start_from_inside(len(code) - 1) + ) def test_find_parens_start_with_parens_in_strs_in_multiple_lines(self): code = 'myfunc (\np1\n , \n "(" \n, \np4\n' finder = worder.Worder(code) - self.assertEqual(code.index('('), - finder.find_parens_start_from_inside(len(code) - 1)) + self.assertEqual( + code.index("("), finder.find_parens_start_from_inside(len(code) - 1) + ) def test_is_on_function_keyword(self): - code = 'myfunc(va' + code = "myfunc(va" finder = worder.Worder(code) self.assertTrue(finder.is_on_function_call_keyword(len(code) - 1)) @@ -365,7 +359,6 @@ def test_get_word_range_with_fstring(self): class ScopeNameFinderTest(unittest.TestCase): - def setUp(self): super(ScopeNameFinderTest, self).setUp() self.project = testutils.sample_project() @@ -377,141 +370,140 @@ def tearDown(self): # FIXME: in normal scopes the interpreter raises `UnboundLocalName` # exception, but not in class bodies def xxx_test_global_name_in_class_body(self): - code = 'a_var = 10\nclass C(object):\n a_var = a_var\n' + code = "a_var = 10\nclass C(object):\n a_var = a_var\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) result = name_finder.get_pyname_at(len(code) - 3) - self.assertEqual(scope['a_var'], result) + self.assertEqual(scope["a_var"], result) def test_class_variable_attribute_in_class_body(self): - code = 'a_var = 10\nclass C(object):\n a_var = a_var\n' + code = "a_var = 10\nclass C(object):\n a_var = a_var\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) - a_var_pyname = scope['C'].get_object()['a_var'] + a_var_pyname = scope["C"].get_object()["a_var"] result = name_finder.get_pyname_at(len(code) - 12) self.assertEqual(a_var_pyname, result) def test_class_variable_attribute_in_class_body2(self): - code = 'a_var = 10\nclass C(object):\n a_var \\\n= a_var\n' + code = "a_var = 10\nclass C(object):\n a_var \\\n= a_var\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) - a_var_pyname = scope['C'].get_object()['a_var'] + a_var_pyname = scope["C"].get_object()["a_var"] result = name_finder.get_pyname_at(len(code) - 12) self.assertEqual(a_var_pyname, result) def test_class_method_attribute_in_class_body(self): - code = 'class C(object):\n def a_method(self):\n pass\n' + code = "class C(object):\n def a_method(self):\n pass\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) - a_method_pyname = scope['C'].get_object()['a_method'] - result = name_finder.get_pyname_at(code.index('a_method') + 2) + a_method_pyname = scope["C"].get_object()["a_method"] + result = name_finder.get_pyname_at(code.index("a_method") + 2) self.assertEqual(a_method_pyname, result) def test_inner_class_attribute_in_class_body(self): - code = 'class C(object):\n class CC(object):\n pass\n' + code = "class C(object):\n class CC(object):\n pass\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) - a_class_pyname = scope['C'].get_object()['CC'] - result = name_finder.get_pyname_at(code.index('CC') + 2) + a_class_pyname = scope["C"].get_object()["CC"] + result = name_finder.get_pyname_at(code.index("CC") + 2) self.assertEqual(a_class_pyname, result) def test_class_method_in_class_body_but_not_indexed(self): - code = 'class C(object):\n def func(self, func):\n pass\n' + code = "class C(object):\n def func(self, func):\n pass\n" scope = libutils.get_string_scope(self.project, code) - a_func_pyname = scope.get_scopes()[0].get_scopes()[0]['func'] + a_func_pyname = scope.get_scopes()[0].get_scopes()[0]["func"] name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) - result = name_finder.get_pyname_at(code.index(', func') + 3) + result = name_finder.get_pyname_at(code.index(", func") + 3) self.assertEqual(a_func_pyname, result) def test_function_but_not_indexed(self): - code = 'def a_func(a_func):\n pass\n' + code = "def a_func(a_func):\n pass\n" scope = libutils.get_string_scope(self.project, code) - a_func_pyname = scope['a_func'] + a_func_pyname = scope["a_func"] name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) - result = name_finder.get_pyname_at(code.index('a_func') + 3) + result = name_finder.get_pyname_at(code.index("a_func") + 3) self.assertEqual(a_func_pyname, result) def test_modules_after_from_statements(self): root_folder = self.project.root - mod = testutils.create_module(self.project, 'mod', root_folder) - mod.write('def a_func():\n pass\n') - code = 'from mod import a_func\n' + mod = testutils.create_module(self.project, "mod", root_folder) + mod.write("def a_func():\n pass\n") + code = "from mod import a_func\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) mod_pyobject = self.project.get_pymodule(mod) - found_pyname = name_finder.get_pyname_at(code.index('mod') + 1) + found_pyname = name_finder.get_pyname_at(code.index("mod") + 1) self.assertEqual(mod_pyobject, found_pyname.get_object()) def test_renaming_functions_with_from_import_and_parens(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def afunc():\n pass\n') - code = 'from mod1 import (\n afunc as func)\n' + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def afunc():\n pass\n") + code = "from mod1 import (\n afunc as func)\n" scope = libutils.get_string_scope(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(scope.pyobject) mod_pyobject = self.project.get_pymodule(mod1) - afunc = mod_pyobject['afunc'] - found_pyname = name_finder.get_pyname_at(code.index('afunc') + 1) + afunc = mod_pyobject["afunc"] + found_pyname = name_finder.get_pyname_at(code.index("afunc") + 1) self.assertEqual(afunc.get_object(), found_pyname.get_object()) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_relative_modules_after_from_statements(self): - pkg1 = testutils.create_package(self.project, 'pkg1') - pkg2 = testutils.create_package(self.project, 'pkg2', pkg1) - mod1 = testutils.create_module(self.project, 'mod1', pkg1) - mod2 = testutils.create_module(self.project, 'mod2', pkg2) - mod1.write('def a_func():\n pass\n') - code = 'from ..mod1 import a_func\n' + pkg1 = testutils.create_package(self.project, "pkg1") + pkg2 = testutils.create_package(self.project, "pkg2", pkg1) + mod1 = testutils.create_module(self.project, "mod1", pkg1) + mod2 = testutils.create_module(self.project, "mod2", pkg2) + mod1.write("def a_func():\n pass\n") + code = "from ..mod1 import a_func\n" mod2.write(code) mod2_scope = self.project.get_pymodule(mod2).get_scope() name_finder = rope.base.evaluate.ScopeNameFinder(mod2_scope.pyobject) mod1_pyobject = self.project.get_pymodule(mod1) - found_pyname = name_finder.get_pyname_at(code.index('mod1') + 1) + found_pyname = name_finder.get_pyname_at(code.index("mod1") + 1) self.assertEqual(mod1_pyobject, found_pyname.get_object()) def test_relative_modules_after_from_statements2(self): - mod1 = testutils.create_module(self.project, 'mod1') - pkg1 = testutils.create_package(self.project, 'pkg1') - pkg2 = testutils.create_package(self.project, 'pkg2', pkg1) - mod2 = testutils.create_module(self.project, 'mod2', pkg2) # noqa - mod1.write('import pkg1.pkg2.mod2') + mod1 = testutils.create_module(self.project, "mod1") + pkg1 = testutils.create_package(self.project, "pkg1") + pkg2 = testutils.create_package(self.project, "pkg2", pkg1) + mod2 = testutils.create_module(self.project, "mod2", pkg2) # noqa + mod1.write("import pkg1.pkg2.mod2") mod1_scope = self.project.get_pymodule(mod1).get_scope() name_finder = rope.base.evaluate.ScopeNameFinder(mod1_scope.pyobject) pkg2_pyobject = self.project.get_pymodule(pkg2) - found_pyname = name_finder.get_pyname_at(mod1.read().index('pkg2') + 1) + found_pyname = name_finder.get_pyname_at(mod1.read().index("pkg2") + 1) self.assertEqual(pkg2_pyobject, found_pyname.get_object()) def test_get_pyname_at_on_language_keywords(self): - code = 'def a_func(a_func):\n pass\n' + code = "def a_func(a_func):\n pass\n" pymod = libutils.get_string_module(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(pymod) with self.assertRaises(exceptions.RopeError): - name_finder.get_pyname_at(code.index('pass')) + name_finder.get_pyname_at(code.index("pass")) def test_one_liners(self): - code = 'var = 1\ndef f(): var = 2\nprint(var)\n' + code = "var = 1\ndef f(): var = 2\nprint(var)\n" pymod = libutils.get_string_module(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(pymod) - pyname = name_finder.get_pyname_at(code.rindex('var')) - self.assertEqual(pymod['var'], pyname) + pyname = name_finder.get_pyname_at(code.rindex("var")) + self.assertEqual(pymod["var"], pyname) def test_one_liners_with_line_breaks(self): - code = 'var = 1\ndef f(\n): var = 2\nprint(var)\n' + code = "var = 1\ndef f(\n): var = 2\nprint(var)\n" pymod = libutils.get_string_module(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(pymod) - pyname = name_finder.get_pyname_at(code.rindex('var')) - self.assertEqual(pymod['var'], pyname) + pyname = name_finder.get_pyname_at(code.rindex("var")) + self.assertEqual(pymod["var"], pyname) def test_one_liners_with_line_breaks2(self): - code = 'var = 1\ndef f(\np): var = 2\nprint(var)\n' + code = "var = 1\ndef f(\np): var = 2\nprint(var)\n" pymod = libutils.get_string_module(self.project, code) name_finder = rope.base.evaluate.ScopeNameFinder(pymod) - pyname = name_finder.get_pyname_at(code.rindex('var')) - self.assertEqual(pymod['var'], pyname) + pyname = name_finder.get_pyname_at(code.rindex("var")) + self.assertEqual(pymod["var"], pyname) class LogicalLineFinderTest(unittest.TestCase): - def setUp(self): super(LogicalLineFinderTest, self).setUp() @@ -522,70 +514,73 @@ def _logical_finder(self, code): return LogicalLineFinder(SourceLinesAdapter(code)) def test_normal_lines(self): - code = 'a_var = 10' + code = "a_var = 10" line_finder = self._logical_finder(code) self.assertEqual((1, 1), line_finder.logical_line_in(1)) def test_normal_lines2(self): - code = 'another = 10\na_var = 20\n' + code = "another = 10\na_var = 20\n" line_finder = self._logical_finder(code) self.assertEqual((1, 1), line_finder.logical_line_in(1)) self.assertEqual((2, 2), line_finder.logical_line_in(2)) def test_implicit_continuation(self): - code = 'a_var = 3 + \\\n 4 + \\\n 5' + code = "a_var = 3 + \\\n 4 + \\\n 5" line_finder = self._logical_finder(code) self.assertEqual((1, 3), line_finder.logical_line_in(2)) def test_explicit_continuation(self): - code = 'print(2)\na_var = (3 + \n 4, \n 5)\n' + code = "print(2)\na_var = (3 + \n 4, \n 5)\n" line_finder = self._logical_finder(code) self.assertEqual((2, 4), line_finder.logical_line_in(2)) def test_explicit_continuation_comments(self): - code = '#\na_var = 3\n' + code = "#\na_var = 3\n" line_finder = self._logical_finder(code) self.assertEqual((2, 2), line_finder.logical_line_in(2)) def test_multiple_indented_ifs(self): - code = 'if True:\n if True:\n ' \ - 'if True:\n pass\n a = 10\n' + code = ( + "if True:\n if True:\n " + "if True:\n pass\n a = 10\n" + ) line_finder = self._logical_finder(code) self.assertEqual((5, 5), line_finder.logical_line_in(5)) def test_list_comprehensions_and_fors(self): - code = 'a_list = [i\n for i in range(10)]\n' + code = "a_list = [i\n for i in range(10)]\n" line_finder = self._logical_finder(code) self.assertEqual((1, 2), line_finder.logical_line_in(2)) def test_generator_expressions_and_fors(self): - code = 'a_list = (i\n for i in range(10))\n' + code = "a_list = (i\n for i in range(10))\n" line_finder = self._logical_finder(code) self.assertEqual((1, 2), line_finder.logical_line_in(2)) def test_fors_and_block_start(self): - code = 'l = range(10)\nfor i in l:\n print(i)\n' + code = "l = range(10)\nfor i in l:\n print(i)\n" self.assertEqual(2, get_block_start(SourceLinesAdapter(code), 2)) def test_problems_with_inner_indentations(self): - code = 'if True:\n if True:\n if True:\n pass\n' \ - ' a = \\\n 1\n' + code = ( + "if True:\n if True:\n if True:\n pass\n" + " a = \\\n 1\n" + ) line_finder = self._logical_finder(code) self.assertEqual((5, 6), line_finder.logical_line_in(6)) def test_problems_with_inner_indentations2(self): - code = 'if True:\n if True:\n pass\n' \ - 'a = 1\n' + code = "if True:\n if True:\n pass\n" "a = 1\n" line_finder = self._logical_finder(code) self.assertEqual((4, 4), line_finder.logical_line_in(4)) def test_logical_lines_for_else(self): - code = 'if True:\n pass\nelse:\n pass\n' + code = "if True:\n pass\nelse:\n pass\n" line_finder = self._logical_finder(code) self.assertEqual((3, 3), line_finder.logical_line_in(3)) def test_logical_lines_for_lines_with_wrong_continues(self): - code = 'var = 1 + \\' + code = "var = 1 + \\" line_finder = self._logical_finder(code) self.assertEqual((1, 1), line_finder.logical_line_in(1)) @@ -600,28 +595,30 @@ def test_logical_lines_for_multiline_string_with_escaped_quotes(self): self.assertEqual((2, 2), line_finder.logical_line_in(2)) def test_generating_line_starts(self): - code = 'a = 1\na = 2\n\na = 3\n' + code = "a = 1\na = 2\n\na = 3\n" line_finder = self._logical_finder(code) self.assertEqual([1, 2, 4], list(line_finder.generate_starts())) def test_generating_line_starts2(self): - code = 'a = 1\na = 2\n\na = \\ 3\n' + code = "a = 1\na = 2\n\na = \\ 3\n" line_finder = self._logical_finder(code) self.assertEqual([2, 4], list(line_finder.generate_starts(2))) def test_generating_line_starts3(self): - code = 'a = 1\na = 2\n\na = \\ 3\n' + code = "a = 1\na = 2\n\na = \\ 3\n" line_finder = self._logical_finder(code) self.assertEqual([2], list(line_finder.generate_starts(2, 3))) def test_generating_line_starts_for_multi_line_statements(self): - code = '\na = \\\n 1 + \\\n 1\n' + code = "\na = \\\n 1 + \\\n 1\n" line_finder = self._logical_finder(code) self.assertEqual([2], list(line_finder.generate_starts())) def test_generating_line_starts_and_unmatched_deindents(self): - code = 'if True:\n if True:\n if True:\n' \ - ' a = 1\n b = 1\n' + code = ( + "if True:\n if True:\n if True:\n" + " a = 1\n b = 1\n" + ) line_finder = self._logical_finder(code) self.assertEqual([4, 5], list(line_finder.generate_starts(4))) @@ -642,16 +639,14 @@ def bar(): class TokenizerLogicalLineFinderTest(LogicalLineFinderTest): - def _logical_finder(self, code): lines = SourceLinesAdapter(code) return codeanalyze.CachingLogicalLineFinder( - lines, codeanalyze.tokenizer_generator) + lines, codeanalyze.tokenizer_generator + ) class CustomLogicalLineFinderTest(LogicalLineFinderTest): - def _logical_finder(self, code): lines = SourceLinesAdapter(code) - return codeanalyze.CachingLogicalLineFinder( - lines, codeanalyze.custom_generator) + return codeanalyze.CachingLogicalLineFinder(lines, codeanalyze.custom_generator) diff --git a/ropetest/contrib/__init__.py b/ropetest/contrib/__init__.py index e1515484c..4860f1443 100644 --- a/ropetest/contrib/__init__.py +++ b/ropetest/contrib/__init__.py @@ -1,4 +1,5 @@ import sys + try: import unittest2 as unittest except ImportError: @@ -15,21 +16,21 @@ def suite(): result = unittest.TestSuite() - result.addTests(unittest.makeSuite(ropetest.contrib.generatetest. - GenerateTest)) + result.addTests(unittest.makeSuite(ropetest.contrib.generatetest.GenerateTest)) result.addTests(ropetest.contrib.codeassisttest.suite()) result.addTests(ropetest.contrib.autoimporttest.suite()) result.addTests(ropetest.contrib.findittest.suite()) - result.addTests(unittest.makeSuite(ropetest.contrib.changestacktest. - ChangeStackTest)) - result.addTests(unittest.makeSuite(ropetest.contrib.fixmodnamestest. - FixModuleNamesTest)) - result.addTests(unittest.makeSuite(ropetest.contrib.finderrorstest. - FindErrorsTest)) + result.addTests( + unittest.makeSuite(ropetest.contrib.changestacktest.ChangeStackTest) + ) + result.addTests( + unittest.makeSuite(ropetest.contrib.fixmodnamestest.FixModuleNamesTest) + ) + result.addTests(unittest.makeSuite(ropetest.contrib.finderrorstest.FindErrorsTest)) return result -if __name__ == '__main__': +if __name__ == "__main__": runner = unittest.TextTestRunner() result = runner.run(suite()) sys.exit(not result.wasSuccessful()) diff --git a/ropetest/contrib/autoimporttest.py b/ropetest/contrib/autoimporttest.py index 1a7afc39d..995eba6b9 100644 --- a/ropetest/contrib/autoimporttest.py +++ b/ropetest/contrib/autoimporttest.py @@ -8,13 +8,12 @@ class AutoImportTest(unittest.TestCase): - def setUp(self): super(AutoImportTest, self).setUp() - self.project = testutils.sample_project(extension_modules=['sys']) - self.mod1 = testutils.create_module(self.project, 'mod1') - self.pkg = testutils.create_package(self.project, 'pkg') - self.mod2 = testutils.create_module(self.project, 'mod2', self.pkg) + self.project = testutils.sample_project(extension_modules=["sys"]) + self.mod1 = testutils.create_module(self.project, "mod1") + self.pkg = testutils.create_package(self.project, "pkg") + self.mod2 = testutils.create_module(self.project, "mod2", self.pkg) self.importer = autoimport.AutoImport(self.project, observe=False) def tearDown(self): @@ -22,123 +21,117 @@ def tearDown(self): super(AutoImportTest, self).tearDown() def test_simple_case(self): - self.assertEqual([], self.importer.import_assist('A')) + self.assertEqual([], self.importer.import_assist("A")) def test_update_resource(self): - self.mod1.write('myvar = None\n') + self.mod1.write("myvar = None\n") self.importer.update_resource(self.mod1) - self.assertEqual([('myvar', 'mod1')], - self.importer.import_assist('myva')) + self.assertEqual([("myvar", "mod1")], self.importer.import_assist("myva")) def test_update_module(self): - self.mod1.write('myvar = None') - self.importer.update_module('mod1') - self.assertEqual([('myvar', 'mod1')], - self.importer.import_assist('myva')) + self.mod1.write("myvar = None") + self.importer.update_module("mod1") + self.assertEqual([("myvar", "mod1")], self.importer.import_assist("myva")) def test_update_non_existent_module(self): - self.importer.update_module('does_not_exists_this') - self.assertEqual([], self.importer.import_assist('myva')) + self.importer.update_module("does_not_exists_this") + self.assertEqual([], self.importer.import_assist("myva")) def test_module_with_syntax_errors(self): - self.mod1.write('this is a syntax error\n') + self.mod1.write("this is a syntax error\n") self.importer.update_resource(self.mod1) - self.assertEqual([], self.importer.import_assist('myva')) + self.assertEqual([], self.importer.import_assist("myva")) def test_excluding_imported_names(self): - self.mod1.write('import pkg\n') + self.mod1.write("import pkg\n") self.importer.update_resource(self.mod1) - self.assertEqual([], self.importer.import_assist('pkg')) + self.assertEqual([], self.importer.import_assist("pkg")) def test_get_modules(self): - self.mod1.write('myvar = None\n') + self.mod1.write("myvar = None\n") self.importer.update_resource(self.mod1) - self.assertEqual(['mod1'], self.importer.get_modules('myvar')) + self.assertEqual(["mod1"], self.importer.get_modules("myvar")) def test_get_modules_inside_packages(self): - self.mod1.write('myvar = None\n') - self.mod2.write('myvar = None\n') + self.mod1.write("myvar = None\n") + self.mod2.write("myvar = None\n") self.importer.update_resource(self.mod1) self.importer.update_resource(self.mod2) - self.assertEqual(set(['mod1', 'pkg.mod2']), - set(self.importer.get_modules('myvar'))) + self.assertEqual( + set(["mod1", "pkg.mod2"]), set(self.importer.get_modules("myvar")) + ) def test_trivial_insertion_line(self): - result = self.importer.find_insertion_line('') + result = self.importer.find_insertion_line("") self.assertEqual(1, result) def test_insertion_line(self): - result = self.importer.find_insertion_line('import mod\n') + result = self.importer.find_insertion_line("import mod\n") self.assertEqual(2, result) def test_insertion_line_with_pydocs(self): - result = self.importer.find_insertion_line( - '"""docs\n\ndocs"""\nimport mod\n') + result = self.importer.find_insertion_line('"""docs\n\ndocs"""\nimport mod\n') self.assertEqual(5, result) def test_insertion_line_with_multiple_imports(self): - result = self.importer.find_insertion_line( - 'import mod1\n\nimport mod2\n') + result = self.importer.find_insertion_line("import mod1\n\nimport mod2\n") self.assertEqual(4, result) def test_insertion_line_with_blank_lines(self): - result = self.importer.find_insertion_line( - 'import mod1\n\n# comment\n') + result = self.importer.find_insertion_line("import mod1\n\n# comment\n") self.assertEqual(2, result) def test_empty_cache(self): - self.mod1.write('myvar = None\n') + self.mod1.write("myvar = None\n") self.importer.update_resource(self.mod1) - self.assertEqual(['mod1'], self.importer.get_modules('myvar')) + self.assertEqual(["mod1"], self.importer.get_modules("myvar")) self.importer.clear_cache() - self.assertEqual([], self.importer.get_modules('myvar')) + self.assertEqual([], self.importer.get_modules("myvar")) def test_not_caching_underlined_names(self): - self.mod1.write('_myvar = None\n') + self.mod1.write("_myvar = None\n") self.importer.update_resource(self.mod1, underlined=False) - self.assertEqual([], self.importer.get_modules('_myvar')) + self.assertEqual([], self.importer.get_modules("_myvar")) self.importer.update_resource(self.mod1, underlined=True) - self.assertEqual(['mod1'], self.importer.get_modules('_myvar')) + self.assertEqual(["mod1"], self.importer.get_modules("_myvar")) def test_caching_underlined_names_passing_to_the_constructor(self): importer = autoimport.AutoImport(self.project, False, True) - self.mod1.write('_myvar = None\n') + self.mod1.write("_myvar = None\n") importer.update_resource(self.mod1) - self.assertEqual(['mod1'], importer.get_modules('_myvar')) + self.assertEqual(["mod1"], importer.get_modules("_myvar")) def test_name_locations(self): - self.mod1.write('myvar = None\n') + self.mod1.write("myvar = None\n") self.importer.update_resource(self.mod1) - self.assertEqual([(self.mod1, 1)], - self.importer.get_name_locations('myvar')) + self.assertEqual([(self.mod1, 1)], self.importer.get_name_locations("myvar")) def test_name_locations_with_multiple_occurrences(self): - self.mod1.write('myvar = None\n') - self.mod2.write('\nmyvar = None\n') + self.mod1.write("myvar = None\n") + self.mod2.write("\nmyvar = None\n") self.importer.update_resource(self.mod1) self.importer.update_resource(self.mod2) - self.assertEqual(set([(self.mod1, 1), (self.mod2, 2)]), - set(self.importer.get_name_locations('myvar'))) + self.assertEqual( + set([(self.mod1, 1), (self.mod2, 2)]), + set(self.importer.get_name_locations("myvar")), + ) def test_handling_builtin_modules(self): - self.importer.update_module('sys') - self.assertTrue('sys' in self.importer.get_modules('exit')) + self.importer.update_module("sys") + self.assertTrue("sys" in self.importer.get_modules("exit")) def test_submodules(self): - self.assertEqual(set([self.mod1]), - autoimport.submodules(self.mod1)) - self.assertEqual(set([self.mod2, self.pkg]), - autoimport.submodules(self.pkg)) + self.assertEqual(set([self.mod1]), autoimport.submodules(self.mod1)) + self.assertEqual(set([self.mod2, self.pkg]), autoimport.submodules(self.pkg)) class AutoImportObservingTest(unittest.TestCase): - def setUp(self): super(AutoImportObservingTest, self).setUp() self.project = testutils.sample_project() - self.mod1 = testutils.create_module(self.project, 'mod1') - self.pkg = testutils.create_package(self.project, 'pkg') - self.mod2 = testutils.create_module(self.project, 'mod2', self.pkg) + self.mod1 = testutils.create_module(self.project, "mod1") + self.pkg = testutils.create_package(self.project, "pkg") + self.mod2 = testutils.create_module(self.project, "mod2", self.pkg) self.importer = autoimport.AutoImport(self.project, observe=True) def tearDown(self): @@ -146,15 +139,15 @@ def tearDown(self): super(AutoImportObservingTest, self).tearDown() def test_writing_files(self): - self.mod1.write('myvar = None\n') - self.assertEqual(['mod1'], self.importer.get_modules('myvar')) + self.mod1.write("myvar = None\n") + self.assertEqual(["mod1"], self.importer.get_modules("myvar")) def test_moving_files(self): - self.mod1.write('myvar = None\n') - self.mod1.move('mod3.py') - self.assertEqual(['mod3'], self.importer.get_modules('myvar')) + self.mod1.write("myvar = None\n") + self.mod1.move("mod3.py") + self.assertEqual(["mod3"], self.importer.get_modules("myvar")) def test_removing_files(self): - self.mod1.write('myvar = None\n') + self.mod1.write("myvar = None\n") self.mod1.remove() - self.assertEqual([], self.importer.get_modules('myvar')) + self.assertEqual([], self.importer.get_modules("myvar")) diff --git a/ropetest/contrib/changestacktest.py b/ropetest/contrib/changestacktest.py index ffbd86dfc..89c35d6eb 100644 --- a/ropetest/contrib/changestacktest.py +++ b/ropetest/contrib/changestacktest.py @@ -11,7 +11,6 @@ class ChangeStackTest(unittest.TestCase): - def setUp(self): super(ChangeStackTest, self).setUp() self.project = testutils.sample_project() @@ -21,15 +20,15 @@ def tearDown(self): super(ChangeStackTest, self).tearDown() def test_change_stack(self): - myfile = self.project.root.create_file('myfile.txt') - myfile.write('1') + myfile = self.project.root.create_file("myfile.txt") + myfile.write("1") stack = rope.contrib.changestack.ChangeStack(self.project) - stack.push(rope.base.change.ChangeContents(myfile, '2')) - self.assertEqual('2', myfile.read()) - stack.push(rope.base.change.ChangeContents(myfile, '3')) - self.assertEqual('3', myfile.read()) + stack.push(rope.base.change.ChangeContents(myfile, "2")) + self.assertEqual("2", myfile.read()) + stack.push(rope.base.change.ChangeContents(myfile, "3")) + self.assertEqual("3", myfile.read()) stack.pop_all() - self.assertEqual('1', myfile.read()) + self.assertEqual("1", myfile.read()) changes = stack.merged() self.project.do(changes) - self.assertEqual('3', myfile.read()) + self.assertEqual("3", myfile.read()) diff --git a/ropetest/contrib/codeassisttest.py b/ropetest/contrib/codeassisttest.py index 37adf4b34..8bd3d38ae 100644 --- a/ropetest/contrib/codeassisttest.py +++ b/ropetest/contrib/codeassisttest.py @@ -2,16 +2,23 @@ import os.path from textwrap import dedent + try: import unittest2 as unittest except ImportError: import unittest from rope.base import exceptions -from rope.contrib.codeassist import (get_definition_location, get_doc, - starting_expression, code_assist, - sorted_proposals, starting_offset, - get_calltip, get_canonical_path) +from rope.contrib.codeassist import ( + get_definition_location, + get_doc, + starting_expression, + code_assist, + sorted_proposals, + starting_offset, + get_calltip, + get_canonical_path, +) from ropetest import testutils try: @@ -21,7 +28,6 @@ class CodeAssistTest(unittest.TestCase): - def setUp(self): super(CodeAssistTest, self).setUp() self.project = testutils.sample_project() @@ -36,440 +42,473 @@ def _assist(self, code, offset=None, **args): return code_assist(self.project, code, offset, **args) def test_simple_assist(self): - self._assist('', 0) + self._assist("", 0) def assert_completion_in_result(self, name, scope, result, type=None): for proposal in result: if proposal.name == name: - self.assertEqual(scope, proposal.scope, - "proposal <%s> has wrong scope, expected " - "%r, got %r" % (name, scope, proposal.scope)) + self.assertEqual( + scope, + proposal.scope, + "proposal <%s> has wrong scope, expected " + "%r, got %r" % (name, scope, proposal.scope), + ) if type is not None: - self.assertEqual(type, proposal.type, - "proposal <%s> has wrong type, expected " - "%r, got %r" % - (name, type, proposal.type)) + self.assertEqual( + type, + proposal.type, + "proposal <%s> has wrong type, expected " + "%r, got %r" % (name, type, proposal.type), + ) return - self.fail('completion <%s> not proposed' % name) + self.fail("completion <%s> not proposed" % name) def assert_completion_not_in_result(self, name, scope, result): for proposal in result: if proposal.name == name and proposal.scope == scope: - self.fail('completion <%s> was proposed' % name) + self.fail("completion <%s> was proposed" % name) def test_completing_global_variables(self): - code = 'my_global = 10\nt = my' + code = "my_global = 10\nt = my" result = self._assist(code) - self.assert_completion_in_result('my_global', 'global', result) + self.assert_completion_in_result("my_global", "global", result) def test_not_proposing_unmatched_vars(self): - code = 'my_global = 10\nt = you' + code = "my_global = 10\nt = you" result = self._assist(code) - self.assert_completion_not_in_result('my_global', 'global', result) + self.assert_completion_not_in_result("my_global", "global", result) def test_not_proposing_unmatched_vars_with_underlined_starting(self): - code = 'my_global = 10\nt = your_' + code = "my_global = 10\nt = your_" result = self._assist(code) - self.assert_completion_not_in_result('my_global', 'global', result) + self.assert_completion_not_in_result("my_global", "global", result) def test_not_proposing_local_assigns_as_global_completions(self): - code = 'def f(): my_global = 10\nt = my_' + code = "def f(): my_global = 10\nt = my_" result = self._assist(code) - self.assert_completion_not_in_result('my_global', 'global', result) + self.assert_completion_not_in_result("my_global", "global", result) def test_proposing_functions(self): - code = 'def my_func(): return 2\nt = my_' + code = "def my_func(): return 2\nt = my_" result = self._assist(code) - self.assert_completion_in_result('my_func', 'global', result) + self.assert_completion_in_result("my_func", "global", result) def test_proposing_classes(self): - code = 'class Sample(object): pass\nt = Sam' + code = "class Sample(object): pass\nt = Sam" result = self._assist(code) - self.assert_completion_in_result('Sample', 'global', result) + self.assert_completion_in_result("Sample", "global", result) def test_proposing_each_name_at_most_once(self): - code = 'variable = 10\nvariable = 20\nt = vari' + code = "variable = 10\nvariable = 20\nt = vari" result = self._assist(code) - count = len([x for x in result - if x.name == 'variable' and x.scope == 'global']) + count = len([x for x in result if x.name == "variable" and x.scope == "global"]) self.assertEqual(1, count) def test_throwing_exception_in_case_of_syntax_errors(self): - code = 'sample (sdf+)\n' + code = "sample (sdf+)\n" with self.assertRaises(exceptions.ModuleSyntaxError): self._assist(code, maxfixes=0) def test_fixing_errors_with_maxfixes(self): - code = 'def f():\n sldj sldj\ndef g():\n ran' + code = "def f():\n sldj sldj\ndef g():\n ran" result = self._assist(code, maxfixes=2) self.assertTrue(len(result) > 0) def test_ignoring_errors_in_current_line(self): - code = 'def my_func():\n return 2\nt = ' + code = "def my_func():\n return 2\nt = " result = self._assist(code) - self.assert_completion_in_result('my_func', 'global', result) + self.assert_completion_in_result("my_func", "global", result) def test_not_reporting_variables_in_current_line(self): - code = 'def my_func(): return 2\nt = my_' + code = "def my_func(): return 2\nt = my_" result = self._assist(code) - self.assert_completion_not_in_result('my_', 'global', result) + self.assert_completion_not_in_result("my_", "global", result) def test_completion_result(self): - code = 'my_global = 10\nt = my' + code = "my_global = 10\nt = my" self.assertEqual(len(code) - 2, starting_offset(code, len(code))) def test_completing_imported_names(self): - code = 'import sys\na = sy' + code = "import sys\na = sy" result = self._assist(code) - self.assert_completion_in_result('sys', 'imported', result) + self.assert_completion_in_result("sys", "imported", result) def test_completing_imported_names_with_as(self): - code = 'import sys as mysys\na = mys' + code = "import sys as mysys\na = mys" result = self._assist(code) - self.assert_completion_in_result('mysys', 'imported', result) + self.assert_completion_in_result("mysys", "imported", result) def test_not_completing_imported_names_with_as(self): - code = 'import sys as mysys\na = sy' + code = "import sys as mysys\na = sy" result = self._assist(code) - self.assert_completion_not_in_result('sys', 'global', result) + self.assert_completion_not_in_result("sys", "global", result) def test_including_matching_builtins_types(self): - code = 'my_var = Excep' + code = "my_var = Excep" result = self._assist(code) - self.assert_completion_in_result('Exception', 'builtin', result) - self.assert_completion_not_in_result('zip', 'builtin', result) + self.assert_completion_in_result("Exception", "builtin", result) + self.assert_completion_not_in_result("zip", "builtin", result) def test_including_matching_builtins_functions(self): - code = 'my_var = zi' + code = "my_var = zi" result = self._assist(code) - self.assert_completion_in_result('zip', 'builtin', result) + self.assert_completion_in_result("zip", "builtin", result) def test_builtin_instances(self): # ``import_dynload_stdmods`` pref is disabled for test project. # we need to have it enabled to make pycore._find_module() # load ``sys`` module. - self.project.prefs['import_dynload_stdmods'] = True - code = 'from sys import stdout\nstdout.wr' + self.project.prefs["import_dynload_stdmods"] = True + code = "from sys import stdout\nstdout.wr" result = self._assist(code) - self.assert_completion_in_result('write', 'builtin', result) - self.assert_completion_in_result('writelines', 'builtin', result) + self.assert_completion_in_result("write", "builtin", result) + self.assert_completion_in_result("writelines", "builtin", result) def test_including_keywords(self): - code = 'fo' + code = "fo" result = self._assist(code) - self.assert_completion_in_result('for', 'keyword', result) + self.assert_completion_in_result("for", "keyword", result) def test_not_reporting_proposals_after_dot(self): - code = 'a_dict = {}\nkey = 3\na_dict.ke' + code = "a_dict = {}\nkey = 3\na_dict.ke" result = self._assist(code) - self.assert_completion_not_in_result('key', 'global', result) + self.assert_completion_not_in_result("key", "global", result) def test_proposing_local_variables_in_functions(self): - code = 'def f(self):\n my_var = 10\n my_' + code = "def f(self):\n my_var = 10\n my_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'local', result) + self.assert_completion_in_result("my_var", "local", result) def test_local_variables_override_global_ones(self): - code = 'my_var = 20\ndef f(self):\n my_var = 10\n my_' + code = "my_var = 20\ndef f(self):\n my_var = 10\n my_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'local', result) + self.assert_completion_in_result("my_var", "local", result) def test_not_including_class_body_variables(self): - code = 'class C(object):\n my_var = 20\n' \ - ' def f(self):\n a = 20\n my_' + code = ( + "class C(object):\n my_var = 20\n" + " def f(self):\n a = 20\n my_" + ) result = self._assist(code) - self.assert_completion_not_in_result('my_var', 'local', result) + self.assert_completion_not_in_result("my_var", "local", result) def test_nested_functions(self): - code = 'def my_func():\n func_var = 20\n ' \ - 'def inner_func():\n a = 20\n func' + code = ( + "def my_func():\n func_var = 20\n " + "def inner_func():\n a = 20\n func" + ) result = self._assist(code) - self.assert_completion_in_result('func_var', 'local', result) + self.assert_completion_in_result("func_var", "local", result) def test_scope_endpoint_selection(self): code = "def my_func():\n func_var = 20\n" result = self._assist(code) - self.assert_completion_not_in_result('func_var', 'local', result) + self.assert_completion_not_in_result("func_var", "local", result) def test_scope_better_endpoint_selection(self): code = "if True:\n def f():\n my_var = 10\n my_" result = self._assist(code) - self.assert_completion_not_in_result('my_var', 'local', result) + self.assert_completion_not_in_result("my_var", "local", result) def test_imports_inside_function(self): code = "def f():\n import sys\n sy" result = self._assist(code) - self.assert_completion_in_result('sys', 'imported', result) + self.assert_completion_in_result("sys", "imported", result) def test_imports_inside_function_dont_mix_with_globals(self): code = "def f():\n import sys\nsy" result = self._assist(code) - self.assert_completion_not_in_result('sys', 'local', result) + self.assert_completion_not_in_result("sys", "local", result) def test_nested_classes_local_names(self): - code = 'global_var = 10\n' \ - 'def my_func():\n' \ - ' func_var = 20\n' \ - ' class C(object):\n' \ - ' def another_func(self):\n' \ - ' local_var = 10\n' \ - ' func' - result = self._assist(code) - self.assert_completion_in_result('func_var', 'local', result) + code = ( + "global_var = 10\n" + "def my_func():\n" + " func_var = 20\n" + " class C(object):\n" + " def another_func(self):\n" + " local_var = 10\n" + " func" + ) + result = self._assist(code) + self.assert_completion_in_result("func_var", "local", result) def test_nested_classes_global(self): - code = 'global_var = 10\n' \ - 'def my_func():\n' \ - ' func_var = 20\n' \ - ' class C(object):\n' \ - ' def another_func(self):\n' \ - ' local_var = 10\n' \ - ' globa' - result = self._assist(code) - self.assert_completion_in_result('global_var', 'global', result) + code = ( + "global_var = 10\n" + "def my_func():\n" + " func_var = 20\n" + " class C(object):\n" + " def another_func(self):\n" + " local_var = 10\n" + " globa" + ) + result = self._assist(code) + self.assert_completion_in_result("global_var", "global", result) def test_nested_classes_global_function(self): - code = 'global_var = 10\n' \ - 'def my_func():\n' \ - ' func_var = 20\n' \ - ' class C(object):\n' \ - ' def another_func(self):\n' \ - ' local_var = 10\n' \ - ' my_f' - result = self._assist(code) - self.assert_completion_in_result('my_func', 'global', result) + code = ( + "global_var = 10\n" + "def my_func():\n" + " func_var = 20\n" + " class C(object):\n" + " def another_func(self):\n" + " local_var = 10\n" + " my_f" + ) + result = self._assist(code) + self.assert_completion_in_result("my_func", "global", result) def test_proposing_function_parameters_in_functions(self): - code = 'def my_func(my_param):\n my_var = 20\n my_' + code = "def my_func(my_param):\n my_var = 20\n my_" result = self._assist(code) - self.assert_completion_in_result('my_param', 'local', result) + self.assert_completion_in_result("my_param", "local", result) def test_proposing_function_keyword_parameters_in_functions(self): - code = 'def my_func(my_param, *my_list, **my_kws):\n' \ - ' my_var = 20\n' \ - ' my_' + code = ( + "def my_func(my_param, *my_list, **my_kws):\n" " my_var = 20\n" " my_" + ) result = self._assist(code) - self.assert_completion_in_result('my_param', 'local', result) - self.assert_completion_in_result('my_list', 'local', result) - self.assert_completion_in_result('my_kws', 'local', result) + self.assert_completion_in_result("my_param", "local", result) + self.assert_completion_in_result("my_list", "local", result) + self.assert_completion_in_result("my_kws", "local", result) def test_not_proposing_unmatching_function_parameters_in_functions(self): code = "def my_func(my_param):\n my_var = 20\n you_" result = self._assist(code) - self.assert_completion_not_in_result('my_param', 'local', result) + self.assert_completion_not_in_result("my_param", "local", result) def test_ignoring_current_statement(self): code = "my_var = 10\nmy_tuple = (10, \n my_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_current_statement_brackets_continuation(self): code = "my_var = 10\n'hello'[10:\n my_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_current_statement_explicit_continuation(self): code = "my_var = 10\nmy_var2 = 2 + \\\n my_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignor_current_statement_while_the_first_stmnt_of_the_block(self): code = "my_var = 10\ndef f():\n my_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignor_current_stmnt_while_current_line_ends_with_a_colon(self): code = "my_var = 10\nif my_:\n pass" result = self._assist(code, 18) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_string_contents(self): code = "my_var = '('\nmy_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_comment_contents(self): code = "my_var = 10 #(\nmy_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_string_contents_backslash_plus_quotes(self): code = "my_var = '\\''\nmy_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_string_contents_backslash_plus_backslash(self): code = "my_var = '\\\\'\nmy_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_not_proposing_later_defined_variables_in_current_block(self): code = "my_\nmy_var = 10\n" result = self._assist(code, 3, later_locals=False) - self.assert_completion_not_in_result('my_var', 'global', result) + self.assert_completion_not_in_result("my_var", "global", result) def test_not_proposing_later_defined_variables_in_current_function(self): code = "def f():\n my_\n my_var = 10\n" result = self._assist(code, 16, later_locals=False) - self.assert_completion_not_in_result('my_var', 'local', result) + self.assert_completion_not_in_result("my_var", "local", result) def test_ignoring_string_contents_with_triple_quotes(self): code = "my_var = '''(\n'('''\nmy_" result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignoring_string_contents_with_triple_quotes_and_backslash(self): code = 'my_var = """\\"""("""\nmy_' result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_ignor_str_contents_with_triple_quotes_and_double_backslash(self): code = 'my_var = """\\\\"""\nmy_' result = self._assist(code) - self.assert_completion_in_result('my_var', 'global', result) + self.assert_completion_in_result("my_var", "global", result) def test_reporting_params_when_in_the_first_line_of_a_function(self): - code = 'def f(param):\n para' + code = "def f(param):\n para" result = self._assist(code) - self.assert_completion_in_result('param', 'local', result) + self.assert_completion_in_result("param", "local", result) def test_code_assist_when_having_a_two_line_function_header(self): - code = 'def f(param1,\n param2):\n para' + code = "def f(param1,\n param2):\n para" result = self._assist(code) - self.assert_completion_in_result('param1', 'local', result) + self.assert_completion_in_result("param1", "local", result) def test_code_assist_with_function_with_two_line_return(self): - code = 'def f(param1, param2):\n return(param1,\n para' + code = "def f(param1, param2):\n return(param1,\n para" result = self._assist(code) - self.assert_completion_in_result('param2', 'local', result) + self.assert_completion_in_result("param2", "local", result) def test_get_definition_location(self): - code = 'def a_func():\n pass\na_func()' + code = "def a_func():\n pass\na_func()" result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 1), result) def test_get_definition_location_underlined_names(self): - code = 'def a_sample_func():\n pass\na_sample_func()' + code = "def a_sample_func():\n pass\na_sample_func()" result = get_definition_location(self.project, code, len(code) - 11) self.assertEqual((None, 1), result) def test_get_definition_location_dotted_names_method(self): - code = 'class AClass(object):\n' \ - ' @staticmethod\n' \ - ' def a_method():\n' \ - ' pass\n' \ - 'AClass.a_method()' + code = ( + "class AClass(object):\n" + " @staticmethod\n" + " def a_method():\n" + " pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 3), result) def test_get_definition_location_dotted_names_property(self): - code = 'class AClass(object):\n' \ - ' @property\n' \ - ' @somedecorator\n' \ - ' def a_method():\n' \ - ' pass\n' \ - 'AClass.a_method()' + code = ( + "class AClass(object):\n" + " @property\n" + " @somedecorator\n" + " def a_method():\n" + " pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 4), result) def test_get_definition_location_dotted_names_free_function(self): - code = '@custom_decorator\n' \ - 'def a_method():\n' \ - ' pass\n' \ - 'a_method()' + code = "@custom_decorator\n" "def a_method():\n" " pass\n" "a_method()" result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 2), result) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_get_definition_location_dotted_names_async_def(self): - code = 'class AClass(object):\n' \ - ' @property\n' \ - ' @decorator2\n' \ - ' async def a_method():\n' \ - ' pass\n' \ - 'AClass.a_method()' + code = ( + "class AClass(object):\n" + " @property\n" + " @decorator2\n" + " async def a_method():\n" + " pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 4), result) def test_get_definition_location_dotted_names_class(self): - code = '@custom_decorator\n' \ - 'class AClass(object):\n' \ - ' def a_method():\n' \ - ' pass\n' \ - 'AClass.a_method()' + code = ( + "@custom_decorator\n" + "class AClass(object):\n" + " def a_method():\n" + " pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 12) self.assertEqual((None, 2), result) def test_get_definition_location_dotted_names_with_space(self): - code = 'class AClass(object):\n' \ - ' @staticmethod\n' \ - ' def a_method():\n' \ - ' \n' \ - ' pass\n' \ - 'AClass.a_method()' + code = ( + "class AClass(object):\n" + " @staticmethod\n" + " def a_method():\n" + " \n" + " pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 3), result) def test_get_definition_location_dotted_names_inline_body(self): - code = 'class AClass(object):\n' \ - ' @staticmethod\n' \ - ' def a_method(): pass\n' \ - 'AClass.a_method()' + code = ( + "class AClass(object):\n" + " @staticmethod\n" + " def a_method(): pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 3), result) def test_get_definition_location_dotted_names_inline_body_split_arg(self): - code = 'class AClass(object):\n' \ - ' @staticmethod\n' \ - ' def a_method(\n' \ - ' self,\n' \ - ' arg1\n' \ - ' ): pass\n' \ - 'AClass.a_method()' + code = ( + "class AClass(object):\n" + " @staticmethod\n" + " def a_method(\n" + " self,\n" + " arg1\n" + " ): pass\n" + "AClass.a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 3), result) def test_get_definition_location_dotted_module_names(self): - module_resource = testutils.create_module(self.project, 'mod') - module_resource.write('def a_func():\n pass\n') - code = 'import mod\nmod.a_func()' + module_resource = testutils.create_module(self.project, "mod") + module_resource.write("def a_func():\n pass\n") + code = "import mod\nmod.a_func()" result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((module_resource, 1), result) def test_get_definition_location_for_nested_packages(self): - mod1 = testutils.create_module(self.project, 'mod1') - pkg1 = testutils.create_package(self.project, 'pkg1') - pkg2 = testutils.create_package(self.project, 'pkg2', pkg1) - mod1.write('import pkg1.pkg2.mod2') - - init_dot_py = pkg2.get_child('__init__.py') - found_pyname = get_definition_location(self.project, mod1.read(), - mod1.read().index('pkg2') + 1) + mod1 = testutils.create_module(self.project, "mod1") + pkg1 = testutils.create_package(self.project, "pkg1") + pkg2 = testutils.create_package(self.project, "pkg2", pkg1) + mod1.write("import pkg1.pkg2.mod2") + + init_dot_py = pkg2.get_child("__init__.py") + found_pyname = get_definition_location( + self.project, mod1.read(), mod1.read().index("pkg2") + 1 + ) self.assertEqual(init_dot_py, found_pyname[0]) def test_get_definition_location_unknown(self): - code = 'a_func()\n' + code = "a_func()\n" result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, None), result) def test_get_definition_location_dot_spaces(self): - code = 'class AClass(object):\n ' \ - '@staticmethod\n def a_method():\n' \ - ' pass\nAClass.\\\n a_method()' + code = ( + "class AClass(object):\n " + "@staticmethod\n def a_method():\n" + " pass\nAClass.\\\n a_method()" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 3), result) def test_get_definition_location_dot_line_break_inside_parens(self): - code = 'class A(object):\n def a_method(self):\n pass\n' + \ - '(A.\na_method)' - result = get_definition_location(self.project, code, - code.rindex('a_method') + 1) + code = ( + "class A(object):\n def a_method(self):\n pass\n" + + "(A.\na_method)" + ) + result = get_definition_location( + self.project, code, code.rindex("a_method") + 1 + ) self.assertEqual((None, 2), result) def test_if_scopes_in_other_scopes_for_get_definition_location(self): - code = 'def f(a_var):\n pass\na_var = 10\n' + \ - 'if True:\n' + \ - ' print(a_var)\n' + code = ( + "def f(a_var):\n pass\na_var = 10\n" + + "if True:\n" + + " print(a_var)\n" + ) result = get_definition_location(self.project, code, len(code) - 3) self.assertEqual((None, 3), result) @@ -487,508 +526,495 @@ def bar(): self.assertEqual((None, 6), result) def test_code_assists_in_parens(self): - code = 'def a_func(a_var):\n pass\na_var = 10\na_func(a_' + code = "def a_func(a_var):\n pass\na_var = 10\na_func(a_" result = self._assist(code) - self.assert_completion_in_result('a_var', 'global', result) + self.assert_completion_in_result("a_var", "global", result) def test_simple_type_inferencing(self): - code = 'class Sample(object):\n' \ - ' def __init__(self, a_param):\n' \ - ' pass\n' \ - ' def a_method(self):\n' \ - ' pass\n' \ - 'Sample("hey").a_' + code = ( + "class Sample(object):\n" + " def __init__(self, a_param):\n" + " pass\n" + " def a_method(self):\n" + " pass\n" + 'Sample("hey").a_' + ) result = self._assist(code) - self.assert_completion_in_result('a_method', 'attribute', result) + self.assert_completion_in_result("a_method", "attribute", result) def test_proposals_sorter(self): - code = 'def my_sample_function(self):\n' + \ - ' my_sample_var = 20\n' + \ - ' my_sample_' + code = ( + "def my_sample_function(self):\n" + + " my_sample_var = 20\n" + + " my_sample_" + ) proposals = sorted_proposals(self._assist(code)) - self.assertEqual('my_sample_var', proposals[0].name) - self.assertEqual('my_sample_function', proposals[1].name) + self.assertEqual("my_sample_var", proposals[0].name) + self.assertEqual("my_sample_function", proposals[1].name) def test_proposals_sorter_for_methods_and_attributes(self): - code = 'class A(object):\n' + \ - ' def __init__(self):\n' + \ - ' self.my_a_var = 10\n' + \ - ' def my_b_func(self):\n' + \ - ' pass\n' + \ - ' def my_c_func(self):\n' + \ - ' pass\n' + \ - 'a_var = A()\n' + \ - 'a_var.my_' + code = ( + "class A(object):\n" + + " def __init__(self):\n" + + " self.my_a_var = 10\n" + + " def my_b_func(self):\n" + + " pass\n" + + " def my_c_func(self):\n" + + " pass\n" + + "a_var = A()\n" + + "a_var.my_" + ) proposals = sorted_proposals(self._assist(code)) - self.assertEqual('my_b_func', proposals[0].name) - self.assertEqual('my_c_func', proposals[1].name) - self.assertEqual('my_a_var', proposals[2].name) + self.assertEqual("my_b_func", proposals[0].name) + self.assertEqual("my_c_func", proposals[1].name) + self.assertEqual("my_a_var", proposals[2].name) def test_proposals_sorter_for_global_methods_and_funcs(self): - code = 'def my_b_func(self):\n' + \ - ' pass\n' + \ - 'my_a_var = 10\n' + \ - 'my_' + code = "def my_b_func(self):\n" + " pass\n" + "my_a_var = 10\n" + "my_" proposals = sorted_proposals(self._assist(code)) - self.assertEqual('my_b_func', proposals[0].name) - self.assertEqual('my_a_var', proposals[1].name) + self.assertEqual("my_b_func", proposals[0].name) + self.assertEqual("my_a_var", proposals[1].name) def test_proposals_sorter_underlined_methods(self): - code = 'class A(object):\n' + \ - ' def _my_func(self):\n' + \ - ' self.my_a_var = 10\n' + \ - ' def my_func(self):\n' + \ - ' pass\n' + \ - 'a_var = A()\n' + \ - 'a_var.' + code = ( + "class A(object):\n" + + " def _my_func(self):\n" + + " self.my_a_var = 10\n" + + " def my_func(self):\n" + + " pass\n" + + "a_var = A()\n" + + "a_var." + ) proposals = sorted_proposals(self._assist(code)) - self.assertEqual('my_func', proposals[0].name) - self.assertEqual('_my_func', proposals[1].name) + self.assertEqual("my_func", proposals[0].name) + self.assertEqual("_my_func", proposals[1].name) def test_proposals_sorter_and_scope_prefs(self): - code = 'my_global_var = 1\n' \ - 'def func(self):\n' \ - ' my_local_var = 2\n' \ - ' my_' + code = ( + "my_global_var = 1\n" "def func(self):\n" " my_local_var = 2\n" " my_" + ) result = self._assist(code) - proposals = sorted_proposals(result, scopepref=['global', 'local']) - self.assertEqual('my_global_var', proposals[0].name) - self.assertEqual('my_local_var', proposals[1].name) + proposals = sorted_proposals(result, scopepref=["global", "local"]) + self.assertEqual("my_global_var", proposals[0].name) + self.assertEqual("my_local_var", proposals[1].name) def test_proposals_sorter_and_type_prefs(self): - code = 'my_global_var = 1\n' \ - 'def my_global_func(self):\n' \ - ' pass\n' \ - 'my_' + code = "my_global_var = 1\n" "def my_global_func(self):\n" " pass\n" "my_" result = self._assist(code) - proposals = sorted_proposals(result, typepref=['instance', 'function']) - self.assertEqual('my_global_var', proposals[0].name) - self.assertEqual('my_global_func', proposals[1].name) + proposals = sorted_proposals(result, typepref=["instance", "function"]) + self.assertEqual("my_global_var", proposals[0].name) + self.assertEqual("my_global_func", proposals[1].name) def test_proposals_sorter_and_missing_type_in_typepref(self): - code = 'my_global_var = 1\n' \ - 'def my_global_func():\n' \ - ' pass\n' \ - 'my_' + code = "my_global_var = 1\n" "def my_global_func():\n" " pass\n" "my_" result = self._assist(code) - proposals = sorted_proposals(result, typepref=['function']) # noqa + proposals = sorted_proposals(result, typepref=["function"]) # noqa def test_get_pydoc_unicode(self): src = u'# coding: utf-8\ndef foo():\n u"юникод-объект"' - doc = get_doc(self.project, src, src.index('foo') + 1) + doc = get_doc(self.project, src, src.index("foo") + 1) self.assertTrue(isinstance(doc, unicode)) - self.assertTrue(u'юникод-объект' in doc) + self.assertTrue(u"юникод-объект" in doc) def test_get_pydoc_utf8_bytestring(self): src = u'# coding: utf-8\ndef foo():\n "байтстринг"' - doc = get_doc(self.project, src, src.index('foo') + 1) + doc = get_doc(self.project, src, src.index("foo") + 1) self.assertTrue(isinstance(doc, unicode)) - self.assertTrue(u'байтстринг' in doc) + self.assertTrue(u"байтстринг" in doc) def test_get_pydoc_for_functions(self): - src = 'def a_func():\n' \ - ' """a function"""\n' \ - ' a_var = 10\n' \ - 'a_func()' - self.assertTrue(get_doc(self.project, src, len(src) - 4). - endswith('a function')) - get_doc(self.project, src, len(src) - 4).index('a_func()') + src = "def a_func():\n" ' """a function"""\n' " a_var = 10\n" "a_func()" + self.assertTrue(get_doc(self.project, src, len(src) - 4).endswith("a function")) + get_doc(self.project, src, len(src) - 4).index("a_func()") def test_get_pydoc_for_classes(self): - src = 'class AClass(object):\n pass\n' - get_doc(self.project, src, src.index('AClass') + 1).index('AClass') + src = "class AClass(object):\n pass\n" + get_doc(self.project, src, src.index("AClass") + 1).index("AClass") def test_get_pydoc_for_classes_with_init(self): - src = 'class AClass(object):\n def __init__(self):\n pass\n' - get_doc(self.project, src, src.index('AClass') + 1).index('AClass') + src = "class AClass(object):\n def __init__(self):\n pass\n" + get_doc(self.project, src, src.index("AClass") + 1).index("AClass") def test_get_pydoc_for_modules(self): - mod = testutils.create_module(self.project, 'mod') + mod = testutils.create_module(self.project, "mod") mod.write('"""a module"""\n') - src = 'import mod\nmod' - self.assertEqual('a module', get_doc(self.project, src, len(src) - 1)) + src = "import mod\nmod" + self.assertEqual("a module", get_doc(self.project, src, len(src) - 1)) def test_get_pydoc_for_builtins(self): - src = 'print(object)\n' - self.assertTrue(get_doc(self.project, src, - src.index('obj')) is not None) + src = "print(object)\n" + self.assertTrue(get_doc(self.project, src, src.index("obj")) is not None) def test_get_pydoc_for_methods_should_include_class_name(self): - src = 'class AClass(object):\n' \ - ' def a_method(self):\n'\ - ' """hey"""\n' \ - ' pass\n' - doc = get_doc(self.project, src, src.index('a_method') + 1) - doc.index('AClass.a_method') - doc.index('hey') + src = ( + "class AClass(object):\n" + " def a_method(self):\n" + ' """hey"""\n' + " pass\n" + ) + doc = get_doc(self.project, src, src.index("a_method") + 1) + doc.index("AClass.a_method") + doc.index("hey") def test_get_pydoc_for_meths_should_inc_methods_from_super_classes(self): - src = 'class A(object):\n' \ - ' def a_method(self):\n' \ - ' """hey1"""\n' \ - ' pass\n' \ - 'class B(A):\n' \ - ' def a_method(self):\n' \ - ' """hey2"""\n' \ - ' pass\n' - doc = get_doc(self.project, src, src.rindex('a_method') + 1) - doc.index('A.a_method') - doc.index('hey1') - doc.index('B.a_method') - doc.index('hey2') + src = ( + "class A(object):\n" + " def a_method(self):\n" + ' """hey1"""\n' + " pass\n" + "class B(A):\n" + " def a_method(self):\n" + ' """hey2"""\n' + " pass\n" + ) + doc = get_doc(self.project, src, src.rindex("a_method") + 1) + doc.index("A.a_method") + doc.index("hey1") + doc.index("B.a_method") + doc.index("hey2") def test_get_pydoc_for_classes_should_name_super_classes(self): - src = 'class A(object):\n pass\n' \ - 'class B(A):\n pass\n' - doc = get_doc(self.project, src, src.rindex('B') + 1) - doc.index('B(A)') + src = "class A(object):\n pass\n" "class B(A):\n pass\n" + doc = get_doc(self.project, src, src.rindex("B") + 1) + doc.index("B(A)") def test_get_pydoc_for_builtin_functions(self): src = 's = "hey"\ns.replace\n' - doc = get_doc(self.project, src, src.rindex('replace') + 1) + doc = get_doc(self.project, src, src.rindex("replace") + 1) self.assertTrue(doc is not None) def test_commenting_errors_before_offset(self): src = 'lsjd lsjdf\ns = "hey"\ns.replace()\n' - doc = get_doc(self.project, src, src.rindex('replace') + 1) # noqa + doc = get_doc(self.project, src, src.rindex("replace") + 1) # noqa def test_proposing_variables_defined_till_the_end_of_scope(self): - code = 'if True:\n a_v\na_var = 10\n' - result = self._assist(code, code.index('a_v') + 3) - self.assert_completion_in_result('a_var', 'global', result) + code = "if True:\n a_v\na_var = 10\n" + result = self._assist(code, code.index("a_v") + 3) + self.assert_completion_in_result("a_var", "global", result) def test_completing_in_uncomplete_try_blocks(self): - code = 'try:\n a_var = 10\n a_' + code = "try:\n a_var = 10\n a_" result = self._assist(code) - self.assert_completion_in_result('a_var', 'global', result) + self.assert_completion_in_result("a_var", "global", result) def test_completing_in_uncomplete_try_blocks_in_functions(self): - code = 'def a_func():\n try:\n a_var = 10\n a_' + code = "def a_func():\n try:\n a_var = 10\n a_" result = self._assist(code) - self.assert_completion_in_result('a_var', 'local', result) + self.assert_completion_in_result("a_var", "local", result) def test_already_complete_try_blocks_with_finally(self): - code = 'def a_func():\n try:\n a_var = 10\n a_' + code = "def a_func():\n try:\n a_var = 10\n a_" result = self._assist(code) - self.assert_completion_in_result('a_var', 'local', result) + self.assert_completion_in_result("a_var", "local", result) def test_already_complete_try_blocks_with_finally2(self): - code = 'try:\n a_var = 10\n a_\nfinally:\n pass\n' - result = self._assist(code, code.rindex('a_') + 2) - self.assert_completion_in_result('a_var', 'global', result) + code = "try:\n a_var = 10\n a_\nfinally:\n pass\n" + result = self._assist(code, code.rindex("a_") + 2) + self.assert_completion_in_result("a_var", "global", result) def test_already_complete_try_blocks_with_except(self): - code = 'try:\n a_var = 10\n a_\nexcept Exception:\n pass\n' - result = self._assist(code, code.rindex('a_') + 2) - self.assert_completion_in_result('a_var', 'global', result) + code = "try:\n a_var = 10\n a_\nexcept Exception:\n pass\n" + result = self._assist(code, code.rindex("a_") + 2) + self.assert_completion_in_result("a_var", "global", result) def test_already_complete_try_blocks_with_except2(self): - code = 'a_var = 10\ntry:\n ' \ - 'another_var = a_\n another_var = 10\n' \ - 'except Exception:\n pass\n' - result = self._assist(code, code.rindex('a_') + 2) - self.assert_completion_in_result('a_var', 'global', result) + code = ( + "a_var = 10\ntry:\n " + "another_var = a_\n another_var = 10\n" + "except Exception:\n pass\n" + ) + result = self._assist(code, code.rindex("a_") + 2) + self.assert_completion_in_result("a_var", "global", result) def test_completing_ifs_in_uncomplete_try_blocks(self): - code = 'try:\n if True:\n a_var = 10\n a_' + code = "try:\n if True:\n a_var = 10\n a_" result = self._assist(code) - self.assert_completion_in_result('a_var', 'global', result) + self.assert_completion_in_result("a_var", "global", result) def test_completing_ifs_in_uncomplete_try_blocks2(self): - code = 'try:\n if True:\n a_var = 10\n a_' + code = "try:\n if True:\n a_var = 10\n a_" result = self._assist(code) - self.assert_completion_in_result('a_var', 'global', result) + self.assert_completion_in_result("a_var", "global", result) def test_completing_excepts_in_uncomplete_try_blocks(self): - code = 'try:\n pass\nexcept Exc' + code = "try:\n pass\nexcept Exc" result = self._assist(code) - self.assert_completion_in_result('Exception', 'builtin', result) + self.assert_completion_in_result("Exception", "builtin", result) def test_and_normal_complete_blocks_and_single_fixing(self): - code = 'try:\n range.\nexcept:\n pass\n' - result = self._assist(code, code.index('.'), maxfixes=1) # noqa + code = "try:\n range.\nexcept:\n pass\n" + result = self._assist(code, code.index("."), maxfixes=1) # noqa def test_nested_blocks(self): - code = 'a_var = 10\ntry:\n try:\n a_v' + code = "a_var = 10\ntry:\n try:\n a_v" result = self._assist(code) - self.assert_completion_in_result('a_var', 'global', result) + self.assert_completion_in_result("a_var", "global", result) def test_proposing_function_keywords_when_calling(self): - code = 'def f(p):\n pass\nf(p' + code = "def f(p):\n pass\nf(p" result = self._assist(code) - self.assert_completion_in_result('p=', 'parameter_keyword', result) + self.assert_completion_in_result("p=", "parameter_keyword", result) def test_proposing_function_keywords_when_calling_for_non_functions(self): - code = 'f = 1\nf(p' + code = "f = 1\nf(p" result = self._assist(code) # noqa def test_proposing_function_keywords_when_calling_extra_spaces(self): - code = 'def f(p):\n pass\nf( p' + code = "def f(p):\n pass\nf( p" result = self._assist(code) - self.assert_completion_in_result('p=', 'parameter_keyword', result) + self.assert_completion_in_result("p=", "parameter_keyword", result) def test_proposing_function_keywords_when_calling_on_second_argument(self): - code = 'def f(p1, p2):\n pass\nf(1, p' + code = "def f(p1, p2):\n pass\nf(1, p" result = self._assist(code) - self.assert_completion_in_result('p2=', 'parameter_keyword', result) + self.assert_completion_in_result("p2=", "parameter_keyword", result) def test_proposing_function_keywords_when_calling_not_proposing_args(self): - code = 'def f(p1, *args):\n pass\nf(1, a' + code = "def f(p1, *args):\n pass\nf(1, a" result = self._assist(code) - self.assert_completion_not_in_result('args=', 'parameter_keyword', - result) + self.assert_completion_not_in_result("args=", "parameter_keyword", result) def test_propos_function_kwrds_when_call_with_no_noth_after_parens(self): - code = 'def f(p):\n pass\nf(' + code = "def f(p):\n pass\nf(" result = self._assist(code) - self.assert_completion_in_result('p=', 'parameter_keyword', result) + self.assert_completion_in_result("p=", "parameter_keyword", result) def test_propos_function_kwrds_when_call_with_no_noth_after_parens2(self): - code = 'def f(p):\n pass\ndef g():\n h = f\n f(' + code = "def f(p):\n pass\ndef g():\n h = f\n f(" result = self._assist(code) - self.assert_completion_in_result('p=', 'parameter_keyword', result) + self.assert_completion_in_result("p=", "parameter_keyword", result) def test_codeassists_before_opening_of_parens(self): - code = 'def f(p):\n pass\na_var = 1\nf(1)\n' - result = self._assist(code, code.rindex('f') + 1) - self.assert_completion_not_in_result('a_var', 'global', result) + code = "def f(p):\n pass\na_var = 1\nf(1)\n" + result = self._assist(code, code.rindex("f") + 1) + self.assert_completion_not_in_result("a_var", "global", result) def test_codeassist_before_single_line_indents(self): - code = 'myvar = 1\nif True:\n (myv\nif True:\n pass\n' - result = self._assist(code, code.rindex('myv') + 3) - self.assert_completion_not_in_result('myvar', 'local', result) + code = "myvar = 1\nif True:\n (myv\nif True:\n pass\n" + result = self._assist(code, code.rindex("myv") + 3) + self.assert_completion_not_in_result("myvar", "local", result) def test_codeassist_before_line_indents_in_a_blank_line(self): - code = 'myvar = 1\nif True:\n \nif True:\n pass\n' - result = self._assist(code, code.rindex(' ') + 4) - self.assert_completion_not_in_result('myvar', 'local', result) + code = "myvar = 1\nif True:\n \nif True:\n pass\n" + result = self._assist(code, code.rindex(" ") + 4) + self.assert_completion_not_in_result("myvar", "local", result) def test_simple_get_calltips(self): - src = 'def f():\n pass\nvar = f()\n' - doc = get_calltip(self.project, src, src.rindex('f')) - self.assertEqual('f()', doc) + src = "def f():\n pass\nvar = f()\n" + doc = get_calltip(self.project, src, src.rindex("f")) + self.assertEqual("f()", doc) def test_get_calltips_for_classes(self): - src = 'class C(object):\n' \ - ' def __init__(self):\n pass\nC(' + src = "class C(object):\n" " def __init__(self):\n pass\nC(" doc = get_calltip(self.project, src, len(src) - 1) - self.assertEqual('C.__init__(self)', doc) + self.assertEqual("C.__init__(self)", doc) def test_get_calltips_for_objects_with_call(self): - src = 'class C(object):\n' \ - ' def __call__(self, p):\n pass\n' \ - 'c = C()\nc(1,' - doc = get_calltip(self.project, src, src.rindex('c')) - self.assertEqual('C.__call__(self, p)', doc) + src = ( + "class C(object):\n" + " def __call__(self, p):\n pass\n" + "c = C()\nc(1," + ) + doc = get_calltip(self.project, src, src.rindex("c")) + self.assertEqual("C.__call__(self, p)", doc) def test_get_calltips_and_including_module_name(self): - src = 'class C(object):\n' \ - ' def __call__(self, p):\n pass\n' \ - 'c = C()\nc(1,' - mod = testutils.create_module(self.project, 'mod') + src = ( + "class C(object):\n" + " def __call__(self, p):\n pass\n" + "c = C()\nc(1," + ) + mod = testutils.create_module(self.project, "mod") mod.write(src) - doc = get_calltip(self.project, src, src.rindex('c'), mod) - self.assertEqual('mod.C.__call__(self, p)', doc) + doc = get_calltip(self.project, src, src.rindex("c"), mod) + self.assertEqual("mod.C.__call__(self, p)", doc) def test_get_calltips_and_including_module_name_2(self): - src = 'range()\n' + src = "range()\n" doc = get_calltip(self.project, src, 1, ignore_unknown=True) self.assertTrue(doc is None) def test_removing_self_parameter(self): - src = 'class C(object):\n' \ - ' def f(self):\n'\ - ' pass\n' \ - 'C().f()' - doc = get_calltip(self.project, src, src.rindex('f'), remove_self=True) - self.assertEqual('C.f()', doc) + src = "class C(object):\n" " def f(self):\n" " pass\n" "C().f()" + doc = get_calltip(self.project, src, src.rindex("f"), remove_self=True) + self.assertEqual("C.f()", doc) def test_removing_self_parameter_and_more_than_one_parameter(self): - src = 'class C(object):\n' \ - ' def f(self, p1):\n'\ - ' pass\n' \ - 'C().f()' - doc = get_calltip(self.project, src, src.rindex('f'), remove_self=True) - self.assertEqual('C.f(p1)', doc) + src = "class C(object):\n" " def f(self, p1):\n" " pass\n" "C().f()" + doc = get_calltip(self.project, src, src.rindex("f"), remove_self=True) + self.assertEqual("C.f(p1)", doc) def test_lambda_calltip(self): - src = 'foo = lambda x, y=1: None\n' \ - 'foo()' - doc = get_calltip(self.project, src, src.rindex('f')) - self.assertEqual(doc, 'lambda(x, y)') + src = "foo = lambda x, y=1: None\n" "foo()" + doc = get_calltip(self.project, src, src.rindex("f")) + self.assertEqual(doc, "lambda(x, y)") def test_keyword_before_parens(self): - code = 'if (1).:\n pass' - result = self._assist(code, offset=len('if (1).')) + code = "if (1).:\n pass" + result = self._assist(code, offset=len("if (1).")) self.assertTrue(result) # TESTING PROPOSAL'S KINDS AND TYPES. # SEE RELATION MATRIX IN `CompletionProposal`'s DOCSTRING def test_local_variable_completion_proposal(self): - code = 'def foo():\n xvar = 5\n x' + code = "def foo():\n xvar = 5\n x" result = self._assist(code) - self.assert_completion_in_result('xvar', 'local', result, 'instance') + self.assert_completion_in_result("xvar", "local", result, "instance") def test_global_variable_completion_proposal(self): - code = 'yvar = 5\ny' + code = "yvar = 5\ny" result = self._assist(code) - self.assert_completion_in_result('yvar', 'global', result, 'instance') + self.assert_completion_in_result("yvar", "global", result, "instance") def test_builtin_variable_completion_proposal(self): - for varname in ('False', 'True'): + for varname in ("False", "True"): result = self._assist(varname[0]) - self.assert_completion_in_result(varname, 'builtin', result, - type='instance') + self.assert_completion_in_result( + varname, "builtin", result, type="instance" + ) def test_attribute_variable_completion_proposal(self): - code = 'class AClass(object):\n def foo(self):\n ' \ - 'self.bar = 1\n self.b' + code = ( + "class AClass(object):\n def foo(self):\n " "self.bar = 1\n self.b" + ) result = self._assist(code) - self.assert_completion_in_result('bar', 'attribute', result, - type='instance') + self.assert_completion_in_result("bar", "attribute", result, type="instance") def test_local_class_completion_proposal(self): - code = 'def foo():\n class LocalClass(object): pass\n Lo' + code = "def foo():\n class LocalClass(object): pass\n Lo" result = self._assist(code) - self.assert_completion_in_result('LocalClass', 'local', result, - type='class') + self.assert_completion_in_result("LocalClass", "local", result, type="class") def test_global_class_completion_proposal(self): - code = 'class GlobalClass(object): pass\nGl' + code = "class GlobalClass(object): pass\nGl" result = self._assist(code) - self.assert_completion_in_result('GlobalClass', 'global', result, - type='class') + self.assert_completion_in_result("GlobalClass", "global", result, type="class") def test_builtin_class_completion_proposal(self): - for varname in ('object', 'dict', 'file'): + for varname in ("object", "dict", "file"): result = self._assist(varname[0]) - self.assert_completion_in_result(varname, 'builtin', result, - type='class') + self.assert_completion_in_result(varname, "builtin", result, type="class") def test_attribute_class_completion_proposal(self): - code = 'class Outer(object):\n class Inner(object): pass\nOuter.' + code = "class Outer(object):\n class Inner(object): pass\nOuter." result = self._assist(code) - self.assert_completion_in_result('Inner', 'attribute', result, - type='class') + self.assert_completion_in_result("Inner", "attribute", result, type="class") def test_local_function_completion_proposal(self): - code = 'def outer():\n def inner(): pass\n in' + code = "def outer():\n def inner(): pass\n in" result = self._assist(code) - self.assert_completion_in_result('inner', 'local', result, - type='function') + self.assert_completion_in_result("inner", "local", result, type="function") def test_global_function_completion_proposal(self): - code = 'def foo(): pass\nf' + code = "def foo(): pass\nf" result = self._assist(code) - self.assert_completion_in_result('foo', 'global', result, - type='function') + self.assert_completion_in_result("foo", "global", result, type="function") def test_builtin_function_completion_proposal(self): - code = 'a' + code = "a" result = self._assist(code) - for expected in ('all', 'any', 'abs'): - self.assert_completion_in_result(expected, 'builtin', result, - type='function') + for expected in ("all", "any", "abs"): + self.assert_completion_in_result( + expected, "builtin", result, type="function" + ) def test_attribute_function_completion_proposal(self): - code = 'class Some(object):\n def method(self):\n self.' + code = "class Some(object):\n def method(self):\n self." result = self._assist(code) - self.assert_completion_in_result('method', 'attribute', result, - type='function') + self.assert_completion_in_result("method", "attribute", result, type="function") def test_local_module_completion_proposal(self): - code = 'def foo():\n import types\n t' + code = "def foo():\n import types\n t" result = self._assist(code) - self.assert_completion_in_result('types', 'imported', result, - type='module') + self.assert_completion_in_result("types", "imported", result, type="module") def test_global_module_completion_proposal(self): - code = 'import operator\no' + code = "import operator\no" result = self._assist(code) - self.assert_completion_in_result('operator', 'imported', result, - type='module') + self.assert_completion_in_result("operator", "imported", result, type="module") def test_attribute_module_completion_proposal(self): - code = 'class Some(object):\n import os\nSome.o' + code = "class Some(object):\n import os\nSome.o" result = self._assist(code) - self.assert_completion_in_result('os', 'imported', result, - type='module') + self.assert_completion_in_result("os", "imported", result, type="module") def test_builtin_exception_completion_proposal(self): - code = 'def blah():\n Z' + code = "def blah():\n Z" result = self._assist(code) - self.assert_completion_in_result('ZeroDivisionError', 'builtin', - result, type='class') + self.assert_completion_in_result( + "ZeroDivisionError", "builtin", result, type="class" + ) def test_keyword_completion_proposal(self): - code = 'f' + code = "f" result = self._assist(code) - self.assert_completion_in_result('for', 'keyword', result, type=None) - self.assert_completion_in_result('from', 'keyword', result, type=None) + self.assert_completion_in_result("for", "keyword", result, type=None) + self.assert_completion_in_result("from", "keyword", result, type=None) def test_parameter_keyword_completion_proposal(self): - code = 'def func(abc, aloha, alpha, amigo): pass\nfunc(a' + code = "def func(abc, aloha, alpha, amigo): pass\nfunc(a" result = self._assist(code) - for expected in ('abc=', 'aloha=', 'alpha=', 'amigo='): - self.assert_completion_in_result(expected, 'parameter_keyword', - result, type=None) + for expected in ("abc=", "aloha=", "alpha=", "amigo="): + self.assert_completion_in_result( + expected, "parameter_keyword", result, type=None + ) def test_object_path_global(self): - code = 'GLOBAL_VARIABLE = 42\n' - resource = testutils.create_module(self.project, 'mod') + code = "GLOBAL_VARIABLE = 42\n" + resource = testutils.create_module(self.project, "mod") resource.write(code) result = get_canonical_path(self.project, resource, 1) - mod_path = os.path.join(self.project.address, 'mod.py') + mod_path = os.path.join(self.project.address, "mod.py") self.assertEqual( - result, [(mod_path, 'MODULE'), - ('GLOBAL_VARIABLE', 'VARIABLE')]) + result, [(mod_path, "MODULE"), ("GLOBAL_VARIABLE", "VARIABLE")] + ) def test_object_path_attribute(self): - code = 'class Foo(object):\n' \ - ' attr = 42\n' - resource = testutils.create_module(self.project, 'mod') + code = "class Foo(object):\n" " attr = 42\n" + resource = testutils.create_module(self.project, "mod") resource.write(code) result = get_canonical_path(self.project, resource, 24) - mod_path = os.path.join(self.project.address, 'mod.py') + mod_path = os.path.join(self.project.address, "mod.py") self.assertEqual( - result, [(mod_path, 'MODULE'), ('Foo', 'CLASS'), - ('attr', 'VARIABLE')]) + result, [(mod_path, "MODULE"), ("Foo", "CLASS"), ("attr", "VARIABLE")] + ) def test_object_path_subclass(self): - code = 'class Foo(object):\n' \ - ' class Bar(object):\n' \ - ' pass\n' - resource = testutils.create_module(self.project, 'mod') + code = "class Foo(object):\n" " class Bar(object):\n" " pass\n" + resource = testutils.create_module(self.project, "mod") resource.write(code) result = get_canonical_path(self.project, resource, 30) - mod_path = os.path.join(self.project.address, 'mod.py') + mod_path = os.path.join(self.project.address, "mod.py") self.assertEqual( - result, [(mod_path, 'MODULE'), ('Foo', 'CLASS'), - ('Bar', 'CLASS')]) + result, [(mod_path, "MODULE"), ("Foo", "CLASS"), ("Bar", "CLASS")] + ) def test_object_path_method_parameter(self): - code = 'class Foo(object):\n' \ - ' def bar(self, a, b, c):\n' \ - ' pass\n' - resource = testutils.create_module(self.project, 'mod') + code = "class Foo(object):\n" " def bar(self, a, b, c):\n" " pass\n" + resource = testutils.create_module(self.project, "mod") resource.write(code) result = get_canonical_path(self.project, resource, 41) - mod_path = os.path.join(self.project.address, 'mod.py') + mod_path = os.path.join(self.project.address, "mod.py") self.assertEqual( - result, [(mod_path, 'MODULE'), ('Foo', 'CLASS'), - ('bar', 'FUNCTION'), ('b', 'PARAMETER')]) + result, + [ + (mod_path, "MODULE"), + ("Foo", "CLASS"), + ("bar", "FUNCTION"), + ("b", "PARAMETER"), + ], + ) def test_object_path_variable(self): - code = 'def bar(a):\n' \ - ' x = a + 42\n' - resource = testutils.create_module(self.project, 'mod') + code = "def bar(a):\n" " x = a + 42\n" + resource = testutils.create_module(self.project, "mod") resource.write(code) result = get_canonical_path(self.project, resource, 17) - mod_path = os.path.join(self.project.address, 'mod.py') + mod_path = os.path.join(self.project.address, "mod.py") self.assertEqual( - result, [(mod_path, 'MODULE'), ('bar', 'FUNCTION'), - ('x', 'VARIABLE')]) + result, [(mod_path, "MODULE"), ("bar", "FUNCTION"), ("x", "VARIABLE")] + ) class CodeAssistInProjectsTest(unittest.TestCase): @@ -996,16 +1022,17 @@ def setUp(self): super(CodeAssistInProjectsTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - samplemod = testutils.create_module(self.project, 'samplemod') - code = 'class SampleClass(object):\n' \ - ' def sample_method():\n pass\n\n' \ - 'def sample_func():\n pass\n' \ - 'sample_var = 10\n\n' \ - 'def _underlined_func():\n pass\n\n' + samplemod = testutils.create_module(self.project, "samplemod") + code = ( + "class SampleClass(object):\n" + " def sample_method():\n pass\n\n" + "def sample_func():\n pass\n" + "sample_var = 10\n\n" + "def _underlined_func():\n pass\n\n" + ) samplemod.write(code) - package = testutils.create_package(self.project, 'package') - nestedmod = testutils.create_module(self.project, # noqa - 'nestedmod', package) + package = testutils.create_package(self.project, "package") + nestedmod = testutils.create_module(self.project, "nestedmod", package) # noqa def tearDown(self): testutils.remove_project(self.project) @@ -1018,187 +1045,196 @@ def assert_completion_in_result(self, name, scope, result): for proposal in result: if proposal.name == name and proposal.scope == scope: return - self.fail('completion <%s> not proposed' % name) + self.fail("completion <%s> not proposed" % name) def assert_completion_not_in_result(self, name, scope, result): for proposal in result: if proposal.name == name and proposal.scope == scope: - self.fail('completion <%s> was proposed' % name) + self.fail("completion <%s> was proposed" % name) def test_simple_import(self): - code = 'import samplemod\nsample' + code = "import samplemod\nsample" result = self._assist(code) - self.assert_completion_in_result('samplemod', 'imported', result) + self.assert_completion_in_result("samplemod", "imported", result) def test_from_import_class(self): - code = 'from samplemod import SampleClass\nSample' + code = "from samplemod import SampleClass\nSample" result = self._assist(code) - self.assert_completion_in_result('SampleClass', 'imported', result) + self.assert_completion_in_result("SampleClass", "imported", result) def test_from_import_function(self): - code = 'from samplemod import sample_func\nsample' + code = "from samplemod import sample_func\nsample" result = self._assist(code) - self.assert_completion_in_result('sample_func', 'imported', result) + self.assert_completion_in_result("sample_func", "imported", result) def test_from_import_variable(self): - code = 'from samplemod import sample_var\nsample' + code = "from samplemod import sample_var\nsample" result = self._assist(code) - self.assert_completion_in_result('sample_var', 'imported', result) + self.assert_completion_in_result("sample_var", "imported", result) def test_from_imports_inside_functions(self): - code = 'def f():\n from samplemod import SampleClass\n Sample' + code = "def f():\n from samplemod import SampleClass\n Sample" result = self._assist(code) - self.assert_completion_in_result('SampleClass', 'imported', result) + self.assert_completion_in_result("SampleClass", "imported", result) def test_from_import_only_imports_imported(self): - code = 'from samplemod import sample_func\nSample' + code = "from samplemod import sample_func\nSample" result = self._assist(code) - self.assert_completion_not_in_result('SampleClass', 'global', result) + self.assert_completion_not_in_result("SampleClass", "global", result) def test_from_import_star(self): - code = 'from samplemod import *\nSample' + code = "from samplemod import *\nSample" result = self._assist(code) - self.assert_completion_in_result('SampleClass', 'imported', result) + self.assert_completion_in_result("SampleClass", "imported", result) def test_from_import_star2(self): - code = 'from samplemod import *\nsample' + code = "from samplemod import *\nsample" result = self._assist(code) - self.assert_completion_in_result('sample_func', 'imported', result) - self.assert_completion_in_result('sample_var', 'imported', result) + self.assert_completion_in_result("sample_func", "imported", result) + self.assert_completion_in_result("sample_var", "imported", result) def test_from_import_star_not_imporing_underlined(self): - code = 'from samplemod import *\n_under' + code = "from samplemod import *\n_under" result = self._assist(code) - self.assert_completion_not_in_result('_underlined_func', 'global', - result) + self.assert_completion_not_in_result("_underlined_func", "global", result) def test_from_package_import_mod(self): - code = 'from package import nestedmod\nnest' + code = "from package import nestedmod\nnest" result = self._assist(code) - self.assert_completion_in_result('nestedmod', 'imported', result) + self.assert_completion_in_result("nestedmod", "imported", result) def test_completing_after_dot(self): - code = 'class SampleClass(object):\n' \ - ' def sample_method(self):\n' \ - ' pass\n' \ - 'SampleClass.sam' + code = ( + "class SampleClass(object):\n" + " def sample_method(self):\n" + " pass\n" + "SampleClass.sam" + ) result = self._assist(code) - self.assert_completion_in_result('sample_method', 'attribute', result) + self.assert_completion_in_result("sample_method", "attribute", result) def test_completing_after_multiple_dots(self): - code = 'class Class1(object):\n' \ - ' class Class2(object):\n' \ - ' def sample_method(self):\n' \ - ' pass\n' \ - 'Class1.Class2.sam' + code = ( + "class Class1(object):\n" + " class Class2(object):\n" + " def sample_method(self):\n" + " pass\n" + "Class1.Class2.sam" + ) result = self._assist(code) - self.assert_completion_in_result('sample_method', 'attribute', result) + self.assert_completion_in_result("sample_method", "attribute", result) def test_completing_after_self_dot(self): - code = 'class Sample(object):\n' \ - ' def method1(self):\n' \ - ' pass\n' \ - ' def method2(self):\n' \ - ' self.m' + code = ( + "class Sample(object):\n" + " def method1(self):\n" + " pass\n" + " def method2(self):\n" + " self.m" + ) result = self._assist(code) - self.assert_completion_in_result('method1', 'attribute', result) + self.assert_completion_in_result("method1", "attribute", result) def test_result_start_offset_for_dotted_completions(self): - code = 'class Sample(object):\n' \ - ' def method1(self):\n' \ - ' pass\n' \ - 'Sample.me' + code = ( + "class Sample(object):\n" + " def method1(self):\n" + " pass\n" + "Sample.me" + ) self.assertEqual(len(code) - 2, starting_offset(code, len(code))) def test_backslash_after_dots(self): - code = 'class Sample(object):\n' \ - ' def a_method(self):\n' \ - ' pass\n' \ - 'Sample.\\\n a_m' + code = ( + "class Sample(object):\n" + " def a_method(self):\n" + " pass\n" + "Sample.\\\n a_m" + ) result = self._assist(code) - self.assert_completion_in_result('a_method', 'attribute', result) + self.assert_completion_in_result("a_method", "attribute", result) def test_not_proposing_global_names_after_dot(self): - code = 'class Sample(object):\n' \ - ' def a_method(self):\n' \ - ' pass\n' \ - 'Sample.' + code = ( + "class Sample(object):\n" + " def a_method(self):\n" + " pass\n" + "Sample." + ) result = self._assist(code) - self.assert_completion_not_in_result('Sample', 'global', result) + self.assert_completion_not_in_result("Sample", "global", result) def test_assist_on_relative_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod1.write('def a_func():\n pass\n') - code = 'import mod1\nmod1.' + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod1.write("def a_func():\n pass\n") + code = "import mod1\nmod1." result = self._assist(code, resource=mod2) - self.assert_completion_in_result('a_func', 'imported', result) + self.assert_completion_in_result("a_func", "imported", result) def test_get_location_on_relative_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod1.write('def a_func():\n pass\n') - code = 'import mod1\nmod1.a_func\n' - result = get_definition_location(self.project, code, - len(code) - 2, mod2) + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod1.write("def a_func():\n pass\n") + code = "import mod1\nmod1.a_func\n" + result = get_definition_location(self.project, code, len(code) - 2, mod2) self.assertEqual((mod1, 1), result) def test_get_definition_location_for_builtins(self): - code = 'import sys\n' - result = get_definition_location(self.project, code, - len(code) - 2) + code = "import sys\n" + result = get_definition_location(self.project, code, len(code) - 2) self.assertEqual((None, None), result) def test_get_doc_on_relative_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) mod1.write('def a_func():\n """hey"""\n pass\n') - code = 'import mod1\nmod1.a_func\n' + code = "import mod1\nmod1.a_func\n" result = get_doc(self.project, code, len(code) - 2, mod2) - self.assertTrue(result.endswith('hey')) + self.assertTrue(result.endswith("hey")) def test_get_doc_on_from_import_module(self): - mod1 = testutils.create_module(self.project, 'mod1') + mod1 = testutils.create_module(self.project, "mod1") mod1.write('"""mod1 docs"""\nvar = 1\n') - code = 'from mod1 import var\n' - result = get_doc(self.project, code, code.index('mod1')) - result.index('mod1 docs') + code = "from mod1 import var\n" + result = get_doc(self.project, code, code.index("mod1")) + result.index("mod1 docs") def test_fixing_errors_with_maxfixes_in_resources(self): - mod = testutils.create_module(self.project, 'mod') - code = 'def f():\n sldj sldj\ndef g():\n ran' + mod = testutils.create_module(self.project, "mod") + code = "def f():\n sldj sldj\ndef g():\n ran" mod.write(code) result = self._assist(code, maxfixes=2, resource=mod) self.assertTrue(len(result) > 0) def test_completing_names_after_from_import(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('myvar = None\n') - result = self._assist('from mod1 import myva', resource=mod2) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("myvar = None\n") + result = self._assist("from mod1 import myva", resource=mod2) self.assertTrue(len(result) > 0) - self.assert_completion_in_result('myvar', 'global', result) + self.assert_completion_in_result("myvar", "global", result) def test_completing_names_after_from_import_and_sorted_proposals(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('myvar = None\n') - result = self._assist('from mod1 import myva', resource=mod2) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("myvar = None\n") + result = self._assist("from mod1 import myva", resource=mod2) result = sorted_proposals(result) self.assertTrue(len(result) > 0) - self.assert_completion_in_result('myvar', 'global', result) + self.assert_completion_in_result("myvar", "global", result) def test_completing_names_after_from_import2(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('myvar = None\n') - result = self._assist('from mod1 import ', resource=mod2) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("myvar = None\n") + result = self._assist("from mod1 import ", resource=mod2) self.assertTrue(len(result) > 0) - self.assert_completion_in_result('myvar', 'global', result) + self.assert_completion_in_result("myvar", "global", result) def test_starting_expression(self): - code = 'l = list()\nl.app' - self.assertEqual('l.app', starting_expression(code, len(code))) + code = "l = list()\nl.app" + self.assertEqual("l.app", starting_expression(code, len(code))) diff --git a/ropetest/contrib/finderrorstest.py b/ropetest/contrib/finderrorstest.py index d78e7e494..f829cf1c9 100644 --- a/ropetest/contrib/finderrorstest.py +++ b/ropetest/contrib/finderrorstest.py @@ -9,43 +9,39 @@ class FindErrorsTest(unittest.TestCase): - def setUp(self): super(FindErrorsTest, self).setUp() self.project = testutils.sample_project() - self.mod = self.project.root.create_file('mod.py') + self.mod = self.project.root.create_file("mod.py") def tearDown(self): testutils.remove_project(self.project) super(FindErrorsTest, self).tearDown() def test_unresolved_variables(self): - self.mod.write('print(var)\n') + self.mod.write("print(var)\n") result = finderrors.find_errors(self.project, self.mod) self.assertEqual(1, len(result)) self.assertEqual(1, result[0].lineno) def test_defined_later(self): - self.mod.write('print(var)\nvar = 1\n') + self.mod.write("print(var)\nvar = 1\n") result = finderrors.find_errors(self.project, self.mod) self.assertEqual(1, len(result)) self.assertEqual(1, result[0].lineno) def test_ignoring_builtins(self): - self.mod.write('range(2)\n') + self.mod.write("range(2)\n") result = finderrors.find_errors(self.project, self.mod) self.assertEqual(0, len(result)) def test_ignoring_none(self): - self.mod.write('var = None\n') + self.mod.write("var = None\n") result = finderrors.find_errors(self.project, self.mod) self.assertEqual(0, len(result)) def test_bad_attributes(self): - code = 'class C(object):\n' \ - ' pass\n' \ - 'c = C()\n' \ - 'print(c.var)\n' + code = "class C(object):\n" " pass\n" "c = C()\n" "print(c.var)\n" self.mod.write(code) result = finderrors.find_errors(self.project, self.mod) self.assertEqual(1, len(result)) diff --git a/ropetest/contrib/findittest.py b/ropetest/contrib/findittest.py index 5e03a8f87..555562b3d 100644 --- a/ropetest/contrib/findittest.py +++ b/ropetest/contrib/findittest.py @@ -4,13 +4,11 @@ import unittest from rope.base import exceptions -from rope.contrib.findit import (find_occurrences, find_implementations, - find_definition) +from rope.contrib.findit import find_occurrences, find_implementations, find_definition from ropetest import testutils class FindItTest(unittest.TestCase): - def setUp(self): super(FindItTest, self).setUp() self.project = testutils.sample_project() @@ -20,95 +18,103 @@ def tearDown(self): super(FindItTest, self).tearDown() def test_finding_occurrences(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('a_var = 1\n') + mod = testutils.create_module(self.project, "mod") + mod.write("a_var = 1\n") result = find_occurrences(self.project, mod, 1) self.assertEqual(mod, result[0].resource) self.assertEqual(0, result[0].offset) self.assertEqual(False, result[0].unsure) def test_finding_occurrences_in_more_than_one_module(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('a_var = 1\n') - mod2.write('import mod1\nmy_var = mod1.a_var') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("a_var = 1\n") + mod2.write("import mod1\nmy_var = mod1.a_var") result = find_occurrences(self.project, mod1, 1) self.assertEqual(2, len(result)) modules = (result[0].resource, result[1].resource) self.assertTrue(mod1 in modules and mod2 in modules) def test_finding_occurrences_matching_when_unsure(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class C(object):\n def a_func(self):\n pass\n' - 'def f(arg):\n arg.a_func()\n') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write( + "class C(object):\n def a_func(self):\n pass\n" + "def f(arg):\n arg.a_func()\n" + ) result = find_occurrences( - self.project, mod1, mod1.read().index('a_func'), unsure=True) + self.project, mod1, mod1.read().index("a_func"), unsure=True + ) self.assertEqual(2, len(result)) def test_find_occurrences_resources_parameter(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('a_var = 1\n') - mod2.write('import mod1\nmy_var = mod1.a_var') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("a_var = 1\n") + mod2.write("import mod1\nmy_var = mod1.a_var") result = find_occurrences(self.project, mod1, 1, resources=[mod1]) self.assertEqual(1, len(result)) self.assertEqual((mod1, 0), (result[0].resource, result[0].offset)) def test_find_occurrences_and_class_hierarchies(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class A(object):\n def f():\n pass\n' - 'class B(A):\n def f():\n pass\n') - offset = mod1.read().rindex('f') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write( + "class A(object):\n def f():\n pass\n" + "class B(A):\n def f():\n pass\n" + ) + offset = mod1.read().rindex("f") result1 = find_occurrences(self.project, mod1, offset) - result2 = find_occurrences(self.project, mod1, - offset, in_hierarchy=True) + result2 = find_occurrences(self.project, mod1, offset, in_hierarchy=True) self.assertEqual(1, len(result1)) self.assertEqual(2, len(result2)) def test_trivial_find_implementations(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class A(object):\n def f(self):\n pass\n') - offset = mod1.read().rindex('f(') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("class A(object):\n def f(self):\n pass\n") + offset = mod1.read().rindex("f(") result = find_implementations(self.project, mod1, offset) self.assertEqual([], result) def test_find_implementations_and_not_returning_parents(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class A(object):\n def f(self):\n pass\n' - 'class B(A):\n def f(self):\n pass\n') - offset = mod1.read().rindex('f(') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write( + "class A(object):\n def f(self):\n pass\n" + "class B(A):\n def f(self):\n pass\n" + ) + offset = mod1.read().rindex("f(") result = find_implementations(self.project, mod1, offset) self.assertEqual([], result) def test_find_implementations_real_implementation(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class A(object):\n def f(self):\n pass\n' - 'class B(A):\n def f(self):\n pass\n') - offset = mod1.read().index('f(') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write( + "class A(object):\n def f(self):\n pass\n" + "class B(A):\n def f(self):\n pass\n" + ) + offset = mod1.read().index("f(") result = find_implementations(self.project, mod1, offset) self.assertEqual(1, len(result)) - self.assertEqual(mod1.read().rindex('f('), result[0].offset) + self.assertEqual(mod1.read().rindex("f("), result[0].offset) def test_find_implementations_real_implementation_simple(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class A(object):\n pass\n') - offset = mod1.read().index('A') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("class A(object):\n pass\n") + offset = mod1.read().index("A") with self.assertRaises(exceptions.BadIdentifierError): find_implementations(self.project, mod1, offset) def test_trivial_find_definition(self): - code = 'def a_func():\n pass\na_func()' - result = find_definition(self.project, code, code.rindex('a_func')) - start = code.index('a_func') + code = "def a_func():\n pass\na_func()" + result = find_definition(self.project, code, code.rindex("a_func")) + start = code.index("a_func") self.assertEqual(start, result.offset) self.assertEqual(None, result.resource) self.assertEqual(1, result.lineno) - self.assertEqual((start, start + len('a_func')), result.region) + self.assertEqual((start, start + len("a_func")), result.region) def test_find_definition_in_other_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('var = 1\n') - code = 'import mod1\nprint(mod1.var)\n' - result = find_definition(self.project, code, code.index('var')) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("var = 1\n") + code = "import mod1\nprint(mod1.var)\n" + result = find_definition(self.project, code, code.index("var")) self.assertEqual(mod1, result.resource) self.assertEqual(0, result.offset) diff --git a/ropetest/contrib/fixmodnamestest.py b/ropetest/contrib/fixmodnamestest.py index 05fbec3ba..ac231ec7a 100644 --- a/ropetest/contrib/fixmodnamestest.py +++ b/ropetest/contrib/fixmodnamestest.py @@ -12,7 +12,6 @@ # HACK: for making this test work on case-insensitive file-systems, it # uses a name.replace('x', '_') fixer. class FixModuleNamesTest(unittest.TestCase): - def setUp(self): super(FixModuleNamesTest, self).setUp() self.project = testutils.sample_project() @@ -22,36 +21,36 @@ def tearDown(self): super(FixModuleNamesTest, self).tearDown() def test_simple_module_renaming(self): - mod = create_module(self.project, 'xod') + mod = create_module(self.project, "xod") self.project.do(FixModuleNames(self.project).get_changes(_fixer)) self.assertFalse(mod.exists()) - self.assertTrue(self.project.get_resource('_od.py').exists()) + self.assertTrue(self.project.get_resource("_od.py").exists()) def test_packages_module_renaming(self): - pkg = create_package(self.project, 'xkg') + pkg = create_package(self.project, "xkg") self.project.do(FixModuleNames(self.project).get_changes(_fixer)) self.assertFalse(pkg.exists()) - self.assertTrue(self.project.get_resource('_kg/__init__.py').exists()) + self.assertTrue(self.project.get_resource("_kg/__init__.py").exists()) def test_fixing_contents(self): - mod1 = create_module(self.project, 'xod1') - mod2 = create_module(self.project, 'xod2') - mod1.write('import xod2\n') - mod2.write('import xod1\n') + mod1 = create_module(self.project, "xod1") + mod2 = create_module(self.project, "xod2") + mod1.write("import xod2\n") + mod2.write("import xod1\n") self.project.do(FixModuleNames(self.project).get_changes(_fixer)) - newmod1 = self.project.get_resource('_od1.py') - newmod2 = self.project.get_resource('_od2.py') - self.assertEqual('import _od2\n', newmod1.read()) - self.assertEqual('import _od1\n', newmod2.read()) + newmod1 = self.project.get_resource("_od1.py") + newmod2 = self.project.get_resource("_od2.py") + self.assertEqual("import _od2\n", newmod1.read()) + self.assertEqual("import _od1\n", newmod2.read()) def test_handling_nested_modules(self): - pkg = create_package(self.project, 'xkg') - mod = create_module(self.project, 'xkg.xod') # noqa + pkg = create_package(self.project, "xkg") + mod = create_module(self.project, "xkg.xod") # noqa self.project.do(FixModuleNames(self.project).get_changes(_fixer)) self.assertFalse(pkg.exists()) - self.assertTrue(self.project.get_resource('_kg/__init__.py').exists()) - self.assertTrue(self.project.get_resource('_kg/_od.py').exists()) + self.assertTrue(self.project.get_resource("_kg/__init__.py").exists()) + self.assertTrue(self.project.get_resource("_kg/_od.py").exists()) def _fixer(name): - return name.replace('x', '_') + return name.replace("x", "_") diff --git a/ropetest/contrib/generatetest.py b/ropetest/contrib/generatetest.py index 3471ed634..f68dc37a5 100644 --- a/ropetest/contrib/generatetest.py +++ b/ropetest/contrib/generatetest.py @@ -9,14 +9,13 @@ class GenerateTest(unittest.TestCase): - def setUp(self): super(GenerateTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod1') - self.mod2 = testutils.create_module(self.project, 'mod2') - self.pkg = testutils.create_package(self.project, 'pkg') + self.mod = testutils.create_module(self.project, "mod1") + self.mod2 = testutils.create_module(self.project, "mod2") + self.pkg = testutils.create_package(self.project, "pkg") def tearDown(self): testutils.remove_project(self.project) @@ -26,7 +25,9 @@ def _get_generate(self, offset): return generate.GenerateVariable(self.project, self.mod, offset) def _get_generate_class(self, offset, goal_mod=None): - return generate.GenerateClass(self.project, self.mod, offset, goal_resource=goal_mod) + return generate.GenerateClass( + self.project, self.mod, offset, goal_resource=goal_mod + ) def _get_generate_module(self, offset): return generate.GenerateModule(self.project, self.mod, offset) @@ -38,275 +39,272 @@ def _get_generate_function(self, offset): return generate.GenerateFunction(self.project, self.mod, offset) def test_getting_location(self): - code = 'a_var = name\n' + code = "a_var = name\n" self.mod.write(code) - generator = self._get_generate(code.index('name')) + generator = self._get_generate(code.index("name")) self.assertEqual((self.mod, 1), generator.get_location()) def test_generating_variable(self): - code = 'a_var = name\n' + code = "a_var = name\n" self.mod.write(code) - changes = self._get_generate(code.index('name')).get_changes() + changes = self._get_generate(code.index("name")).get_changes() self.project.do(changes) - self.assertEqual('name = None\n\n\na_var = name\n', self.mod.read()) + self.assertEqual("name = None\n\n\na_var = name\n", self.mod.read()) def test_generating_variable_inserting_before_statement(self): - code = 'c = 1\nc = b\n' + code = "c = 1\nc = b\n" self.mod.write(code) - changes = self._get_generate(code.index('b')).get_changes() + changes = self._get_generate(code.index("b")).get_changes() self.project.do(changes) - self.assertEqual('c = 1\nb = None\n\n\nc = b\n', self.mod.read()) + self.assertEqual("c = 1\nb = None\n\n\nc = b\n", self.mod.read()) def test_generating_variable_in_local_scopes(self): - code = 'def f():\n c = 1\n c = b\n' + code = "def f():\n c = 1\n c = b\n" self.mod.write(code) - changes = self._get_generate(code.index('b')).get_changes() + changes = self._get_generate(code.index("b")).get_changes() self.project.do(changes) - self.assertEqual('def f():\n c = 1\n b = None\n c = b\n', - self.mod.read()) + self.assertEqual( + "def f():\n c = 1\n b = None\n c = b\n", self.mod.read() + ) def test_generating_variable_in_other_modules(self): - code = 'import mod2\nc = mod2.b\n' + code = "import mod2\nc = mod2.b\n" self.mod.write(code) - generator = self._get_generate(code.index('b')) + generator = self._get_generate(code.index("b")) self.project.do(generator.get_changes()) self.assertEqual((self.mod2, 1), generator.get_location()) - self.assertEqual('b = None\n', self.mod2.read()) + self.assertEqual("b = None\n", self.mod2.read()) def test_generating_variable_in_classes(self): - code = 'class C(object):\n def f(self):\n pass\n' \ - 'c = C()\na_var = c.attr' + code = ( + "class C(object):\n def f(self):\n pass\n" + "c = C()\na_var = c.attr" + ) self.mod.write(code) - changes = self._get_generate(code.index('attr')).get_changes() + changes = self._get_generate(code.index("attr")).get_changes() self.project.do(changes) self.assertEqual( - 'class C(object):\n def f(self):\n ' - 'pass\n\n attr = None\n' - 'c = C()\na_var = c.attr', self.mod.read()) + "class C(object):\n def f(self):\n " + "pass\n\n attr = None\n" + "c = C()\na_var = c.attr", + self.mod.read(), + ) def test_generating_variable_in_classes_removing_pass(self): - code = 'class C(object):\n pass\nc = C()\na_var = c.attr' + code = "class C(object):\n pass\nc = C()\na_var = c.attr" self.mod.write(code) - changes = self._get_generate(code.index('attr')).get_changes() + changes = self._get_generate(code.index("attr")).get_changes() self.project.do(changes) - self.assertEqual('class C(object):\n\n attr = None\n' - 'c = C()\na_var = c.attr', self.mod.read()) + self.assertEqual( + "class C(object):\n\n attr = None\n" "c = C()\na_var = c.attr", + self.mod.read(), + ) def test_generating_variable_in_packages(self): - code = 'import pkg\na = pkg.a\n' + code = "import pkg\na = pkg.a\n" self.mod.write(code) - generator = self._get_generate(code.rindex('a')) + generator = self._get_generate(code.rindex("a")) self.project.do(generator.get_changes()) - init = self.pkg.get_child('__init__.py') + init = self.pkg.get_child("__init__.py") self.assertEqual((init, 1), generator.get_location()) - self.assertEqual('a = None\n', init.read()) + self.assertEqual("a = None\n", init.read()) def test_generating_classes(self): - code = 'c = C()\n' + code = "c = C()\n" self.mod.write(code) - changes = self._get_generate_class(code.index('C')).get_changes() + changes = self._get_generate_class(code.index("C")).get_changes() self.project.do(changes) - self.assertEqual('class C(object):\n pass\n\n\nc = C()\n', - self.mod.read()) + self.assertEqual("class C(object):\n pass\n\n\nc = C()\n", self.mod.read()) def test_generating_classes_in_other_module(self): - code = 'c = C()\n' + code = "c = C()\n" self.mod.write(code) - changes = self._get_generate_class(code.index('C'), self.mod2).get_changes() + changes = self._get_generate_class(code.index("C"), self.mod2).get_changes() self.project.do(changes) - self.assertEqual('class C(object):\n pass\n', - self.mod2.read()) - self.assertEqual('from mod2 import C\nc = C()\n', - self.mod.read()) + self.assertEqual("class C(object):\n pass\n", self.mod2.read()) + self.assertEqual("from mod2 import C\nc = C()\n", self.mod.read()) def test_generating_modules(self): - code = 'import pkg\npkg.mod\n' + code = "import pkg\npkg.mod\n" self.mod.write(code) - generator = self._get_generate_module(code.rindex('mod')) + generator = self._get_generate_module(code.rindex("mod")) self.project.do(generator.get_changes()) - mod = self.pkg.get_child('mod.py') + mod = self.pkg.get_child("mod.py") self.assertEqual((mod, 1), generator.get_location()) - self.assertEqual('import pkg.mod\npkg.mod\n', self.mod.read()) + self.assertEqual("import pkg.mod\npkg.mod\n", self.mod.read()) def test_generating_packages(self): - code = 'import pkg\npkg.pkg2\n' + code = "import pkg\npkg.pkg2\n" self.mod.write(code) - generator = self._get_generate_package(code.rindex('pkg2')) + generator = self._get_generate_package(code.rindex("pkg2")) self.project.do(generator.get_changes()) - pkg2 = self.pkg.get_child('pkg2') - init = pkg2.get_child('__init__.py') + pkg2 = self.pkg.get_child("pkg2") + init = pkg2.get_child("__init__.py") self.assertEqual((init, 1), generator.get_location()) - self.assertEqual('import pkg.pkg2\npkg.pkg2\n', self.mod.read()) + self.assertEqual("import pkg.pkg2\npkg.pkg2\n", self.mod.read()) def test_generating_function(self): - code = 'a_func()\n' + code = "a_func()\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) - self.assertEqual('def a_func():\n pass\n\n\na_func()\n', - self.mod.read()) + self.assertEqual("def a_func():\n pass\n\n\na_func()\n", self.mod.read()) def test_generating_modules_with_empty_primary(self): - code = 'mod\n' + code = "mod\n" self.mod.write(code) - generator = self._get_generate_module(code.rindex('mod')) + generator = self._get_generate_module(code.rindex("mod")) self.project.do(generator.get_changes()) - mod = self.project.root.get_child('mod.py') + mod = self.project.root.get_child("mod.py") self.assertEqual((mod, 1), generator.get_location()) - self.assertEqual('import mod\nmod\n', self.mod.read()) + self.assertEqual("import mod\nmod\n", self.mod.read()) def test_generating_variable_already_exists(self): - code = 'b = 1\nc = b\n' + code = "b = 1\nc = b\n" self.mod.write(code) with self.assertRaises(exceptions.RefactoringError): - self._get_generate(code.index('b')).get_changes() + self._get_generate(code.index("b")).get_changes() def test_generating_variable_primary_cannot_be_determined(self): - code = 'c = can_not_be_found.b\n' + code = "c = can_not_be_found.b\n" self.mod.write(code) with self.assertRaises(exceptions.RefactoringError): - self._get_generate(code.rindex('b')).get_changes() + self._get_generate(code.rindex("b")).get_changes() def test_generating_modules_when_already_exists(self): - code = 'mod2\n' + code = "mod2\n" self.mod.write(code) - generator = self._get_generate_module(code.rindex('mod')) + generator = self._get_generate_module(code.rindex("mod")) with self.assertRaises(exceptions.RefactoringError): self.project.do(generator.get_changes()) def test_generating_static_methods(self): - code = 'class C(object):\n pass\nC.a_func()\n' + code = "class C(object):\n pass\nC.a_func()\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) self.assertEqual( - 'class C(object):\n\n @staticmethod' - '\n def a_func():\n pass\nC.a_func()\n', - self.mod.read()) + "class C(object):\n\n @staticmethod" + "\n def a_func():\n pass\nC.a_func()\n", + self.mod.read(), + ) def test_generating_methods(self): - code = 'class C(object):\n pass\nc = C()\nc.a_func()\n' + code = "class C(object):\n pass\nc = C()\nc.a_func()\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) self.assertEqual( - 'class C(object):\n\n def a_func(self):\n pass\n' - 'c = C()\nc.a_func()\n', - self.mod.read()) + "class C(object):\n\n def a_func(self):\n pass\n" + "c = C()\nc.a_func()\n", + self.mod.read(), + ) def test_generating_constructors(self): - code = 'class C(object):\n pass\nc = C()\n' + code = "class C(object):\n pass\nc = C()\n" self.mod.write(code) - changes = self._get_generate_function(code.rindex('C')).get_changes() + changes = self._get_generate_function(code.rindex("C")).get_changes() self.project.do(changes) self.assertEqual( - 'class C(object):\n\n def __init__(self):\n pass\n' - 'c = C()\n', - self.mod.read()) + "class C(object):\n\n def __init__(self):\n pass\n" "c = C()\n", + self.mod.read(), + ) def test_generating_calls(self): - code = 'class C(object):\n pass\nc = C()\nc()\n' + code = "class C(object):\n pass\nc = C()\nc()\n" self.mod.write(code) - changes = self._get_generate_function(code.rindex('c')).get_changes() + changes = self._get_generate_function(code.rindex("c")).get_changes() self.project.do(changes) self.assertEqual( - 'class C(object):\n\n def __call__(self):\n pass\n' - 'c = C()\nc()\n', - self.mod.read()) + "class C(object):\n\n def __call__(self):\n pass\n" + "c = C()\nc()\n", + self.mod.read(), + ) def test_generating_calls_in_other_modules(self): - self.mod2.write('class C(object):\n pass\n') - code = 'import mod2\nc = mod2.C()\nc()\n' + self.mod2.write("class C(object):\n pass\n") + code = "import mod2\nc = mod2.C()\nc()\n" self.mod.write(code) - changes = self._get_generate_function(code.rindex('c')).get_changes() + changes = self._get_generate_function(code.rindex("c")).get_changes() self.project.do(changes) self.assertEqual( - 'class C(object):\n\n def __call__(self):\n pass\n', - self.mod2.read()) + "class C(object):\n\n def __call__(self):\n pass\n", + self.mod2.read(), + ) def test_generating_function_handling_arguments(self): - code = 'a_func(1)\n' + code = "a_func(1)\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) - self.assertEqual('def a_func(arg0):\n pass\n\n\na_func(1)\n', - self.mod.read()) + self.assertEqual( + "def a_func(arg0):\n pass\n\n\na_func(1)\n", self.mod.read() + ) def test_generating_function_handling_keyword_xarguments(self): - code = 'a_func(p=1)\n' + code = "a_func(p=1)\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) - self.assertEqual('def a_func(p):\n pass\n\n\na_func(p=1)\n', - self.mod.read()) + self.assertEqual("def a_func(p):\n pass\n\n\na_func(p=1)\n", self.mod.read()) def test_generating_function_handling_arguments_better_naming(self): - code = 'a_var = 1\na_func(a_var)\n' + code = "a_var = 1\na_func(a_var)\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) - self.assertEqual('a_var = 1\ndef a_func(a_var):' - '\n pass\n\n\na_func(a_var)\n', - self.mod.read()) + self.assertEqual( + "a_var = 1\ndef a_func(a_var):" "\n pass\n\n\na_func(a_var)\n", + self.mod.read(), + ) def test_generating_variable_in_other_modules2(self): - self.mod2.write('\n\n\nprint(1)\n') - code = 'import mod2\nc = mod2.b\n' + self.mod2.write("\n\n\nprint(1)\n") + code = "import mod2\nc = mod2.b\n" self.mod.write(code) - generator = self._get_generate(code.index('b')) + generator = self._get_generate(code.index("b")) self.project.do(generator.get_changes()) self.assertEqual((self.mod2, 5), generator.get_location()) - self.assertEqual('\n\n\nprint(1)\n\n\nb = None\n', self.mod2.read()) + self.assertEqual("\n\n\nprint(1)\n\n\nb = None\n", self.mod2.read()) def test_generating_function_in_a_suite(self): - code = 'if True:\n a_func()\n' + code = "if True:\n a_func()\n" self.mod.write(code) - changes = self._get_generate_function( - code.index('a_func')).get_changes() + changes = self._get_generate_function(code.index("a_func")).get_changes() self.project.do(changes) - self.assertEqual('def a_func():\n pass' - '\n\n\nif True:\n a_func()\n', - self.mod.read()) + self.assertEqual( + "def a_func():\n pass" "\n\n\nif True:\n a_func()\n", self.mod.read() + ) def test_generating_function_in_a_suite_in_a_function(self): - code = 'def f():\n a = 1\n if 1:\n g()\n' + code = "def f():\n a = 1\n if 1:\n g()\n" self.mod.write(code) - changes = self._get_generate_function(code.index('g()')).get_changes() + changes = self._get_generate_function(code.index("g()")).get_changes() self.project.do(changes) self.assertEqual( - 'def f():\n a = 1\n def g():\n pass\n' - ' if 1:\n g()\n', - self.mod.read()) + "def f():\n a = 1\n def g():\n pass\n" + " if 1:\n g()\n", + self.mod.read(), + ) def test_create_generate_class_with_goal_resource(self): - code = 'c = C()\n' + code = "c = C()\n" self.mod.write(code) result = generate.create_generate( - "class", - self.project, - self.mod, - code.index("C"), - goal_resource=self.mod2) + "class", self.project, self.mod, code.index("C"), goal_resource=self.mod2 + ) self.assertTrue(isinstance(result, generate.GenerateClass)) self.assertEqual(result.goal_resource, self.mod2) def test_create_generate_class_without_goal_resource(self): - code = 'c = C()\n' + code = "c = C()\n" self.mod.write(code) result = generate.create_generate( - "class", - self.project, - self.mod, - code.index("C")) + "class", self.project, self.mod, code.index("C") + ) self.assertTrue(isinstance(result, generate.GenerateClass)) self.assertIsNone(result.goal_resource) diff --git a/ropetest/doatest.py b/ropetest/doatest.py index 6b527b0c6..5f3c26db2 100644 --- a/ropetest/doatest.py +++ b/ropetest/doatest.py @@ -2,11 +2,13 @@ import hashlib import hmac import multiprocessing + try: import cPickle as pickle except ImportError: import pickle import socket + try: import unittest2 as unittest except ImportError: @@ -17,20 +19,20 @@ class DOATest(unittest.TestCase): - def try_CVE_2014_3539_exploit(self, receiver, payload): # Simulated attacker writing to the socket def attacker(data_port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect(('127.0.0.1', data_port)) - s_file = s.makefile('wb') + s.connect(("127.0.0.1", data_port)) + s_file = s.makefile("wb") s_file.write(payload) s.close() # Assume the attacker guesses the port correctly; 3037 is used by # default if it is available. - attacker_proc = multiprocessing.Process(target=attacker, - args=(receiver.data_port,)) + attacker_proc = multiprocessing.Process( + target=attacker, args=(receiver.data_port,) + ) attacker_proc.start() received_objs = list(receiver.receive_data()) @@ -41,7 +43,7 @@ def test_CVE_2014_3539_no_encoding(self): # Attacker sends pickled data to the receiver socket. receiver = doa._SocketReceiver() - payload = pickle.dumps('def foo():\n return 123\n') + payload = pickle.dumps("def foo():\n return 123\n") received_objs = self.try_CVE_2014_3539_exploit(receiver, payload) # Make sure the exploit did not run @@ -51,11 +53,13 @@ def test_CVE_2014_3539_signature_mismatch(self): # Attacker sends well-formed data with an incorrect signature. receiver = doa._SocketReceiver() - pickled_data = pickle.dumps('def foo():\n return 123\n', - pickle.HIGHEST_PROTOCOL) - digest = hmac.new(b'invalid-key', pickled_data, hashlib.sha256).digest() - payload = (base64.b64encode(digest) + b':' + - base64.b64encode(pickled_data) + b'\n') + pickled_data = pickle.dumps( + "def foo():\n return 123\n", pickle.HIGHEST_PROTOCOL + ) + digest = hmac.new(b"invalid-key", pickled_data, hashlib.sha256).digest() + payload = ( + base64.b64encode(digest) + b":" + base64.b64encode(pickled_data) + b"\n" + ) received_objs = self.try_CVE_2014_3539_exploit(receiver, payload) # Make sure the exploit did not run @@ -66,17 +70,17 @@ def test_CVE_2014_3539_sanity(self): receiver = doa._SocketReceiver() pickled_data = base64.b64encode( - pickle.dumps('def foo():\n return 123\n', - pickle.HIGHEST_PROTOCOL)) + pickle.dumps("def foo():\n return 123\n", pickle.HIGHEST_PROTOCOL) + ) digest = hmac.new(receiver.key, pickled_data, hashlib.sha256).digest() - payload = (base64.b64encode(digest) + b':' + pickled_data + b'\n') + payload = base64.b64encode(digest) + b":" + pickled_data + b"\n" received_objs = self.try_CVE_2014_3539_exploit(receiver, payload) # Make sure the exploit did not run self.assertEqual(1, len(received_objs)) def test_compare_digest_compat(self): - self.assertTrue(doa._compat_compare_digest('', '')) - self.assertTrue(doa._compat_compare_digest('abc', 'abc')) - self.assertFalse(doa._compat_compare_digest('abc', 'abd')) - self.assertFalse(doa._compat_compare_digest('abc', 'abcd')) + self.assertTrue(doa._compat_compare_digest("", "")) + self.assertTrue(doa._compat_compare_digest("abc", "abc")) + self.assertFalse(doa._compat_compare_digest("abc", "abd")) + self.assertFalse(doa._compat_compare_digest("abc", "abcd")) diff --git a/ropetest/historytest.py b/ropetest/historytest.py index d0d601621..de9da02e3 100644 --- a/ropetest/historytest.py +++ b/ropetest/historytest.py @@ -10,7 +10,6 @@ class HistoryTest(unittest.TestCase): - def setUp(self): super(HistoryTest, self).setUp() self.project = testutils.sample_project() @@ -21,64 +20,63 @@ def tearDown(self): super(HistoryTest, self).tearDown() def test_undoing_writes(self): - my_file = self.project.root.create_file('my_file.txt') - my_file.write('text1') + my_file = self.project.root.create_file("my_file.txt") + my_file.write("text1") self.history.undo() - self.assertEqual('', my_file.read()) + self.assertEqual("", my_file.read()) def test_moving_files(self): - my_file = self.project.root.create_file('my_file.txt') - my_file.move('new_file.txt') + my_file = self.project.root.create_file("my_file.txt") + my_file.move("new_file.txt") self.history.undo() - self.assertEqual('', my_file.read()) + self.assertEqual("", my_file.read()) def test_moving_files_to_folders(self): - my_file = self.project.root.create_file('my_file.txt') - my_folder = self.project.root.create_folder('my_folder') + my_file = self.project.root.create_file("my_file.txt") + my_folder = self.project.root.create_folder("my_folder") my_file.move(my_folder.path) self.history.undo() - self.assertEqual('', my_file.read()) + self.assertEqual("", my_file.read()) def test_writing_files_that_does_not_change_contents(self): - my_file = self.project.root.create_file('my_file.txt') - my_file.write('') + my_file = self.project.root.create_file("my_file.txt") + my_file.write("") self.project.history.undo() self.assertFalse(my_file.exists()) class IsolatedHistoryTest(unittest.TestCase): - def setUp(self): super(IsolatedHistoryTest, self).setUp() self.project = testutils.sample_project() self.history = rope.base.history.History(self.project) - self.file1 = self.project.root.create_file('file1.txt') - self.file2 = self.project.root.create_file('file2.txt') + self.file1 = self.project.root.create_file("file1.txt") + self.file2 = self.project.root.create_file("file2.txt") def tearDown(self): testutils.remove_project(self.project) super(IsolatedHistoryTest, self).tearDown() def test_simple_undo(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) - self.assertEqual('1', self.file1.read()) + self.assertEqual("1", self.file1.read()) self.history.undo() - self.assertEqual('', self.file1.read()) + self.assertEqual("", self.file1.read()) def test_tobe_undone(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') + change1 = rope.base.change.ChangeContents(self.file1, "1") self.assertEqual(None, self.history.tobe_undone) self.history.do(change1) self.assertEqual(change1, self.history.tobe_undone) - change2 = rope.base.change.ChangeContents(self.file1, '2') + change2 = rope.base.change.ChangeContents(self.file1, "2") self.history.do(change2) self.assertEqual(change2, self.history.tobe_undone) self.history.undo() self.assertEqual(change1, self.history.tobe_undone) def test_tobe_redone(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) self.assertEqual(None, self.history.tobe_redone) self.history.undo() @@ -86,43 +84,43 @@ def test_tobe_redone(self): def test_undo_limit(self): history = rope.base.history.History(self.project, maxundos=1) - history.do(rope.base.change.ChangeContents(self.file1, '1')) - history.do(rope.base.change.ChangeContents(self.file1, '2')) + history.do(rope.base.change.ChangeContents(self.file1, "1")) + history.do(rope.base.change.ChangeContents(self.file1, "2")) try: history.undo() with self.assertRaises(exceptions.HistoryError): history.undo() finally: - self.assertEqual('1', self.file1.read()) + self.assertEqual("1", self.file1.read()) def test_simple_redo(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) self.history.undo() self.history.redo() - self.assertEqual('1', self.file1.read()) + self.assertEqual("1", self.file1.read()) def test_simple_re_undo(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) self.history.undo() self.history.redo() self.history.undo() - self.assertEqual('', self.file1.read()) + self.assertEqual("", self.file1.read()) def test_multiple_undos(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) - change = rope.base.change.ChangeContents(self.file1, '2') + change = rope.base.change.ChangeContents(self.file1, "2") self.history.do(change) self.history.undo() - self.assertEqual('1', self.file1.read()) - change = rope.base.change.ChangeContents(self.file1, '3') + self.assertEqual("1", self.file1.read()) + change = rope.base.change.ChangeContents(self.file1, "3") self.history.do(change) self.history.undo() - self.assertEqual('1', self.file1.read()) + self.assertEqual("1", self.file1.read()) self.history.redo() - self.assertEqual('3', self.file1.read()) + self.assertEqual("3", self.file1.read()) def test_undo_list_underflow(self): with self.assertRaises(exceptions.HistoryError): @@ -133,130 +131,131 @@ def test_redo_list_underflow(self): self.history.redo() def test_dropping_undone_changes(self): - self.file1.write('1') + self.file1.write("1") with self.assertRaises(exceptions.HistoryError): self.history.undo(drop=True) self.history.redo() def test_undoing_choosen_changes(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) self.history.undo(change) - self.assertEqual('', self.file1.read()) + self.assertEqual("", self.file1.read()) self.assertFalse(self.history.undo_list) def test_undoing_choosen_changes2(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') + change1 = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change1) - self.history.do(rope.base.change.ChangeContents(self.file1, '2')) + self.history.do(rope.base.change.ChangeContents(self.file1, "2")) self.history.undo(change1) - self.assertEqual('', self.file1.read()) + self.assertEqual("", self.file1.read()) self.assertFalse(self.history.undo_list) def test_undoing_choosen_changes_not_undoing_others(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') + change1 = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change1) - self.history.do(rope.base.change.ChangeContents(self.file2, '2')) + self.history.do(rope.base.change.ChangeContents(self.file2, "2")) self.history.undo(change1) - self.assertEqual('', self.file1.read()) - self.assertEqual('2', self.file2.read()) + self.assertEqual("", self.file1.read()) + self.assertEqual("2", self.file2.read()) def test_undoing_writing_after_moving(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') + change1 = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change1) - self.history.do(rope.base.change.MoveResource(self.file1, 'file3.txt')) - file3 = self.project.get_resource('file3.txt') + self.history.do(rope.base.change.MoveResource(self.file1, "file3.txt")) + file3 = self.project.get_resource("file3.txt") self.history.undo(change1) - self.assertEqual('', self.file1.read()) + self.assertEqual("", self.file1.read()) self.assertFalse(file3.exists()) def test_undoing_folder_movements_for_undoing_writes_inside_it(self): - folder = self.project.root.create_folder('folder') - file3 = folder.create_file('file3.txt') - change1 = rope.base.change.ChangeContents(file3, '1') + folder = self.project.root.create_folder("folder") + file3 = folder.create_file("file3.txt") + change1 = rope.base.change.ChangeContents(file3, "1") self.history.do(change1) - self.history.do(rope.base.change.MoveResource(folder, 'new_folder')) - new_folder = self.project.get_resource('new_folder') + self.history.do(rope.base.change.MoveResource(folder, "new_folder")) + new_folder = self.project.get_resource("new_folder") self.history.undo(change1) - self.assertEqual('', file3.read()) + self.assertEqual("", file3.read()) self.assertFalse(new_folder.exists()) def test_undoing_changes_that_depend_on_a_dependant_change(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') + change1 = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change1) - changes = rope.base.change.ChangeSet('2nd change') - changes.add_change(rope.base.change.ChangeContents(self.file1, '2')) - changes.add_change(rope.base.change.ChangeContents(self.file2, '2')) + changes = rope.base.change.ChangeSet("2nd change") + changes.add_change(rope.base.change.ChangeContents(self.file1, "2")) + changes.add_change(rope.base.change.ChangeContents(self.file2, "2")) self.history.do(changes) - self.history.do(rope.base.change.MoveResource(self.file2, 'file3.txt')) - file3 = self.project.get_resource('file3.txt') + self.history.do(rope.base.change.MoveResource(self.file2, "file3.txt")) + file3 = self.project.get_resource("file3.txt") self.history.undo(change1) - self.assertEqual('', self.file1.read()) - self.assertEqual('', self.file2.read()) + self.assertEqual("", self.file1.read()) + self.assertEqual("", self.file2.read()) self.assertFalse(file3.exists()) def test_undoing_writes_for_undoing_folder_movements_containing_it(self): - folder = self.project.root.create_folder('folder') - old_file = folder.create_file('file3.txt') - change1 = rope.base.change.MoveResource(folder, 'new_folder') + folder = self.project.root.create_folder("folder") + old_file = folder.create_file("file3.txt") + change1 = rope.base.change.MoveResource(folder, "new_folder") self.history.do(change1) - new_file = self.project.get_resource('new_folder/file3.txt') - self.history.do(rope.base.change.ChangeContents(new_file, '1')) + new_file = self.project.get_resource("new_folder/file3.txt") + self.history.do(rope.base.change.ChangeContents(new_file, "1")) self.history.undo(change1) - self.assertEqual('', old_file.read()) + self.assertEqual("", old_file.read()) self.assertFalse(new_file.exists()) def test_undoing_not_available_change(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") with self.assertRaises(exceptions.HistoryError): self.history.undo(change) def test_ignoring_ignored_resources(self): - self.project.set('ignored_resources', ['ignored*']) - ignored = self.project.get_file('ignored.txt') + self.project.set("ignored_resources", ["ignored*"]) + ignored = self.project.get_file("ignored.txt") change = rope.base.change.CreateResource(ignored) self.history.do(change) self.assertTrue(ignored.exists()) self.assertEqual(0, len(self.history.undo_list)) def test_get_file_undo_list_simple(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) - self.assertEqual(set([change]), - set(self.history.get_file_undo_list(self.file1))) + self.assertEqual( + set([change]), set(self.history.get_file_undo_list(self.file1)) + ) def test_get_file_undo_list_for_moves(self): - change = rope.base.change.MoveResource(self.file1, 'file2.txt') + change = rope.base.change.MoveResource(self.file1, "file2.txt") self.history.do(change) - self.assertEqual(set([change]), - set(self.history.get_file_undo_list(self.file1))) + self.assertEqual( + set([change]), set(self.history.get_file_undo_list(self.file1)) + ) # XXX: What happens for moves before the file is created? def xxx_test_get_file_undo_list_and_moving_its_contining_folder(self): - folder = self.project.root.create_folder('folder') - old_file = folder.create_file('file3.txt') - change1 = rope.base.change.MoveResource(folder, 'new_folder') + folder = self.project.root.create_folder("folder") + old_file = folder.create_file("file3.txt") + change1 = rope.base.change.MoveResource(folder, "new_folder") self.history.do(change1) - self.assertEqual(set([change1]), - set(self.history.get_file_undo_list(old_file))) + self.assertEqual(set([change1]), set(self.history.get_file_undo_list(old_file))) def test_clearing_redo_list_after_do(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") self.history.do(change) self.history.undo() self.history.do(change) self.assertEqual(0, len(self.history.redo_list)) def test_undoing_a_not_yet_performed_change(self): - change = rope.base.change.ChangeContents(self.file1, '1') + change = rope.base.change.ChangeContents(self.file1, "1") str(change) with self.assertRaises(exceptions.HistoryError): change.undo() def test_clearing_up_the_history(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') - change2 = rope.base.change.ChangeContents(self.file1, '2') + change1 = rope.base.change.ChangeContents(self.file1, "1") + change2 = rope.base.change.ChangeContents(self.file1, "2") self.history.do(change1) self.history.do(change2) self.history.undo() @@ -265,20 +264,19 @@ def test_clearing_up_the_history(self): self.assertEqual(0, len(self.history.redo_list)) def test_redoing_choosen_changes_not_undoing_others(self): - change1 = rope.base.change.ChangeContents(self.file1, '1') - change2 = rope.base.change.ChangeContents(self.file2, '2') + change1 = rope.base.change.ChangeContents(self.file1, "1") + change2 = rope.base.change.ChangeContents(self.file2, "2") self.history.do(change1) self.history.do(change2) self.history.undo() self.history.undo() redone = self.history.redo(change2) self.assertEqual([change2], redone) - self.assertEqual('', self.file1.read()) - self.assertEqual('2', self.file2.read()) + self.assertEqual("", self.file1.read()) + self.assertEqual("2", self.file2.read()) class SavingHistoryTest(unittest.TestCase): - def setUp(self): super(SavingHistoryTest, self).setUp() self.project = testutils.sample_project() @@ -291,37 +289,37 @@ def tearDown(self): super(SavingHistoryTest, self).tearDown() def test_simple_set_saving(self): - data = self.to_data(rope.base.change.ChangeSet('testing')) + data = self.to_data(rope.base.change.ChangeSet("testing")) change = self.to_change(data) - self.assertEqual('testing', str(change)) + self.assertEqual("testing", str(change)) def test_simple_change_content_saving(self): - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") myfile.create() - myfile.write('1') - data = self.to_data(rope.base.change.ChangeContents(myfile, '2')) + myfile.write("1") + data = self.to_data(rope.base.change.ChangeContents(myfile, "2")) change = self.to_change(data) self.history.do(change) - self.assertEqual('2', myfile.read()) + self.assertEqual("2", myfile.read()) self.history.undo() - self.assertEqual('1', change.old_contents) + self.assertEqual("1", change.old_contents) def test_move_resource_saving(self): - myfile = self.project.root.create_file('myfile.txt') - myfolder = self.project.root.create_folder('myfolder') - data = self.to_data(rope.base.change.MoveResource(myfile, 'myfolder')) + myfile = self.project.root.create_file("myfile.txt") + myfolder = self.project.root.create_folder("myfolder") + data = self.to_data(rope.base.change.MoveResource(myfile, "myfolder")) change = self.to_change(data) self.history.do(change) self.assertFalse(myfile.exists()) - self.assertTrue(myfolder.has_child('myfile.txt')) + self.assertTrue(myfolder.has_child("myfile.txt")) self.history.undo() self.assertTrue(myfile.exists()) - self.assertFalse(myfolder.has_child('myfile.txt')) + self.assertFalse(myfolder.has_child("myfile.txt")) def test_move_resource_saving_for_folders(self): - myfolder = self.project.root.create_folder('myfolder') - newfolder = self.project.get_folder('newfolder') - change = rope.base.change.MoveResource(myfolder, 'newfolder') + myfolder = self.project.root.create_folder("myfolder") + newfolder = self.project.get_folder("newfolder") + change = rope.base.change.MoveResource(myfolder, "newfolder") self.history.do(change) data = self.to_data(change) @@ -331,9 +329,10 @@ def test_move_resource_saving_for_folders(self): self.assertFalse(newfolder.exists()) def test_create_file_saving(self): - myfile = self.project.get_file('myfile.txt') - data = self.to_data(rope.base.change.CreateFile(self.project.root, - 'myfile.txt')) + myfile = self.project.get_file("myfile.txt") + data = self.to_data( + rope.base.change.CreateFile(self.project.root, "myfile.txt") + ) change = self.to_change(data) self.history.do(change) self.assertTrue(myfile.exists()) @@ -341,9 +340,10 @@ def test_create_file_saving(self): self.assertFalse(myfile.exists()) def test_create_folder_saving(self): - myfolder = self.project.get_folder('myfolder') - data = self.to_data(rope.base.change.CreateFolder(self.project.root, - 'myfolder')) + myfolder = self.project.get_folder("myfolder") + data = self.to_data( + rope.base.change.CreateFolder(self.project.root, "myfolder") + ) change = self.to_change(data) self.history.do(change) self.assertTrue(myfolder.exists()) @@ -351,7 +351,7 @@ def test_create_folder_saving(self): self.assertFalse(myfolder.exists()) def test_create_resource_saving(self): - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") data = self.to_data(rope.base.change.CreateResource(myfile)) change = self.to_change(data) self.history.do(change) @@ -360,30 +360,30 @@ def test_create_resource_saving(self): self.assertFalse(myfile.exists()) def test_remove_resource_saving(self): - myfile = self.project.root.create_file('myfile.txt') + myfile = self.project.root.create_file("myfile.txt") data = self.to_data(rope.base.change.RemoveResource(myfile)) change = self.to_change(data) self.history.do(change) self.assertFalse(myfile.exists()) def test_change_set_saving(self): - change = rope.base.change.ChangeSet('testing') - myfile = self.project.get_file('myfile.txt') + change = rope.base.change.ChangeSet("testing") + myfile = self.project.get_file("myfile.txt") change.add_change(rope.base.change.CreateResource(myfile)) - change.add_change(rope.base.change.ChangeContents(myfile, '1')) + change.add_change(rope.base.change.ChangeContents(myfile, "1")) data = self.to_data(change) change = self.to_change(data) self.history.do(change) - self.assertEqual('1', myfile.read()) + self.assertEqual("1", myfile.read()) self.history.undo() self.assertFalse(myfile.exists()) def test_writing_and_reading_history(self): - history_file = self.project.get_file('history.pickle') # noqa - self.project.set('save_history', True) + history_file = self.project.get_file("history.pickle") # noqa + self.project.set("save_history", True) history = rope.base.history.History(self.project) - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") history.do(rope.base.change.CreateResource(myfile)) history.write() @@ -392,10 +392,10 @@ def test_writing_and_reading_history(self): self.assertFalse(myfile.exists()) def test_writing_and_reading_history2(self): - history_file = self.project.get_file('history.pickle') # noqa - self.project.set('save_history', True) + history_file = self.project.get_file("history.pickle") # noqa + self.project.set("save_history", True) history = rope.base.history.History(self.project) - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") history.do(rope.base.change.CreateResource(myfile)) history.undo() history.write() diff --git a/ropetest/objectdbtest.py b/ropetest/objectdbtest.py index b958d93fa..a4561b4d1 100644 --- a/ropetest/objectdbtest.py +++ b/ropetest/objectdbtest.py @@ -12,11 +12,11 @@ def _do_for_all_dbs(function): def called(self): for db in self.dbs: function(self, db) + return called class _MockValidation(object): - def is_value_valid(self, value): return value != -1 @@ -24,31 +24,29 @@ def is_more_valid(self, new, old): return new != -1 def is_file_valid(self, path): - return path != 'invalid' + return path != "invalid" def is_scope_valid(self, path, key): - return path != 'invalid' and key != 'invalid' + return path != "invalid" and key != "invalid" class _MockFileListObserver(object): - log = '' + log = "" def added(self, path): - self.log += 'added %s ' % path + self.log += "added %s " % path def removed(self, path): - self.log += 'removed %s ' % path + self.log += "removed %s " % path class ObjectDBTest(unittest.TestCase): - def setUp(self): super(ObjectDBTest, self).setUp() self.project = testutils.sample_project() validation = _MockValidation() - self.dbs = [ - objectdb.ObjectDB(memorydb.MemoryDB(self.project), validation)] + self.dbs = [objectdb.ObjectDB(memorydb.MemoryDB(self.project), validation)] def tearDown(self): for db in self.dbs: @@ -58,97 +56,97 @@ def tearDown(self): @_do_for_all_dbs def test_simple_per_name(self, db): - db.add_pername('file', 'key', 'name', 1) - self.assertEqual(1, db.get_pername('file', 'key', 'name')) + db.add_pername("file", "key", "name", 1) + self.assertEqual(1, db.get_pername("file", "key", "name")) @_do_for_all_dbs def test_simple_per_name_does_not_exist(self, db): - self.assertEqual(None, db.get_pername('file', 'key', 'name')) + self.assertEqual(None, db.get_pername("file", "key", "name")) @_do_for_all_dbs def test_simple_per_name_after_syncing(self, db): - db.add_pername('file', 'key', 'name', 1) + db.add_pername("file", "key", "name", 1) db.write() - self.assertEqual(1, db.get_pername('file', 'key', 'name')) + self.assertEqual(1, db.get_pername("file", "key", "name")) @_do_for_all_dbs def test_getting_returned(self, db): - db.add_callinfo('file', 'key', (1, 2), 3) - self.assertEqual(3, db.get_returned('file', 'key', (1, 2))) + db.add_callinfo("file", "key", (1, 2), 3) + self.assertEqual(3, db.get_returned("file", "key", (1, 2))) @_do_for_all_dbs def test_getting_returned_when_does_not_match(self, db): - db.add_callinfo('file', 'key', (1, 2), 3) - self.assertEqual(None, db.get_returned('file', 'key', (1, 1))) + db.add_callinfo("file", "key", (1, 2), 3) + self.assertEqual(None, db.get_returned("file", "key", (1, 1))) @_do_for_all_dbs def test_getting_call_info(self, db): - db.add_callinfo('file', 'key', (1, 2), 3) + db.add_callinfo("file", "key", (1, 2), 3) - call_infos = list(db.get_callinfos('file', 'key')) + call_infos = list(db.get_callinfos("file", "key")) self.assertEqual(1, len(call_infos)) self.assertEqual((1, 2), call_infos[0].get_parameters()) self.assertEqual(3, call_infos[0].get_returned()) @_do_for_all_dbs def test_invalid_per_name(self, db): - db.add_pername('file', 'key', 'name', -1) - self.assertEqual(None, db.get_pername('file', 'key', 'name')) + db.add_pername("file", "key", "name", -1) + self.assertEqual(None, db.get_pername("file", "key", "name")) @_do_for_all_dbs def test_overwriting_per_name(self, db): - db.add_pername('file', 'key', 'name', 1) - db.add_pername('file', 'key', 'name', 2) - self.assertEqual(2, db.get_pername('file', 'key', 'name')) + db.add_pername("file", "key", "name", 1) + db.add_pername("file", "key", "name", 2) + self.assertEqual(2, db.get_pername("file", "key", "name")) @_do_for_all_dbs def test_not_overwriting_with_invalid_per_name(self, db): - db.add_pername('file', 'key', 'name', 1) - db.add_pername('file', 'key', 'name', -1) - self.assertEqual(1, db.get_pername('file', 'key', 'name')) + db.add_pername("file", "key", "name", 1) + db.add_pername("file", "key", "name", -1) + self.assertEqual(1, db.get_pername("file", "key", "name")) @_do_for_all_dbs def test_getting_invalid_returned(self, db): - db.add_callinfo('file', 'key', (1, 2), -1) - self.assertEqual(None, db.get_returned('file', 'key', (1, 2))) + db.add_callinfo("file", "key", (1, 2), -1) + self.assertEqual(None, db.get_returned("file", "key", (1, 2))) @_do_for_all_dbs def test_not_overwriting_with_invalid_returned(self, db): - db.add_callinfo('file', 'key', (1, 2), 3) - db.add_callinfo('file', 'key', (1, 2), -1) - self.assertEqual(3, db.get_returned('file', 'key', (1, 2))) + db.add_callinfo("file", "key", (1, 2), 3) + db.add_callinfo("file", "key", (1, 2), -1) + self.assertEqual(3, db.get_returned("file", "key", (1, 2))) @_do_for_all_dbs def test_get_files(self, db): - db.add_callinfo('file1', 'key', (1, 2), 3) - db.add_callinfo('file2', 'key', (1, 2), 3) - self.assertEqual(set(['file1', 'file2']), set(db.get_files())) + db.add_callinfo("file1", "key", (1, 2), 3) + db.add_callinfo("file2", "key", (1, 2), 3) + self.assertEqual(set(["file1", "file2"]), set(db.get_files())) @_do_for_all_dbs def test_validating_files(self, db): - db.add_callinfo('invalid', 'key', (1, 2), 3) + db.add_callinfo("invalid", "key", (1, 2), 3) db.validate_files() self.assertEqual(0, len(db.get_files())) @_do_for_all_dbs def test_validating_file_for_scopes(self, db): - db.add_callinfo('file', 'invalid', (1, 2), 3) - db.validate_file('file') + db.add_callinfo("file", "invalid", (1, 2), 3) + db.validate_file("file") self.assertEqual(1, len(db.get_files())) - self.assertEqual(0, len(list(db.get_callinfos('file', 'invalid')))) + self.assertEqual(0, len(list(db.get_callinfos("file", "invalid")))) @_do_for_all_dbs def test_validating_file_moved(self, db): - db.add_callinfo('file', 'key', (1, 2), 3) + db.add_callinfo("file", "key", (1, 2), 3) - db.file_moved('file', 'newfile') + db.file_moved("file", "newfile") self.assertEqual(1, len(db.get_files())) - self.assertEqual(1, len(list(db.get_callinfos('newfile', 'key')))) + self.assertEqual(1, len(list(db.get_callinfos("newfile", "key")))) @_do_for_all_dbs def test_using_file_list_observer(self, db): - db.add_callinfo('invalid', 'key', (1, 2), 3) + db.add_callinfo("invalid", "key", (1, 2), 3) observer = _MockFileListObserver() db.add_file_list_observer(observer) db.validate_files() - self.assertEqual('removed invalid ', observer.log) + self.assertEqual("removed invalid ", observer.log) diff --git a/ropetest/objectinfertest.py b/ropetest/objectinfertest.py index e930bc4df..b4e49ef95 100644 --- a/ropetest/objectinfertest.py +++ b/ropetest/objectinfertest.py @@ -10,7 +10,6 @@ class ObjectInferTest(unittest.TestCase): - def setUp(self): super(ObjectInferTest, self).setUp() self.project = testutils.sample_project() @@ -20,305 +19,339 @@ def tearDown(self): super(ObjectInferTest, self).tearDown() def test_simple_type_inferencing(self): - code = 'class Sample(object):\n pass\na_var = Sample()\n' + code = "class Sample(object):\n pass\na_var = Sample()\n" scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() - a_var = scope['a_var'].get_object() + sample_class = scope["Sample"].get_object() + a_var = scope["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_simple_type_inferencing_classes_defined_in_holding_scope(self): - code = 'class Sample(object):\n pass\n' \ - 'def a_func():\n a_var = Sample()\n' + code = ( + "class Sample(object):\n pass\n" "def a_func():\n a_var = Sample()\n" + ) scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() - a_var = scope['a_func'].get_object().\ - get_scope()['a_var'].get_object() + sample_class = scope["Sample"].get_object() + a_var = scope["a_func"].get_object().get_scope()["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_simple_type_inferencing_classes_in_class_methods(self): - code = 'class Sample(object):\n pass\n' \ - 'class Another(object):\n' \ - ' def a_method():\n a_var = Sample()\n' + code = ( + "class Sample(object):\n pass\n" + "class Another(object):\n" + " def a_method():\n a_var = Sample()\n" + ) scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() - another_class = scope['Another'].get_object() - a_var = another_class['a_method'].\ - get_object().get_scope()['a_var'].get_object() + sample_class = scope["Sample"].get_object() + another_class = scope["Another"].get_object() + a_var = another_class["a_method"].get_object().get_scope()["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_simple_type_inferencing_class_attributes(self): - code = 'class Sample(object):\n pass\n' \ - 'class Another(object):\n' \ - ' def __init__(self):\n self.a_var = Sample()\n' + code = ( + "class Sample(object):\n pass\n" + "class Another(object):\n" + " def __init__(self):\n self.a_var = Sample()\n" + ) scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() - another_class = scope['Another'].get_object() - a_var = another_class['a_var'].get_object() + sample_class = scope["Sample"].get_object() + another_class = scope["Another"].get_object() + a_var = another_class["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_simple_type_inferencing_for_in_class_assignments(self): - code = 'class Sample(object):\n pass\n' \ - 'class Another(object):\n an_attr = Sample()\n' + code = ( + "class Sample(object):\n pass\n" + "class Another(object):\n an_attr = Sample()\n" + ) scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() - another_class = scope['Another'].get_object() - an_attr = another_class['an_attr'].get_object() + sample_class = scope["Sample"].get_object() + another_class = scope["Another"].get_object() + an_attr = another_class["an_attr"].get_object() self.assertEqual(sample_class, an_attr.get_type()) def test_simple_type_inferencing_for_chained_assignments(self): - mod = 'class Sample(object):\n pass\n' \ - 'copied_sample = Sample' + mod = "class Sample(object):\n pass\n" "copied_sample = Sample" mod_scope = libutils.get_string_scope(self.project, mod) - sample_class = mod_scope['Sample'] - copied_sample = mod_scope['copied_sample'] - self.assertEqual(sample_class.get_object(), - copied_sample.get_object()) + sample_class = mod_scope["Sample"] + copied_sample = mod_scope["copied_sample"] + self.assertEqual(sample_class.get_object(), copied_sample.get_object()) def test_following_chained_assignments_avoiding_circles(self): - mod = 'class Sample(object):\n pass\n' \ - 'sample_class = Sample\n' \ - 'sample_class = sample_class\n' + mod = ( + "class Sample(object):\n pass\n" + "sample_class = Sample\n" + "sample_class = sample_class\n" + ) mod_scope = libutils.get_string_scope(self.project, mod) - sample_class = mod_scope['Sample'] - sample_class_var = mod_scope['sample_class'] - self.assertEqual(sample_class.get_object(), - sample_class_var.get_object()) + sample_class = mod_scope["Sample"] + sample_class_var = mod_scope["sample_class"] + self.assertEqual(sample_class.get_object(), sample_class_var.get_object()) def test_function_returned_object_static_type_inference1(self): - src = 'class Sample(object):\n pass\n' \ - 'def a_func():\n return Sample\n' \ - 'a_var = a_func()\n' + src = ( + "class Sample(object):\n pass\n" + "def a_func():\n return Sample\n" + "a_var = a_func()\n" + ) scope = libutils.get_string_scope(self.project, src) - sample_class = scope['Sample'] - a_var = scope['a_var'] + sample_class = scope["Sample"] + a_var = scope["a_var"] self.assertEqual(sample_class.get_object(), a_var.get_object()) def test_function_returned_object_static_type_inference2(self): - src = 'class Sample(object):\n pass\n' \ - 'def a_func():\n return Sample()\n' \ - 'a_var = a_func()\n' + src = ( + "class Sample(object):\n pass\n" + "def a_func():\n return Sample()\n" + "a_var = a_func()\n" + ) scope = libutils.get_string_scope(self.project, src) - sample_class = scope['Sample'].get_object() - a_var = scope['a_var'].get_object() + sample_class = scope["Sample"].get_object() + a_var = scope["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_recursive_function_returned_object_static_type_inference(self): - src = 'class Sample(object):\n pass\n' \ - 'def a_func():\n' \ - ' if True:\n return Sample()\n' \ - ' else:\n return a_func()\n' \ - 'a_var = a_func()\n' + src = ( + "class Sample(object):\n pass\n" + "def a_func():\n" + " if True:\n return Sample()\n" + " else:\n return a_func()\n" + "a_var = a_func()\n" + ) scope = libutils.get_string_scope(self.project, src) - sample_class = scope['Sample'].get_object() - a_var = scope['a_var'].get_object() + sample_class = scope["Sample"].get_object() + a_var = scope["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_func_returned_obj_using_call_spec_func_static_type_infer(self): - src = 'class Sample(object):\n' \ - ' def __call__(self):\n return Sample\n' \ - 'sample = Sample()\na_var = sample()' + src = ( + "class Sample(object):\n" + " def __call__(self):\n return Sample\n" + "sample = Sample()\na_var = sample()" + ) scope = libutils.get_string_scope(self.project, src) - sample_class = scope['Sample'] - a_var = scope['a_var'] + sample_class = scope["Sample"] + a_var = scope["a_var"] self.assertEqual(sample_class.get_object(), a_var.get_object()) def test_list_type_inferencing(self): - src = 'class Sample(object):\n pass\na_var = [Sample()]\n' + src = "class Sample(object):\n pass\na_var = [Sample()]\n" scope = libutils.get_string_scope(self.project, src) - sample_class = scope['Sample'].get_object() - a_var = scope['a_var'].get_object() + sample_class = scope["Sample"].get_object() + a_var = scope["a_var"].get_object() self.assertNotEqual(sample_class, a_var.get_type()) def test_attributed_object_inference(self): - src = 'class Sample(object):\n' \ - ' def __init__(self):\n self.a_var = None\n' \ - ' def set(self):\n self.a_var = Sample()\n' + src = ( + "class Sample(object):\n" + " def __init__(self):\n self.a_var = None\n" + " def set(self):\n self.a_var = Sample()\n" + ) scope = libutils.get_string_scope(self.project, src) - sample_class = scope['Sample'].get_object() - a_var = sample_class['a_var'].get_object() + sample_class = scope["Sample"].get_object() + a_var = sample_class["a_var"].get_object() self.assertEqual(sample_class, a_var.get_type()) def test_getting_property_attributes(self): - src = 'class A(object):\n pass\n' \ - 'def f(*args):\n return A()\n' \ - 'class B(object):\n p = property(f)\n' \ - 'a_var = B().p\n' + src = ( + "class A(object):\n pass\n" + "def f(*args):\n return A()\n" + "class B(object):\n p = property(f)\n" + "a_var = B().p\n" + ) pymod = libutils.get_string_module(self.project, src) - a_class = pymod['A'].get_object() - a_var = pymod['a_var'].get_object() + a_class = pymod["A"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(a_class, a_var.get_type()) def test_getting_property_attributes_with_method_getters(self): - src = 'class A(object):\n pass\n' \ - 'class B(object):\n def p_get(self):\n return A()\n' \ - ' p = property(p_get)\n' \ - 'a_var = B().p\n' + src = ( + "class A(object):\n pass\n" + "class B(object):\n def p_get(self):\n return A()\n" + " p = property(p_get)\n" + "a_var = B().p\n" + ) pymod = libutils.get_string_module(self.project, src) - a_class = pymod['A'].get_object() - a_var = pymod['a_var'].get_object() + a_class = pymod["A"].get_object() + a_var = pymod["a_var"].get_object() self.assertEqual(a_class, a_var.get_type()) def test_lambda_functions(self): - code = 'class C(object):\n pass\n' \ - 'l = lambda: C()\na_var = l()' + code = "class C(object):\n pass\n" "l = lambda: C()\na_var = l()" mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['a_var'].get_object() + c_class = mod["C"].get_object() + a_var = mod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_mixing_subscript_with_tuple_assigns(self): - code = 'class C(object):\n attr = 0\n' \ - 'd = {}\nd[0], b = (0, C())\n' + code = "class C(object):\n attr = 0\n" "d = {}\nd[0], b = (0, C())\n" mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['b'].get_object() + c_class = mod["C"].get_object() + a_var = mod["b"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_mixing_ass_attr_with_tuple_assignment(self): - code = 'class C(object):\n attr = 0\n' \ - 'c = C()\nc.attr, b = (0, C())\n' + code = "class C(object):\n attr = 0\n" "c = C()\nc.attr, b = (0, C())\n" mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['b'].get_object() + c_class = mod["C"].get_object() + a_var = mod["b"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_mixing_slice_with_tuple_assigns(self): mod = libutils.get_string_module( self.project, - 'class C(object):\n attr = 0\n' - 'd = [None] * 3\nd[0:2], b = ((0,), C())\n') - c_class = mod['C'].get_object() - a_var = mod['b'].get_object() + "class C(object):\n attr = 0\n" + "d = [None] * 3\nd[0:2], b = ((0,), C())\n", + ) + c_class = mod["C"].get_object() + a_var = mod["b"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_nested_tuple_assignments(self): mod = libutils.get_string_module( self.project, - 'class C1(object):\n pass\nclass C2(object):\n pass\n' - 'a, (b, c) = (C1(), (C2(), C1()))\n') - c1_class = mod['C1'].get_object() - c2_class = mod['C2'].get_object() - a_var = mod['a'].get_object() - b_var = mod['b'].get_object() - c_var = mod['c'].get_object() + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "a, (b, c) = (C1(), (C2(), C1()))\n", + ) + c1_class = mod["C1"].get_object() + c2_class = mod["C2"].get_object() + a_var = mod["a"].get_object() + b_var = mod["b"].get_object() + c_var = mod["c"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) self.assertEqual(c1_class, c_var.get_type()) def test_empty_tuples(self): - mod = libutils.get_string_module( - self.project, 't = ()\na, b = t\n') - a = mod['a'].get_object() # noqa + mod = libutils.get_string_module(self.project, "t = ()\na, b = t\n") + a = mod["a"].get_object() # noqa def test_handling_generator_functions(self): - code = 'class C(object):\n pass\n' \ - 'def f():\n yield C()\n' \ - 'for c in f():\n a_var = c\n' + code = ( + "class C(object):\n pass\n" + "def f():\n yield C()\n" + "for c in f():\n a_var = c\n" + ) mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['a_var'].get_object() + c_class = mod["C"].get_object() + a_var = mod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_handling_generator_functions_for_strs(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('def f():\n yield ""\n' - 'for s in f():\n a_var = s\n') + mod = testutils.create_module(self.project, "mod") + mod.write('def f():\n yield ""\n' "for s in f():\n a_var = s\n") pymod = self.project.get_pymodule(mod) - a_var = pymod['a_var'].get_object() + a_var = pymod["a_var"].get_object() self.assertTrue(isinstance(a_var.get_type(), rope.base.builtins.Str)) def test_considering_nones_to_be_unknowns(self): - code = 'class C(object):\n pass\n' \ - 'a_var = None\na_var = C()\na_var = None\n' + code = ( + "class C(object):\n pass\n" "a_var = None\na_var = C()\na_var = None\n" + ) mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['a_var'].get_object() + c_class = mod["C"].get_object() + a_var = mod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_basic_list_comprehensions(self): - code = 'class C(object):\n pass\n' \ - 'l = [C() for i in range(1)]\na_var = l[0]\n' + code = ( + "class C(object):\n pass\n" "l = [C() for i in range(1)]\na_var = l[0]\n" + ) mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['a_var'].get_object() + c_class = mod["C"].get_object() + a_var = mod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_basic_generator_expressions(self): - code = 'class C(object):\n pass\n' \ - 'l = (C() for i in range(1))\na_var = list(l)[0]\n' + code = ( + "class C(object):\n pass\n" + "l = (C() for i in range(1))\na_var = list(l)[0]\n" + ) mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['a_var'].get_object() + c_class = mod["C"].get_object() + a_var = mod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_list_comprehensions_and_loop_var(self): - code = 'class C(object):\n pass\n' \ - 'c_objects = [C(), C()]\n' \ - 'l = [c for c in c_objects]\na_var = l[0]\n' + code = ( + "class C(object):\n pass\n" + "c_objects = [C(), C()]\n" + "l = [c for c in c_objects]\na_var = l[0]\n" + ) mod = libutils.get_string_module(self.project, code) - c_class = mod['C'].get_object() - a_var = mod['a_var'].get_object() + c_class = mod["C"].get_object() + a_var = mod["a_var"].get_object() self.assertEqual(c_class, a_var.get_type()) def test_list_comprehensions_and_multiple_loop_var(self): - code = 'class C1(object):\n pass\n' \ - 'class C2(object):\n pass\n' \ - 'l = [(c1, c2) for c1 in [C1()] for c2 in [C2()]]\n' \ - 'a, b = l[0]\n' + code = ( + "class C1(object):\n pass\n" + "class C2(object):\n pass\n" + "l = [(c1, c2) for c1 in [C1()] for c2 in [C2()]]\n" + "a, b = l[0]\n" + ) mod = libutils.get_string_module(self.project, code) - c1_class = mod['C1'].get_object() - c2_class = mod['C2'].get_object() - a_var = mod['a'].get_object() - b_var = mod['b'].get_object() + c1_class = mod["C1"].get_object() + c2_class = mod["C2"].get_object() + a_var = mod["a"].get_object() + b_var = mod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_list_comprehensions_and_multiple_iters(self): mod = libutils.get_string_module( self.project, - 'class C1(object):\n pass\nclass C2(object):\n pass\n' - 'l = [(c1, c2) for c1, c2 in [(C1(), C2())]]\n' - 'a, b = l[0]\n') - c1_class = mod['C1'].get_object() - c2_class = mod['C2'].get_object() - a_var = mod['a'].get_object() - b_var = mod['b'].get_object() + "class C1(object):\n pass\nclass C2(object):\n pass\n" + "l = [(c1, c2) for c1, c2 in [(C1(), C2())]]\n" + "a, b = l[0]\n", + ) + c1_class = mod["C1"].get_object() + c2_class = mod["C2"].get_object() + a_var = mod["a"].get_object() + b_var = mod["b"].get_object() self.assertEqual(c1_class, a_var.get_type()) self.assertEqual(c2_class, b_var.get_type()) def test_we_know_the_type_of_catched_exceptions(self): - code = 'class MyError(Exception):\n pass\n' \ - 'try:\n raise MyError()\n' \ - 'except MyError as e:\n pass\n' + code = ( + "class MyError(Exception):\n pass\n" + "try:\n raise MyError()\n" + "except MyError as e:\n pass\n" + ) mod = libutils.get_string_module(self.project, code) - my_error = mod['MyError'].get_object() - e_var = mod['e'].get_object() + my_error = mod["MyError"].get_object() + e_var = mod["e"].get_object() self.assertEqual(my_error, e_var.get_type()) def test_we_know_the_type_of_catched_multiple_excepts(self): - code = 'class MyError(Exception):\n pass\n' \ - 'try:\n raise MyError()\n' \ - 'except (MyError, Exception) as e:\n pass\n' + code = ( + "class MyError(Exception):\n pass\n" + "try:\n raise MyError()\n" + "except (MyError, Exception) as e:\n pass\n" + ) mod = libutils.get_string_module(self.project, code) - my_error = mod['MyError'].get_object() - e_var = mod['e'].get_object() + my_error = mod["MyError"].get_object() + e_var = mod["e"].get_object() self.assertEqual(my_error, e_var.get_type()) def test_using_property_as_decorators(self): - code = 'class A(object):\n pass\n' \ - 'class B(object):\n' \ - ' @property\n def f(self):\n return A()\n' \ - 'b = B()\nvar = b.f\n' + code = ( + "class A(object):\n pass\n" + "class B(object):\n" + " @property\n def f(self):\n return A()\n" + "b = B()\nvar = b.f\n" + ) mod = libutils.get_string_module(self.project, code) - var = mod['var'].get_object() - a = mod['A'].get_object() + var = mod["var"].get_object() + a = mod["A"].get_object() self.assertEqual(a, var.get_type()) def test_using_property_as_decorators_and_passing_parameter(self): - code = 'class B(object):\n' \ - ' @property\n def f(self):\n return self\n' \ - 'b = B()\nvar = b.f\n' + code = ( + "class B(object):\n" + " @property\n def f(self):\n return self\n" + "b = B()\nvar = b.f\n" + ) mod = libutils.get_string_module(self.project, code) - var = mod['var'].get_object() - a = mod['B'].get_object() + var = mod["var"].get_object() + a = mod["B"].get_object() self.assertEqual(a, var.get_type()) diff --git a/ropetest/projecttest.py b/ropetest/projecttest.py index dda3ec72b..e2728d9e2 100644 --- a/ropetest/projecttest.py +++ b/ropetest/projecttest.py @@ -2,6 +2,7 @@ from textwrap import dedent import os.path import shutil + try: import unittest2 as unittest except ImportError: @@ -16,24 +17,24 @@ class ProjectTest(unittest.TestCase): - def setUp(self): unittest.TestCase.setUp(self) - self.project = testutils.sample_project(foldername='sampleproject', - ropefolder=None) + self.project = testutils.sample_project( + foldername="sampleproject", ropefolder=None + ) self.project_root = self.project.address self._make_sample_project() self.no_project = NoProject() def _make_sample_project(self): - self.sample_file = 'sample_file.txt' - self.sample_path = os.path.join(self.project_root, 'sample_file.txt') + self.sample_file = "sample_file.txt" + self.sample_path = os.path.join(self.project_root, "sample_file.txt") if not os.path.exists(self.project_root): os.mkdir(self.project_root) - self.sample_folder = 'sample_folder' + self.sample_folder = "sample_folder" os.mkdir(os.path.join(self.project_root, self.sample_folder)) - sample = open(self.sample_path, 'w') - sample.write('sample text\n') + sample = open(self.sample_path, "w") + sample.write("sample text\n") sample.close() def tearDown(self): @@ -41,8 +42,7 @@ def tearDown(self): unittest.TestCase.tearDown(self) def test_project_creation(self): - self.assertEqual(_realpath(self.project_root), - self.project.address) + self.assertEqual(_realpath(self.project_root), self.project.address) def test_getting_project_file(self): project_file = self.project.get_resource(self.sample_file) @@ -50,19 +50,19 @@ def test_getting_project_file(self): def test_project_file_reading(self): projectFile = self.project.get_resource(self.sample_file) - self.assertEqual('sample text\n', projectFile.read()) + self.assertEqual("sample text\n", projectFile.read()) def test_getting_not_existing_project_file(self): with self.assertRaises(ResourceNotFoundError): - self.project.get_resource('DoesNotExistFile.txt') + self.project.get_resource("DoesNotExistFile.txt") def test_writing_in_project_files(self): project_file = self.project.get_resource(self.sample_file) - project_file.write('another text\n') - self.assertEqual('another text\n', project_file.read()) + project_file.write("another text\n") + self.assertEqual("another text\n", project_file.read()) def test_creating_files(self): - project_file = 'newfile.txt' + project_file = "newfile.txt" self.project.root.create_file(project_file) newFile = self.project.get_resource(project_file) self.assertTrue(newFile is not None) @@ -72,66 +72,64 @@ def test_creating_files_that_already_exist(self): self.project.root.create_file(self.sample_file) def test_making_root_folder_if_it_does_not_exist(self): - project = Project('sampleproject2') + project = Project("sampleproject2") try: - self.assertTrue(os.path.exists('sampleproject2') and - os.path.isdir('sampleproject2')) + self.assertTrue( + os.path.exists("sampleproject2") and os.path.isdir("sampleproject2") + ) finally: testutils.remove_project(project) def test_failure_when_project_root_exists_and_is_a_file(self): - project_root = 'sampleproject2' + project_root = "sampleproject2" try: - open(project_root, 'w').close() + open(project_root, "w").close() with self.assertRaises(RopeError): Project(project_root) finally: testutils.remove_recursively(project_root) def test_creating_folders(self): - folderName = 'SampleFolder' + folderName = "SampleFolder" self.project.root.create_folder(folderName) folderPath = os.path.join(self.project.address, folderName) - self.assertTrue(os.path.exists(folderPath) and - os.path.isdir(folderPath)) + self.assertTrue(os.path.exists(folderPath) and os.path.isdir(folderPath)) def test_making_folder_that_already_exists(self): - folderName = 'SampleFolder' + folderName = "SampleFolder" with self.assertRaises(RopeError): self.project.root.create_folder(folderName) self.project.root.create_folder(folderName) def test_failing_if_creating_folder_while_file_already_exists(self): - folderName = 'SampleFolder' + folderName = "SampleFolder" with self.assertRaises(RopeError): self.project.root.create_file(folderName) self.project.root.create_folder(folderName) def test_creating_file_inside_folder(self): - folder_name = 'sampleFolder' - file_name = 'sample2.txt' - file_path = folder_name + '/' + file_name + folder_name = "sampleFolder" + file_name = "sample2.txt" + file_path = folder_name + "/" + file_name parent_folder = self.project.root.create_folder(folder_name) parent_folder.create_file(file_name) file = self.project.get_resource(file_path) - file.write('sample notes') + file.write("sample notes") self.assertEqual(file_path, file.path) - self.assertEqual('sample notes', - open(os.path.join(self.project.address, - file_path)).read()) + self.assertEqual( + "sample notes", open(os.path.join(self.project.address, file_path)).read() + ) def test_failing_when_creating_file_inside_non_existent_folder(self): with self.assertRaises(ResourceNotFoundError): - self.project.root.create_file('NonexistentFolder/SomeFile.txt') + self.project.root.create_file("NonexistentFolder/SomeFile.txt") def test_nested_directories(self): - folder_name = 'SampleFolder' + folder_name = "SampleFolder" parent = self.project.root.create_folder(folder_name) parent.create_folder(folder_name) - folder_path = os.path.join(self.project.address, - folder_name, folder_name) - self.assertTrue(os.path.exists(folder_path) and - os.path.isdir(folder_path)) + folder_path = os.path.join(self.project.address, folder_name, folder_name) + self.assertTrue(os.path.exists(folder_path) and os.path.isdir(folder_path)) def test_removing_files(self): self.assertTrue(os.path.exists(self.sample_path)) @@ -140,36 +138,40 @@ def test_removing_files(self): def test_removing_files_invalidating_in_project_resource_pool(self): root_folder = self.project.root - my_file = root_folder.create_file('my_file.txt') + my_file = root_folder.create_file("my_file.txt") my_file.remove() - self.assertFalse(root_folder.has_child('my_file.txt')) + self.assertFalse(root_folder.has_child("my_file.txt")) def test_removing_directories(self): - self.assertTrue(os.path.exists(os.path.join(self.project.address, - self.sample_folder))) + self.assertTrue( + os.path.exists(os.path.join(self.project.address, self.sample_folder)) + ) self.project.get_resource(self.sample_folder).remove() - self.assertFalse(os.path.exists(os.path.join(self.project.address, - self.sample_folder))) + self.assertFalse( + os.path.exists(os.path.join(self.project.address, self.sample_folder)) + ) def test_removing_non_existent_files(self): with self.assertRaises(ResourceNotFoundError): - self.project.get_resource('NonExistentFile.txt').remove() + self.project.get_resource("NonExistentFile.txt").remove() def test_removing_nested_files(self): - file_name = self.sample_folder + '/sample_file.txt' + file_name = self.sample_folder + "/sample_file.txt" self.project.root.create_file(file_name) self.project.get_resource(file_name).remove() - self.assertTrue(os.path.exists(os.path.join(self.project.address, - self.sample_folder))) - self.assertTrue(not os.path.exists(os.path.join(self.project.address, - file_name))) + self.assertTrue( + os.path.exists(os.path.join(self.project.address, self.sample_folder)) + ) + self.assertTrue( + not os.path.exists(os.path.join(self.project.address, file_name)) + ) def test_file_get_name(self): file = self.project.get_resource(self.sample_file) self.assertEqual(self.sample_file, file.name) - file_name = 'nestedFile.txt' + file_name = "nestedFile.txt" parent = self.project.get_resource(self.sample_folder) - filePath = self.sample_folder + '/' + file_name + filePath = self.sample_folder + "/" + file_name parent.create_file(file_name) nestedFile = self.project.get_resource(filePath) self.assertEqual(file_name, nestedFile.name) @@ -181,9 +183,9 @@ def test_folder_get_name(self): def test_file_get_path(self): file = self.project.get_resource(self.sample_file) self.assertEqual(self.sample_file, file.path) - fileName = 'nestedFile.txt' + fileName = "nestedFile.txt" parent = self.project.get_resource(self.sample_folder) - filePath = self.sample_folder + '/' + fileName + filePath = self.sample_folder + "/" + fileName parent.create_file(fileName) nestedFile = self.project.get_resource(filePath) self.assertEqual(filePath, nestedFile.path) @@ -193,18 +195,16 @@ def test_folder_get_path(self): self.assertEqual(self.sample_folder, folder.path) def test_is_folder(self): - self.assertTrue(self.project.get_resource( - self.sample_folder).is_folder()) - self.assertTrue(not self.project.get_resource( - self.sample_file).is_folder()) + self.assertTrue(self.project.get_resource(self.sample_folder).is_folder()) + self.assertTrue(not self.project.get_resource(self.sample_file).is_folder()) def testget_children(self): children = self.project.get_resource(self.sample_folder).get_children() self.assertEqual([], children) def test_nonempty_get_children(self): - file_name = 'nestedfile.txt' - filePath = self.sample_folder + '/' + file_name + file_name = "nestedfile.txt" + filePath = self.sample_folder + "/" + file_name parent = self.project.get_resource(self.sample_folder) parent.create_file(file_name) children = parent.get_children() @@ -212,19 +212,19 @@ def test_nonempty_get_children(self): self.assertEqual(filePath, children[0].path) def test_nonempty_get_children2(self): - file_name = 'nestedfile.txt' - folder_name = 'nestedfolder.txt' - filePath = self.sample_folder + '/' + file_name - folderPath = self.sample_folder + '/' + folder_name + file_name = "nestedfile.txt" + folder_name = "nestedfolder.txt" + filePath = self.sample_folder + "/" + file_name + folderPath = self.sample_folder + "/" + folder_name parent = self.project.get_resource(self.sample_folder) parent.create_file(file_name) parent.create_folder(folder_name) children = parent.get_children() self.assertEqual(2, len(children)) - self.assertTrue(filePath == children[0].path or - filePath == children[1].path) - self.assertTrue(folderPath == children[0].path or - folderPath == children[1].path) + self.assertTrue(filePath == children[0].path or filePath == children[1].path) + self.assertTrue( + folderPath == children[0].path or folderPath == children[1].path + ) def test_does_not_fail_for_permission_denied(self): bad_dir = os.path.join(self.sample_folder, "bad_dir") @@ -247,32 +247,31 @@ def test_getting_files(self): def test_getting_folders(self): folders = self.project.root.get_folders() self.assertEqual(1, len(folders)) - self.assertTrue(self.project.get_resource( - self.sample_folder) in folders) + self.assertTrue(self.project.get_resource(self.sample_folder) in folders) def test_nested_folder_get_files(self): - parent = self.project.root.create_folder('top') - parent.create_file('file1.txt') - parent.create_file('file2.txt') + parent = self.project.root.create_folder("top") + parent.create_file("file1.txt") + parent.create_file("file2.txt") files = parent.get_files() self.assertEqual(2, len(files)) - self.assertTrue(self.project.get_resource('top/file2.txt') in files) + self.assertTrue(self.project.get_resource("top/file2.txt") in files) self.assertEqual(0, len(parent.get_folders())) def test_nested_folder_get_folders(self): - parent = self.project.root.create_folder('top') - parent.create_folder('dir1') - parent.create_folder('dir2') + parent = self.project.root.create_folder("top") + parent.create_folder("dir1") + parent.create_folder("dir2") folders = parent.get_folders() self.assertEqual(2, len(folders)) - self.assertTrue(self.project.get_resource('top/dir1') in folders) + self.assertTrue(self.project.get_resource("top/dir1") in folders) self.assertEqual(0, len(parent.get_files())) def test_root_folder(self): root_folder = self.project.root self.assertEqual(2, len(root_folder.get_children())) - self.assertEqual('', root_folder.path) - self.assertEqual('', root_folder.name) + self.assertEqual("", root_folder.path) + self.assertEqual("", root_folder.name) def test_get_all_files(self): files = tuple(self.project.get_files()) @@ -281,15 +280,15 @@ def test_get_all_files(self): def test_get_all_files_after_changing(self): self.assertEqual(1, len(self.project.get_files())) - myfile = self.project.root.create_file('myfile.txt') + myfile = self.project.root.create_file("myfile.txt") self.assertEqual(2, len(self.project.get_files())) - myfile.move('newfile.txt') + myfile.move("newfile.txt") self.assertEqual(2, len(self.project.get_files())) - self.project.get_file('newfile.txt').remove() + self.project.get_file("newfile.txt").remove() self.assertEqual(1, len(self.project.get_files())) def test_multifile_get_all_files(self): - fileName = 'nestedFile.txt' + fileName = "nestedFile.txt" parent = self.project.get_resource(self.sample_folder) parent.create_file(fileName) files = list(self.project.get_files()) @@ -298,159 +297,165 @@ def test_multifile_get_all_files(self): def test_ignoring_dot_pyc_files_in_get_files(self): root = self.project.address - src_folder = os.path.join(root, 'src') + src_folder = os.path.join(root, "src") os.mkdir(src_folder) - test_pyc = os.path.join(src_folder, 'test.pyc') - open(test_pyc, 'w').close() + test_pyc = os.path.join(src_folder, "test.pyc") + open(test_pyc, "w").close() for x in self.project.get_files(): - self.assertNotEqual('src/test.pyc', x.path) + self.assertNotEqual("src/test.pyc", x.path) def test_folder_creating_files(self): - projectFile = 'NewFile.txt' + projectFile = "NewFile.txt" self.project.root.create_file(projectFile) new_file = self.project.get_resource(projectFile) self.assertTrue(new_file is not None and not new_file.is_folder()) def test_folder_creating_nested_files(self): - project_file = 'NewFile.txt' + project_file = "NewFile.txt" parent_folder = self.project.get_resource(self.sample_folder) parent_folder.create_file(project_file) - new_file = self.project.get_resource(self.sample_folder - + '/' + project_file) + new_file = self.project.get_resource(self.sample_folder + "/" + project_file) self.assertTrue(new_file is not None and not new_file.is_folder()) def test_folder_creating_files2(self): - projectFile = 'newfolder' + projectFile = "newfolder" self.project.root.create_folder(projectFile) new_folder = self.project.get_resource(projectFile) self.assertTrue(new_folder is not None and new_folder.is_folder()) def test_folder_creating_nested_files2(self): - project_file = 'newfolder' + project_file = "newfolder" parent_folder = self.project.get_resource(self.sample_folder) parent_folder.create_folder(project_file) - new_folder = self.project.get_resource(self.sample_folder - + '/' + project_file) + new_folder = self.project.get_resource(self.sample_folder + "/" + project_file) self.assertTrue(new_folder is not None and new_folder.is_folder()) def test_folder_get_child(self): folder = self.project.root - folder.create_file('myfile.txt') - folder.create_folder('myfolder') - self.assertEqual(self.project.get_resource('myfile.txt'), - folder.get_child('myfile.txt')) - self.assertEqual(self.project.get_resource('myfolder'), - folder.get_child('myfolder')) + folder.create_file("myfile.txt") + folder.create_folder("myfolder") + self.assertEqual( + self.project.get_resource("myfile.txt"), folder.get_child("myfile.txt") + ) + self.assertEqual( + self.project.get_resource("myfolder"), folder.get_child("myfolder") + ) def test_folder_get_child_nested(self): root = self.project.root - folder = root.create_folder('myfolder') - folder.create_file('myfile.txt') - folder.create_folder('myfolder') - self.assertEqual(self.project.get_resource('myfolder/myfile.txt'), - folder.get_child('myfile.txt')) - self.assertEqual(self.project.get_resource('myfolder/myfolder'), - folder.get_child('myfolder')) + folder = root.create_folder("myfolder") + folder.create_file("myfile.txt") + folder.create_folder("myfolder") + self.assertEqual( + self.project.get_resource("myfolder/myfile.txt"), + folder.get_child("myfile.txt"), + ) + self.assertEqual( + self.project.get_resource("myfolder/myfolder"), folder.get_child("myfolder") + ) def test_project_root_is_root_folder(self): - self.assertEqual('', self.project.root.path) + self.assertEqual("", self.project.root.path) def test_moving_files(self): root_folder = self.project.root - my_file = root_folder.create_file('my_file.txt') - my_file.move('my_other_file.txt') + my_file = root_folder.create_file("my_file.txt") + my_file.move("my_other_file.txt") self.assertFalse(my_file.exists()) - root_folder.get_child('my_other_file.txt') + root_folder.get_child("my_other_file.txt") def test_moving_folders(self): root_folder = self.project.root - my_folder = root_folder.create_folder('my_folder') - my_file = my_folder.create_file('my_file.txt') - my_folder.move('new_folder') - self.assertFalse(root_folder.has_child('my_folder')) + my_folder = root_folder.create_folder("my_folder") + my_file = my_folder.create_file("my_file.txt") + my_folder.move("new_folder") + self.assertFalse(root_folder.has_child("my_folder")) self.assertFalse(my_file.exists()) - self.assertTrue(root_folder.get_child('new_folder') is not None) + self.assertTrue(root_folder.get_child("new_folder") is not None) def test_moving_destination_folders(self): root_folder = self.project.root - my_folder = root_folder.create_folder('my_folder') - my_file = root_folder.create_file('my_file.txt') - my_file.move('my_folder') - self.assertFalse(root_folder.has_child('my_file.txt')) + my_folder = root_folder.create_folder("my_folder") + my_file = root_folder.create_file("my_file.txt") + my_file.move("my_folder") + self.assertFalse(root_folder.has_child("my_file.txt")) self.assertFalse(my_file.exists()) - my_folder.get_child('my_file.txt') + my_folder.get_child("my_file.txt") def test_moving_files_and_resource_objects(self): root_folder = self.project.root - my_file = root_folder.create_file('my_file.txt') + my_file = root_folder.create_file("my_file.txt") old_hash = hash(my_file) - my_file.move('my_other_file.txt') + my_file.move("my_other_file.txt") self.assertEqual(old_hash, hash(my_file)) def test_file_encoding_reading(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = (b'# -*- coding: utf-8 -*-\n' + - br'#\N{LATIN SMALL LETTER I WITH DIAERESIS}\n').decode('utf8') - file = open(sample_file.real_path, 'wb') - file.write(contents.encode('utf-8')) + sample_file = self.project.root.create_file("my_file.txt") + contents = ( + b"# -*- coding: utf-8 -*-\n" + + br"#\N{LATIN SMALL LETTER I WITH DIAERESIS}\n" + ).decode("utf8") + file = open(sample_file.real_path, "wb") + file.write(contents.encode("utf-8")) file.close() self.assertEqual(contents, sample_file.read()) def test_file_encoding_writing(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = (b'# -*- coding: utf-8 -*-\n' + - br'\N{LATIN SMALL LETTER I WITH DIAERESIS}\n').decode('utf8') + sample_file = self.project.root.create_file("my_file.txt") + contents = ( + b"# -*- coding: utf-8 -*-\n" + br"\N{LATIN SMALL LETTER I WITH DIAERESIS}\n" + ).decode("utf8") sample_file.write(contents) self.assertEqual(contents, sample_file.read()) def test_using_utf8_when_writing_in_case_of_errors(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = br'\n\N{LATIN SMALL LETTER I WITH DIAERESIS}\n'.decode('utf8') + sample_file = self.project.root.create_file("my_file.txt") + contents = br"\n\N{LATIN SMALL LETTER I WITH DIAERESIS}\n".decode("utf8") sample_file.write(contents) self.assertEqual(contents, sample_file.read()) def test_encoding_declaration_in_the_second_line(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = b'\n# -*- coding: latin-1 -*-\n\xa9\n' - file = open(sample_file.real_path, 'wb') + sample_file = self.project.root.create_file("my_file.txt") + contents = b"\n# -*- coding: latin-1 -*-\n\xa9\n" + file = open(sample_file.real_path, "wb") file.write(contents) file.close() - self.assertEqual(contents, sample_file.read().encode('latin-1')) + self.assertEqual(contents, sample_file.read().encode("latin-1")) def test_not_an_encoding_declaration(self): - sample_file = self.project.root.create_file('my_file.txt') + sample_file = self.project.root.create_file("my_file.txt") contents = b"def my_method(self, encoding='latin-1'):\n var = {}\n\xc2\xa9\n" - file = open(sample_file.real_path, 'wb') + file = open(sample_file.real_path, "wb") file.write(contents) file.close() - self.assertEqual(contents, sample_file.read().encode('utf-8')) - self.assertNotEqual(contents, sample_file.read().encode('latin-1')) + self.assertEqual(contents, sample_file.read().encode("utf-8")) + self.assertNotEqual(contents, sample_file.read().encode("latin-1")) def test_read_bytes(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = b'\n# -*- coding: latin-1 -*-\n\xa9\n' - file = open(sample_file.real_path, 'wb') + sample_file = self.project.root.create_file("my_file.txt") + contents = b"\n# -*- coding: latin-1 -*-\n\xa9\n" + file = open(sample_file.real_path, "wb") file.write(contents) file.close() self.assertEqual(contents, sample_file.read_bytes()) # TODO: Detecting utf-16 encoding def xxx_test_using_utf16(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = b'# -*- coding: utf-16 -*-\n# This is a sample file ...\n' - file = open(sample_file.real_path, 'w') - file.write(contents.encode('utf-16')) + sample_file = self.project.root.create_file("my_file.txt") + contents = b"# -*- coding: utf-16 -*-\n# This is a sample file ...\n" + file = open(sample_file.real_path, "w") + file.write(contents.encode("utf-16")) file.close() sample_file.write(contents) self.assertEqual(contents, sample_file.read()) # XXX: supporting utf_8_sig def xxx_test_file_encoding_reading_for_notepad_styles(self): - sample_file = self.project.root.create_file('my_file.txt') - contents = u'#\N{LATIN SMALL LETTER I WITH DIAERESIS}\n' - file = open(sample_file.real_path, 'w') + sample_file = self.project.root.create_file("my_file.txt") + contents = "#\N{LATIN SMALL LETTER I WITH DIAERESIS}\n" + file = open(sample_file.real_path, "w") # file.write('\xef\xbb\xbf') - file.write(contents.encode('utf-8-sig')) + file.write(contents.encode("utf-8-sig")) file.close() self.assertEqual(contents, sample_file.read()) @@ -459,14 +464,14 @@ def test_using_project_get_file(self): self.assertTrue(myfile.exists()) def test_using_file_create(self): - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") self.assertFalse(myfile.exists()) myfile.create() self.assertTrue(myfile.exists()) self.assertFalse(myfile.is_folder()) def test_using_folder_create(self): - myfolder = self.project.get_folder('myfolder') + myfolder = self.project.get_folder("myfolder") self.assertFalse(myfolder.exists()) myfolder.create() self.assertTrue(myfolder.exists()) @@ -474,33 +479,33 @@ def test_using_folder_create(self): def test_exception_when_creating_twice(self): with self.assertRaises(RopeError): - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") myfile.create() myfile.create() def test_exception_when_parent_does_not_exist(self): with self.assertRaises(ResourceNotFoundError): - myfile = self.project.get_file('myfolder/myfile.txt') + myfile = self.project.get_file("myfolder/myfile.txt") myfile.create() def test_simple_path_to_resource(self): - myfile = self.project.root.create_file('myfile.txt') - self.assertEqual(myfile, path_to_resource(self.project, - myfile.real_path)) - self.assertEqual(myfile, path_to_resource( - self.project, myfile.real_path, type='file')) - myfolder = self.project.root.create_folder('myfolder') - self.assertEqual(myfolder, path_to_resource(self.project, - myfolder.real_path)) - self.assertEqual(myfolder, path_to_resource( - self.project, myfolder.real_path, type='folder')) + myfile = self.project.root.create_file("myfile.txt") + self.assertEqual(myfile, path_to_resource(self.project, myfile.real_path)) + self.assertEqual( + myfile, path_to_resource(self.project, myfile.real_path, type="file") + ) + myfolder = self.project.root.create_folder("myfolder") + self.assertEqual(myfolder, path_to_resource(self.project, myfolder.real_path)) + self.assertEqual( + myfolder, path_to_resource(self.project, myfolder.real_path, type="folder") + ) @testutils.skipNotPOSIX() def test_ignoring_symlinks_inside_project(self): - project2 = testutils.sample_project(folder_name='sampleproject2') - mod = project2.root.create_file('mod.py') + project2 = testutils.sample_project(folder_name="sampleproject2") + mod = project2.root.create_file("mod.py") try: - path = os.path.join(self.project.address, 'linkedfile.txt') + path = os.path.join(self.project.address, "linkedfile.txt") os.symlink(mod.real_path, path) files = self.project.root.get_files() self.assertEqual(1, len(files)) @@ -511,57 +516,55 @@ def test_getting_empty_source_folders(self): self.assertEqual([], self.project.get_source_folders()) def test_root_source_folder(self): - self.project.root.create_file('sample.py') + self.project.root.create_file("sample.py") source_folders = self.project.get_source_folders() self.assertEqual(1, len(source_folders)) self.assertTrue(self.project.root in source_folders) def test_root_source_folder2(self): - self.project.root.create_file('mod1.py') - self.project.root.create_file('mod2.py') + self.project.root.create_file("mod1.py") + self.project.root.create_file("mod2.py") source_folders = self.project.get_source_folders() self.assertEqual(1, len(source_folders)) self.assertTrue(self.project.root in source_folders) def test_src_source_folder(self): - src = self.project.root.create_folder('src') - src.create_file('sample.py') + src = self.project.root.create_folder("src") + src.create_file("sample.py") source_folders = self.project.get_source_folders() self.assertEqual(1, len(source_folders)) - self.assertTrue(self.project.get_resource('src') in source_folders) + self.assertTrue(self.project.get_resource("src") in source_folders) def test_packages(self): - src = self.project.root.create_folder('src') - pkg = src.create_folder('package') - pkg.create_file('__init__.py') + src = self.project.root.create_folder("src") + pkg = src.create_folder("package") + pkg.create_file("__init__.py") source_folders = self.project.get_source_folders() self.assertEqual(1, len(source_folders)) self.assertTrue(src in source_folders) def test_multi_source_folders(self): - src = self.project.root.create_folder('src') - package = src.create_folder('package') - package.create_file('__init__.py') - test = self.project.root.create_folder('test') - test.create_file('alltests.py') + src = self.project.root.create_folder("src") + package = src.create_folder("package") + package.create_file("__init__.py") + test = self.project.root.create_folder("test") + test.create_file("alltests.py") source_folders = self.project.get_source_folders() self.assertEqual(2, len(source_folders)) self.assertTrue(src in source_folders) self.assertTrue(test in source_folders) def test_multi_source_folders2(self): - testutils.create_module(self.project, 'mod1') - src = self.project.root.create_folder('src') - package = testutils.create_package(self.project, 'package', src) - testutils.create_module(self.project, 'mod2', package) + testutils.create_module(self.project, "mod1") + src = self.project.root.create_folder("src") + package = testutils.create_package(self.project, "package", src) + testutils.create_module(self.project, "mod2", package) source_folders = self.project.get_source_folders() self.assertEqual(2, len(source_folders)) - self.assertTrue(self.project.root in source_folders and - src in source_folders) + self.assertTrue(self.project.root in source_folders and src in source_folders) class ResourceObserverTest(unittest.TestCase): - def setUp(self): super(ResourceObserverTest, self).setUp() self.project = testutils.sample_project() @@ -571,66 +574,69 @@ def tearDown(self): super(ResourceObserverTest, self).tearDown() def test_resource_change_observer(self): - sample_file = self.project.root.create_file('my_file.txt') - sample_file.write('a sample file version 1') + sample_file = self.project.root.create_file("my_file.txt") + sample_file.write("a sample file version 1") sample_observer = _SampleObserver() self.project.add_observer(sample_observer) - sample_file.write('a sample file version 2') + sample_file.write("a sample file version 2") self.assertEqual(1, sample_observer.change_count) self.assertEqual(sample_file, sample_observer.last_changed) def test_resource_change_observer_after_removal(self): - sample_file = self.project.root.create_file('my_file.txt') - sample_file.write('text') + sample_file = self.project.root.create_file("my_file.txt") + sample_file.write("text") sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [sample_file])) + self.project.add_observer( + FilteredResourceObserver(sample_observer, [sample_file]) + ) sample_file.remove() self.assertEqual(1, sample_observer.change_count) self.assertEqual(sample_file, sample_observer.last_removed) def test_resource_change_observer2(self): - sample_file = self.project.root.create_file('my_file.txt') + sample_file = self.project.root.create_file("my_file.txt") sample_observer = _SampleObserver() self.project.add_observer(sample_observer) self.project.remove_observer(sample_observer) - sample_file.write('a sample file version 2') + sample_file.write("a sample file version 2") self.assertEqual(0, sample_observer.change_count) def test_resource_change_observer_for_folders(self): root_folder = self.project.root - my_folder = root_folder.create_folder('my_folder') + my_folder = root_folder.create_folder("my_folder") my_folder_observer = _SampleObserver() root_folder_observer = _SampleObserver() self.project.add_observer( - FilteredResourceObserver(my_folder_observer, [my_folder])) + FilteredResourceObserver(my_folder_observer, [my_folder]) + ) self.project.add_observer( - FilteredResourceObserver(root_folder_observer, [root_folder])) - my_file = my_folder.create_file('my_file.txt') + FilteredResourceObserver(root_folder_observer, [root_folder]) + ) + my_file = my_folder.create_file("my_file.txt") self.assertEqual(1, my_folder_observer.change_count) - my_file.move('another_file.txt') + my_file.move("another_file.txt") self.assertEqual(2, my_folder_observer.change_count) self.assertEqual(1, root_folder_observer.change_count) - self.project.get_resource('another_file.txt').remove() + self.project.get_resource("another_file.txt").remove() self.assertEqual(2, my_folder_observer.change_count) self.assertEqual(2, root_folder_observer.change_count) def test_resource_change_observer_after_moving(self): - sample_file = self.project.root.create_file('my_file.txt') + sample_file = self.project.root.create_file("my_file.txt") sample_observer = _SampleObserver() self.project.add_observer(sample_observer) - sample_file.move('new_file.txt') + sample_file.move("new_file.txt") self.assertEqual(1, sample_observer.change_count) - self.assertEqual((sample_file, - self.project.get_resource('new_file.txt')), - sample_observer.last_moved) + self.assertEqual( + (sample_file, self.project.get_resource("new_file.txt")), + sample_observer.last_moved, + ) def test_revalidating_files(self): root = self.project.root - my_file = root.create_file('my_file.txt') + my_file = root.create_file("my_file.txt") sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [my_file])) + self.project.add_observer(FilteredResourceObserver(sample_observer, [my_file])) os.remove(my_file.real_path) self.project.validate(root) self.assertEqual(my_file, sample_observer.last_removed) @@ -638,46 +644,47 @@ def test_revalidating_files(self): def test_revalidating_files_and_no_changes2(self): root = self.project.root - my_file = root.create_file('my_file.txt') + my_file = root.create_file("my_file.txt") sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [my_file])) + self.project.add_observer(FilteredResourceObserver(sample_observer, [my_file])) self.project.validate(root) self.assertEqual(None, sample_observer.last_moved) self.assertEqual(0, sample_observer.change_count) def test_revalidating_folders(self): root = self.project.root - my_folder = root.create_folder('myfolder') - my_file = my_folder.create_file('myfile.txt') # noqa + my_folder = root.create_folder("myfolder") + my_file = my_folder.create_file("myfile.txt") # noqa sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [my_folder])) + self.project.add_observer( + FilteredResourceObserver(sample_observer, [my_folder]) + ) testutils.remove_recursively(my_folder.real_path) self.project.validate(root) self.assertEqual(my_folder, sample_observer.last_removed) self.assertEqual(1, sample_observer.change_count) def test_removing_and_adding_resources_to_filtered_observer(self): - my_file = self.project.root.create_file('my_file.txt') + my_file = self.project.root.create_file("my_file.txt") sample_observer = _SampleObserver() filtered_observer = FilteredResourceObserver(sample_observer) self.project.add_observer(filtered_observer) - my_file.write('1') + my_file.write("1") self.assertEqual(0, sample_observer.change_count) filtered_observer.add_resource(my_file) - my_file.write('2') + my_file.write("2") self.assertEqual(1, sample_observer.change_count) filtered_observer.remove_resource(my_file) - my_file.write('3') + my_file.write("3") self.assertEqual(1, sample_observer.change_count) def test_validation_and_changing_files(self): - my_file = self.project.root.create_file('my_file.txt') + my_file = self.project.root.create_file("my_file.txt") sample_observer = _SampleObserver() timekeeper = _MockChangeIndicator() filtered_observer = FilteredResourceObserver( - sample_observer, [my_file], timekeeper=timekeeper) + sample_observer, [my_file], timekeeper=timekeeper + ) self.project.add_observer(filtered_observer) self._write_file(my_file.real_path) timekeeper.set_indicator(my_file, 1) @@ -685,25 +692,26 @@ def test_validation_and_changing_files(self): self.assertEqual(1, sample_observer.change_count) def test_validation_and_changing_files2(self): - my_file = self.project.root.create_file('my_file.txt') + my_file = self.project.root.create_file("my_file.txt") sample_observer = _SampleObserver() timekeeper = _MockChangeIndicator() - self.project.add_observer(FilteredResourceObserver( - sample_observer, [my_file], - timekeeper=timekeeper)) + self.project.add_observer( + FilteredResourceObserver(sample_observer, [my_file], timekeeper=timekeeper) + ) timekeeper.set_indicator(my_file, 1) - my_file.write('hey') + my_file.write("hey") self.assertEqual(1, sample_observer.change_count) self.project.validate(self.project.root) self.assertEqual(1, sample_observer.change_count) def test_not_reporting_multiple_changes_to_folders(self): root = self.project.root - file1 = root.create_file('file1.txt') - file2 = root.create_file('file2.txt') + file1 = root.create_file("file1.txt") + file2 = root.create_file("file2.txt") sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver( - sample_observer, [root, file1, file2])) + self.project.add_observer( + FilteredResourceObserver(sample_observer, [root, file1, file2]) + ) os.remove(file1.real_path) os.remove(file2.real_path) self.assertEqual(0, sample_observer.change_count) @@ -711,43 +719,42 @@ def test_not_reporting_multiple_changes_to_folders(self): self.assertEqual(3, sample_observer.change_count) def _write_file(self, path): - my_file = open(path, 'w') - my_file.write('\n') + my_file = open(path, "w") + my_file.write("\n") my_file.close() def test_moving_and_being_interested_about_a_folder_and_a_child(self): - my_folder = self.project.root.create_folder('my_folder') - my_file = my_folder.create_file('my_file.txt') + my_folder = self.project.root.create_folder("my_folder") + my_file = my_folder.create_file("my_file.txt") sample_observer = _SampleObserver() filtered_observer = FilteredResourceObserver( - sample_observer, [my_folder, my_file]) + sample_observer, [my_folder, my_file] + ) self.project.add_observer(filtered_observer) - my_folder.move('new_folder') + my_folder.move("new_folder") self.assertEqual(2, sample_observer.change_count) def test_contains_for_folders(self): - folder1 = self.project.root.create_folder('folder') - folder2 = self.project.root.create_folder('folder2') + folder1 = self.project.root.create_folder("folder") + folder2 = self.project.root.create_folder("folder2") self.assertFalse(folder1.contains(folder2)) def test_validating_when_created(self): root = self.project.root - my_file = self.project.get_file('my_file.txt') + my_file = self.project.get_file("my_file.txt") sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [my_file])) - open(my_file.real_path, 'w').close() + self.project.add_observer(FilteredResourceObserver(sample_observer, [my_file])) + open(my_file.real_path, "w").close() self.project.validate(root) self.assertEqual(my_file, sample_observer.last_created) self.assertEqual(1, sample_observer.change_count) def test_validating_twice_when_created(self): root = self.project.root - my_file = self.project.get_file('my_file.txt') + my_file = self.project.get_file("my_file.txt") sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [my_file])) - open(my_file.real_path, 'w').close() + self.project.add_observer(FilteredResourceObserver(sample_observer, [my_file])) + open(my_file.real_path, "w").close() self.project.validate(root) self.project.validate(root) self.assertEqual(my_file, sample_observer.last_created) @@ -755,12 +762,13 @@ def test_validating_twice_when_created(self): def test_changes_and_adding_resources(self): root = self.project.root # noqa - file1 = self.project.get_file('file1.txt') - file2 = self.project.get_file('file2.txt') + file1 = self.project.get_file("file1.txt") + file2 = self.project.get_file("file2.txt") file1.create() sample_observer = _SampleObserver() - self.project.add_observer(FilteredResourceObserver(sample_observer, - [file1, file2])) + self.project.add_observer( + FilteredResourceObserver(sample_observer, [file1, file2]) + ) file1.move(file2.path) self.assertEqual(2, sample_observer.change_count) self.assertEqual(file2, sample_observer.last_created) @@ -769,24 +777,23 @@ def test_changes_and_adding_resources(self): def test_validating_get_files_list(self): root = self.project.root # noqa self.assertEqual(0, len(self.project.get_files())) - file = open(os.path.join(self.project.address, 'myfile.txt'), 'w') + file = open(os.path.join(self.project.address, "myfile.txt"), "w") file.close() self.project.validate() self.assertEqual(1, len(self.project.get_files())) def test_clear_observered_resources_for_filtered_observers(self): - sample_file = self.project.root.create_file('myfile.txt') + sample_file = self.project.root.create_file("myfile.txt") sample_observer = _SampleObserver() filtered = FilteredResourceObserver(sample_observer) self.project.add_observer(filtered) filtered.add_resource(sample_file) filtered.clear_resources() - sample_file.write('1') + sample_file.write("1") self.assertEqual(0, sample_observer.change_count) class _MockChangeIndicator(object): - def __init__(self): self.times = {} @@ -798,7 +805,6 @@ def get_indicator(self, resource): class _SampleObserver(object): - def __init__(self): self.change_count = 0 self.last_changed = None @@ -824,10 +830,9 @@ def resource_removed(self, resource): class OutOfProjectTest(unittest.TestCase): - def setUp(self): super(OutOfProjectTest, self).setUp() - self.test_directory = 'temp_test_directory' + self.test_directory = "temp_test_directory" testutils.remove_recursively(self.test_directory) os.mkdir(self.test_directory) self.project = testutils.sample_project() @@ -839,80 +844,80 @@ def tearDown(self): super(OutOfProjectTest, self).tearDown() def test_simple_out_of_project_file(self): - sample_file_path = os.path.join(self.test_directory, 'sample.txt') - sample_file = open(sample_file_path, 'w') - sample_file.write('sample content\n') + sample_file_path = os.path.join(self.test_directory, "sample.txt") + sample_file = open(sample_file_path, "w") + sample_file.write("sample content\n") sample_file.close() sample_resource = self.no_project.get_resource(sample_file_path) - self.assertEqual('sample content\n', sample_resource.read()) + self.assertEqual("sample content\n", sample_resource.read()) def test_simple_out_of_project_folder(self): - sample_folder_path = os.path.join(self.test_directory, 'sample_folder') + sample_folder_path = os.path.join(self.test_directory, "sample_folder") os.mkdir(sample_folder_path) sample_folder = self.no_project.get_resource(sample_folder_path) self.assertEqual([], sample_folder.get_children()) - sample_file_path = os.path.join(sample_folder_path, 'sample.txt') - open(sample_file_path, 'w').close() + sample_file_path = os.path.join(sample_folder_path, "sample.txt") + open(sample_file_path, "w").close() sample_resource = self.no_project.get_resource(sample_file_path) self.assertEqual(sample_resource, sample_folder.get_children()[0]) def test_using_absolute_path(self): - sample_file_path = os.path.join(self.test_directory, 'sample.txt') - open(sample_file_path, 'w').close() + sample_file_path = os.path.join(self.test_directory, "sample.txt") + open(sample_file_path, "w").close() normal_sample_resource = self.no_project.get_resource(sample_file_path) - absolute_sample_resource = \ - self.no_project.get_resource(os.path.abspath(sample_file_path)) + absolute_sample_resource = self.no_project.get_resource( + os.path.abspath(sample_file_path) + ) self.assertEqual(normal_sample_resource, absolute_sample_resource) def test_folder_get_child(self): - sample_folder_path = os.path.join(self.test_directory, 'sample_folder') + sample_folder_path = os.path.join(self.test_directory, "sample_folder") os.mkdir(sample_folder_path) sample_folder = self.no_project.get_resource(sample_folder_path) self.assertEqual([], sample_folder.get_children()) - sample_file_path = os.path.join(sample_folder_path, 'sample.txt') - open(sample_file_path, 'w').close() + sample_file_path = os.path.join(sample_folder_path, "sample.txt") + open(sample_file_path, "w").close() sample_resource = self.no_project.get_resource(sample_file_path) - self.assertTrue(sample_folder.has_child('sample.txt')) - self.assertFalse(sample_folder.has_child('doesnothave.txt')) - self.assertEqual(sample_resource, - sample_folder.get_child('sample.txt')) + self.assertTrue(sample_folder.has_child("sample.txt")) + self.assertFalse(sample_folder.has_child("doesnothave.txt")) + self.assertEqual(sample_resource, sample_folder.get_child("sample.txt")) def test_out_of_project_files_and_path_to_resource(self): - sample_file_path = os.path.join(self.test_directory, 'sample.txt') - sample_file = open(sample_file_path, 'w') - sample_file.write('sample content\n') + sample_file_path = os.path.join(self.test_directory, "sample.txt") + sample_file = open(sample_file_path, "w") + sample_file.write("sample content\n") sample_file.close() sample_resource = self.no_project.get_resource(sample_file_path) - self.assertEqual(sample_resource, - path_to_resource(self.project, sample_file_path)) + self.assertEqual( + sample_resource, path_to_resource(self.project, sample_file_path) + ) class _MockFSCommands(object): def __init__(self): - self.log = '' + self.log = "" self.fscommands = FileSystemCommands() def create_file(self, path): - self.log += 'create_file ' + self.log += "create_file " self.fscommands.create_file(path) def create_folder(self, path): - self.log += 'create_folder ' + self.log += "create_folder " self.fscommands.create_folder(path) def move(self, path, new_location): - self.log += 'move ' + self.log += "move " self.fscommands.move(path, new_location) def remove(self, path): - self.log += 'remove ' + self.log += "remove " self.fscommands.remove(path) class RopeFolderTest(unittest.TestCase): - def setUp(self): super(RopeFolderTest, self).setUp() self.project = None @@ -927,110 +932,113 @@ def test_none_project_rope_folder(self): self.assertTrue(self.project.ropefolder is None) def test_getting_project_rope_folder(self): - self.project = testutils.sample_project(ropefolder='.ropeproject') + self.project = testutils.sample_project(ropefolder=".ropeproject") self.assertTrue(self.project.ropefolder.exists()) - self.assertTrue('.ropeproject', self.project.ropefolder.path) + self.assertTrue(".ropeproject", self.project.ropefolder.path) def test_setting_ignored_resources(self): - self.project = testutils.sample_project( - ignored_resources=['myfile.txt']) - myfile = self.project.get_file('myfile.txt') - file2 = self.project.get_file('file2.txt') + self.project = testutils.sample_project(ignored_resources=["myfile.txt"]) + myfile = self.project.get_file("myfile.txt") + file2 = self.project.get_file("file2.txt") self.assertTrue(self.project.is_ignored(myfile)) self.assertFalse(self.project.is_ignored(file2)) def test_ignored_folders(self): - self.project = testutils.sample_project(ignored_resources=['myfolder']) - myfolder = self.project.root.create_folder('myfolder') + self.project = testutils.sample_project(ignored_resources=["myfolder"]) + myfolder = self.project.root.create_folder("myfolder") self.assertTrue(self.project.is_ignored(myfolder)) - myfile = myfolder.create_file('myfile.txt') + myfile = myfolder.create_file("myfile.txt") self.assertTrue(self.project.is_ignored(myfile)) def test_ignored_resources_and_get_files(self): self.project = testutils.sample_project( - ignored_resources=['myfile.txt'], ropefolder=None) - myfile = self.project.get_file('myfile.txt') + ignored_resources=["myfile.txt"], ropefolder=None + ) + myfile = self.project.get_file("myfile.txt") self.assertEqual(0, len(self.project.get_files())) myfile.create() self.assertEqual(0, len(self.project.get_files())) def test_ignored_resources_and_get_files2(self): self.project = testutils.sample_project( - ignored_resources=['myfile.txt'], ropefolder=None) - myfile = self.project.root.create_file('myfile.txt') # noqa + ignored_resources=["myfile.txt"], ropefolder=None + ) + myfile = self.project.root.create_file("myfile.txt") # noqa self.assertEqual(0, len(self.project.get_files())) def test_setting_ignored_resources_patterns(self): - self.project = testutils.sample_project(ignored_resources=['m?file.*']) - myfile = self.project.get_file('myfile.txt') - file2 = self.project.get_file('file2.txt') + self.project = testutils.sample_project(ignored_resources=["m?file.*"]) + myfile = self.project.get_file("myfile.txt") + file2 = self.project.get_file("file2.txt") self.assertTrue(self.project.is_ignored(myfile)) self.assertFalse(self.project.is_ignored(file2)) def test_star_should_not_include_slashes(self): - self.project = testutils.sample_project(ignored_resources=['f*.txt']) - folder = self.project.root.create_folder('folder') - file1 = folder.create_file('myfile.txt') - file2 = folder.create_file('file2.txt') + self.project = testutils.sample_project(ignored_resources=["f*.txt"]) + folder = self.project.root.create_folder("folder") + file1 = folder.create_file("myfile.txt") + file2 = folder.create_file("file2.txt") self.assertFalse(self.project.is_ignored(file1)) self.assertTrue(self.project.is_ignored(file2)) def test_normal_fscommands(self): fscommands = _MockFSCommands() self.project = testutils.sample_project(fscommands=fscommands) - myfile = self.project.get_file('myfile.txt') + myfile = self.project.get_file("myfile.txt") myfile.create() - self.assertTrue('create_file ', fscommands.log) + self.assertTrue("create_file ", fscommands.log) def test_fscommands_and_ignored_resources(self): fscommands = _MockFSCommands() self.project = testutils.sample_project( - fscommands=fscommands, - ignored_resources=['myfile.txt'], ropefolder=None) - myfile = self.project.get_file('myfile.txt') + fscommands=fscommands, ignored_resources=["myfile.txt"], ropefolder=None + ) + myfile = self.project.get_file("myfile.txt") myfile.create() - self.assertEqual('', fscommands.log) + self.assertEqual("", fscommands.log) def test_ignored_resources_and_prefixes(self): - self.project = testutils.sample_project( - ignored_resources=['.hg']) - myfile = self.project.root.create_file('.hgignore') + self.project = testutils.sample_project(ignored_resources=[".hg"]) + myfile = self.project.root.create_file(".hgignore") self.assertFalse(self.project.is_ignored(myfile)) def test_loading_config_dot_py(self): - self.project = testutils.sample_project(ropefolder='.ropeproject') - config = self.project.get_file('.ropeproject/config.py') + self.project = testutils.sample_project(ropefolder=".ropeproject") + config = self.project.get_file(".ropeproject/config.py") if not config.exists(): config.create() - config.write('def set_prefs(prefs):\n' - ' prefs["ignored_resources"] = ["myfile.txt"]\n' - 'def project_opened(project):\n' - ' project.root.create_file("loaded")\n') + config.write( + "def set_prefs(prefs):\n" + ' prefs["ignored_resources"] = ["myfile.txt"]\n' + "def project_opened(project):\n" + ' project.root.create_file("loaded")\n' + ) self.project.close() - self.project = Project(self.project.address, ropefolder='.ropeproject') - self.assertTrue(self.project.get_file('loaded').exists()) - myfile = self.project.get_file('myfile.txt') + self.project = Project(self.project.address, ropefolder=".ropeproject") + self.assertTrue(self.project.get_file("loaded").exists()) + myfile = self.project.get_file("myfile.txt") self.assertTrue(self.project.is_ignored(myfile)) def test_ignoring_syntax_errors(self): - self.project = testutils.sample_project(ropefolder=None, - ignore_syntax_errors=True) - mod = testutils.create_module(self.project, 'mod') - mod.write('xyz print') + self.project = testutils.sample_project( + ropefolder=None, ignore_syntax_errors=True + ) + mod = testutils.create_module(self.project, "mod") + mod.write("xyz print") pymod = self.project.get_pymodule(mod) # noqa def test_compressed_history(self): self.project = testutils.sample_project(compress_history=True) - mod = testutils.create_module(self.project, 'mod') - mod.write('') + mod = testutils.create_module(self.project, "mod") + mod.write("") def test_compressed_objectdb(self): self.project = testutils.sample_project(compress_objectdb=True) - mod = testutils.create_module(self.project, 'mod') + mod = testutils.create_module(self.project, "mod") self.project.pycore.analyze_module(mod) def test_nested_dot_ropeproject_folder(self): - self.project = testutils.sample_project(ropefolder='.f1/f2') + self.project = testutils.sample_project(ropefolder=".f1/f2") ropefolder = self.project.ropefolder - self.assertEqual('.f1/f2', ropefolder.path) + self.assertEqual(".f1/f2", ropefolder.path) self.assertTrue(ropefolder.exists()) diff --git a/ropetest/pycoretest.py b/ropetest/pycoretest.py index b9d64f5e3..c0754e883 100644 --- a/ropetest/pycoretest.py +++ b/ropetest/pycoretest.py @@ -16,7 +16,6 @@ class PyCoreTest(unittest.TestCase): - def setUp(self): super(PyCoreTest, self).setUp() self.project = testutils.sample_project() @@ -27,1071 +26,1074 @@ def tearDown(self): super(PyCoreTest, self).tearDown() def test_simple_module(self): - testutils.create_module(self.project, 'mod') - result = self.project.get_module('mod') - self.assertEqual(get_base_type('Module'), result.type) + testutils.create_module(self.project, "mod") + result = self.project.get_module("mod") + self.assertEqual(get_base_type("Module"), result.type) self.assertEqual(0, len(result.get_attributes())) def test_nested_modules(self): - pkg = testutils.create_package(self.project, 'pkg') - mod = testutils.create_module(self.project, 'mod', pkg) # noqa - package = self.project.get_module('pkg') - self.assertEqual(get_base_type('Module'), package.get_type()) + pkg = testutils.create_package(self.project, "pkg") + mod = testutils.create_module(self.project, "mod", pkg) # noqa + package = self.project.get_module("pkg") + self.assertEqual(get_base_type("Module"), package.get_type()) self.assertEqual(1, len(package.get_attributes())) - module = package['mod'].get_object() - self.assertEqual(get_base_type('Module'), module.get_type()) + module = package["mod"].get_object() + self.assertEqual(get_base_type("Module"), module.get_type()) def test_package(self): - pkg = testutils.create_package(self.project, 'pkg') - mod = testutils.create_module(self.project, 'mod', pkg) # noqa - result = self.project.get_module('pkg') - self.assertEqual(get_base_type('Module'), result.type) + pkg = testutils.create_package(self.project, "pkg") + mod = testutils.create_module(self.project, "mod", pkg) # noqa + result = self.project.get_module("pkg") + self.assertEqual(get_base_type("Module"), result.type) def test_simple_class(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('class SampleClass(object):\n pass\n') - mod_element = self.project.get_module('mod') - result = mod_element['SampleClass'].get_object() - self.assertEqual(get_base_type('Type'), result.get_type()) + mod = testutils.create_module(self.project, "mod") + mod.write("class SampleClass(object):\n pass\n") + mod_element = self.project.get_module("mod") + result = mod_element["SampleClass"].get_object() + self.assertEqual(get_base_type("Type"), result.get_type()) def test_simple_function(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('def sample_function():\n pass\n') - mod_element = self.project.get_module('mod') - result = mod_element['sample_function'].get_object() - self.assertEqual(get_base_type('Function'), result.get_type()) + mod = testutils.create_module(self.project, "mod") + mod.write("def sample_function():\n pass\n") + mod_element = self.project.get_module("mod") + result = mod_element["sample_function"].get_object() + self.assertEqual(get_base_type("Function"), result.get_type()) def test_class_methods(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class SampleClass(object):\n' \ - ' def sample_method(self):\n' \ - ' pass\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class SampleClass(object):\n" + " def sample_method(self):\n" + " pass\n" + ) mod.write(code) - mod_element = self.project.get_module('mod') - sample_class = mod_element['SampleClass'].get_object() - self.assertTrue('sample_method' in sample_class) - method = sample_class['sample_method'].get_object() - self.assertEqual(get_base_type('Function'), method.get_type()) + mod_element = self.project.get_module("mod") + sample_class = mod_element["SampleClass"].get_object() + self.assertTrue("sample_method" in sample_class) + method = sample_class["sample_method"].get_object() + self.assertEqual(get_base_type("Function"), method.get_type()) def test_global_variable_without_type_annotation(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('var = 10') - mod_element = self.project.get_module('mod') - var = mod_element['var'] + mod = testutils.create_module(self.project, "mod") + mod.write("var = 10") + mod_element = self.project.get_module("mod") + var = mod_element["var"] self.assertEqual(AssignedName, type(var)) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_global_variable_with_type_annotation(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('py3_var: str = foo_bar') - mod_element = self.project.get_module('mod') - py3_var = mod_element['py3_var'] + mod = testutils.create_module(self.project, "mod") + mod.write("py3_var: str = foo_bar") + mod_element = self.project.get_module("mod") + py3_var = mod_element["py3_var"] self.assertEqual(AssignedName, type(py3_var)) def test_class_variables(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('class SampleClass(object):\n var = 10\n') - mod_element = self.project.get_module('mod') - sample_class = mod_element['SampleClass'].get_object() - var = sample_class['var'] # noqa + mod = testutils.create_module(self.project, "mod") + mod.write("class SampleClass(object):\n var = 10\n") + mod_element = self.project.get_module("mod") + sample_class = mod_element["SampleClass"].get_object() + var = sample_class["var"] # noqa def test_class_attributes_set_in_init(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('class C(object):\n' - ' def __init__(self):\n self.var = 20\n') - mod_element = self.project.get_module('mod') - sample_class = mod_element['C'].get_object() - var = sample_class['var'] # noqa + mod = testutils.create_module(self.project, "mod") + mod.write( + "class C(object):\n" " def __init__(self):\n self.var = 20\n" + ) + mod_element = self.project.get_module("mod") + sample_class = mod_element["C"].get_object() + var = sample_class["var"] # noqa def test_class_attributes_set_in_init_overwriting_a_defined(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n' \ - ' def __init__(self):\n' \ - ' self.f = 20\n' \ - ' def f():\n' \ - ' pass\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n" + " def __init__(self):\n" + " self.f = 20\n" + " def f():\n" + " pass\n" + ) mod.write(code) - mod_element = self.project.get_module('mod') - sample_class = mod_element['C'].get_object() - f = sample_class['f'].get_object() + mod_element = self.project.get_module("mod") + sample_class = mod_element["C"].get_object() + f = sample_class["f"].get_object() self.assertTrue(isinstance(f, AbstractFunction)) def test_classes_inside_other_classes(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class SampleClass(object):\n' \ - ' class InnerClass(object):\n' \ - ' pass\n\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class SampleClass(object):\n" + " class InnerClass(object):\n" + " pass\n\n" + ) mod.write(code) - mod_element = self.project.get_module('mod') - sample_class = mod_element['SampleClass'].get_object() - var = sample_class['InnerClass'].get_object() - self.assertEqual(get_base_type('Type'), var.get_type()) + mod_element = self.project.get_module("mod") + sample_class = mod_element["SampleClass"].get_object() + var = sample_class["InnerClass"].get_object() + self.assertEqual(get_base_type("Type"), var.get_type()) def test_non_existent_module(self): with self.assertRaises(exceptions.ModuleNotFoundError): - self.project.get_module('doesnotexistmodule') + self.project.get_module("doesnotexistmodule") def test_imported_names(self): - testutils.create_module(self.project, 'mod1') - mod = testutils.create_module(self.project, 'mod2') - mod.write('import mod1\n') - module = self.project.get_module('mod2') - imported_sys = module['mod1'].get_object() - self.assertEqual(get_base_type('Module'), imported_sys.get_type()) + testutils.create_module(self.project, "mod1") + mod = testutils.create_module(self.project, "mod2") + mod.write("import mod1\n") + module = self.project.get_module("mod2") + imported_sys = module["mod1"].get_object() + self.assertEqual(get_base_type("Module"), imported_sys.get_type()) def test_imported_as_names(self): - testutils.create_module(self.project, 'mod1') - mod = testutils.create_module(self.project, 'mod2') - mod.write('import mod1 as my_import\n') - module = self.project.get_module('mod2') - imported_mod = module['my_import'].get_object() - self.assertEqual(get_base_type('Module'), imported_mod.get_type()) + testutils.create_module(self.project, "mod1") + mod = testutils.create_module(self.project, "mod2") + mod.write("import mod1 as my_import\n") + module = self.project.get_module("mod2") + imported_mod = module["my_import"].get_object() + self.assertEqual(get_base_type("Module"), imported_mod.get_type()) def test_get_string_module(self): mod = libutils.get_string_module( - self.project, 'class Sample(object):\n pass\n') - sample_class = mod['Sample'].get_object() - self.assertEqual(get_base_type('Type'), sample_class.get_type()) + self.project, "class Sample(object):\n pass\n" + ) + sample_class = mod["Sample"].get_object() + self.assertEqual(get_base_type("Type"), sample_class.get_type()) def test_get_string_module_with_extra_spaces(self): - mod = libutils.get_string_module( - self.project, 'a = 10\n ') # noqa + mod = libutils.get_string_module(self.project, "a = 10\n ") # noqa def test_parameter_info_for_functions(self): - code = 'def func(param1, param2=10, *param3, **param4):\n pass' + code = "def func(param1, param2=10, *param3, **param4):\n pass" mod = libutils.get_string_module(self.project, code) - sample_function = mod['func'] - self.assertEqual(['param1', 'param2', 'param3', 'param4'], - sample_function.get_object().get_param_names()) + sample_function = mod["func"] + self.assertEqual( + ["param1", "param2", "param3", "param4"], + sample_function.get_object().get_param_names(), + ) # FIXME: Not found modules def xxx_test_not_found_module_is_module(self): - mod = libutils.get_string_module( - self.project, 'import doesnotexist\n') - self.assertEqual(get_base_type('Module'), - mod['doesnotexist']. - get_object().get_type()) + mod = libutils.get_string_module(self.project, "import doesnotexist\n") + self.assertEqual( + get_base_type("Module"), mod["doesnotexist"].get_object().get_type() + ) def test_mixing_scopes_and_objects_hierarchy(self): - mod = libutils.get_string_module(self.project, 'var = 200\n') + mod = libutils.get_string_module(self.project, "var = 200\n") scope = mod.get_scope() - self.assertTrue('var' in scope.get_names()) + self.assertTrue("var" in scope.get_names()) def test_inheriting_base_class_attributes(self): - code = 'class Base(object):\n' \ - ' def method(self):\n' \ - ' pass\n' \ - 'class Derived(Base):\n' \ - ' pass\n' + code = ( + "class Base(object):\n" + " def method(self):\n" + " pass\n" + "class Derived(Base):\n" + " pass\n" + ) mod = libutils.get_string_module(self.project, code) - derived = mod['Derived'].get_object() - self.assertTrue('method' in derived) - self.assertEqual(get_base_type('Function'), - derived['method'].get_object().get_type()) + derived = mod["Derived"].get_object() + self.assertTrue("method" in derived) + self.assertEqual( + get_base_type("Function"), derived["method"].get_object().get_type() + ) def test_inheriting_multiple_base_class_attributes(self): - code = 'class Base1(object):\n def method1(self):\n pass\n' \ - 'class Base2(object):\n def method2(self):\n pass\n' \ - 'class Derived(Base1, Base2):\n pass\n' + code = ( + "class Base1(object):\n def method1(self):\n pass\n" + "class Base2(object):\n def method2(self):\n pass\n" + "class Derived(Base1, Base2):\n pass\n" + ) mod = libutils.get_string_module(self.project, code) - derived = mod['Derived'].get_object() - self.assertTrue('method1' in derived) - self.assertTrue('method2' in derived) + derived = mod["Derived"].get_object() + self.assertTrue("method1" in derived) + self.assertTrue("method2" in derived) def test_inherit_multiple_base_class_attrs_with_the_same_name(self): - code = 'class Base1(object):\n def method(self):\n pass\n' \ - 'class Base2(object):\n def method(self):\n pass\n' \ - 'class Derived(Base1, Base2):\n pass\n' + code = ( + "class Base1(object):\n def method(self):\n pass\n" + "class Base2(object):\n def method(self):\n pass\n" + "class Derived(Base1, Base2):\n pass\n" + ) mod = libutils.get_string_module(self.project, code) - base1 = mod['Base1'].get_object() - derived = mod['Derived'].get_object() - self.assertEqual(base1['method'].get_object(), - derived['method'].get_object()) + base1 = mod["Base1"].get_object() + derived = mod["Derived"].get_object() + self.assertEqual(base1["method"].get_object(), derived["method"].get_object()) def test_inheriting_unknown_base_class(self): - code = 'class Derived(NotFound):\n' \ - ' def f(self):\n' \ - ' pass\n' + code = "class Derived(NotFound):\n" " def f(self):\n" " pass\n" mod = libutils.get_string_module(self.project, code) - derived = mod['Derived'].get_object() - self.assertTrue('f' in derived) + derived = mod["Derived"].get_object() + self.assertTrue("f" in derived) def test_module_creation(self): - new_module = testutils.create_module(self.project, 'module') + new_module = testutils.create_module(self.project, "module") self.assertFalse(new_module.is_folder()) - self.assertEqual(self.project.get_resource('module.py'), new_module) + self.assertEqual(self.project.get_resource("module.py"), new_module) def test_packaged_module_creation(self): - package = self.project.root.create_folder('package') # noqa - new_module = testutils.create_module(self.project, 'package.module') - self.assertEqual(self.project.get_resource('package/module.py'), - new_module) + package = self.project.root.create_folder("package") # noqa + new_module = testutils.create_module(self.project, "package.module") + self.assertEqual(self.project.get_resource("package/module.py"), new_module) def test_packaged_module_creation_with_nested_src(self): - src = self.project.root.create_folder('src') - src.create_folder('pkg') - new_module = testutils.create_module(self.project, 'pkg.mod', src) - self.assertEqual(self.project.get_resource('src/pkg/mod.py'), - new_module) + src = self.project.root.create_folder("src") + src.create_folder("pkg") + new_module = testutils.create_module(self.project, "pkg.mod", src) + self.assertEqual(self.project.get_resource("src/pkg/mod.py"), new_module) def test_package_creation(self): - new_package = testutils.create_package(self.project, 'pkg') + new_package = testutils.create_package(self.project, "pkg") self.assertTrue(new_package.is_folder()) - self.assertEqual(self.project.get_resource('pkg'), new_package) - self.assertEqual(self.project.get_resource('pkg/__init__.py'), - new_package.get_child('__init__.py')) + self.assertEqual(self.project.get_resource("pkg"), new_package) + self.assertEqual( + self.project.get_resource("pkg/__init__.py"), + new_package.get_child("__init__.py"), + ) def test_nested_package_creation(self): - testutils.create_package(self.project, 'pkg1') - nested_package = testutils.create_package(self.project, 'pkg1.pkg2') - self.assertEqual(self.project.get_resource('pkg1/pkg2'), - nested_package) + testutils.create_package(self.project, "pkg1") + nested_package = testutils.create_package(self.project, "pkg1.pkg2") + self.assertEqual(self.project.get_resource("pkg1/pkg2"), nested_package) def test_packaged_package_creation_with_nested_src(self): - src = self.project.root.create_folder('src') - testutils.create_package(self.project, 'pkg1', src) - nested_package = testutils.create_package(self.project, 'pkg1.pkg2', - src) - self.assertEqual(self.project.get_resource('src/pkg1/pkg2'), - nested_package) + src = self.project.root.create_folder("src") + testutils.create_package(self.project, "pkg1", src) + nested_package = testutils.create_package(self.project, "pkg1.pkg2", src) + self.assertEqual(self.project.get_resource("src/pkg1/pkg2"), nested_package) def test_find_module(self): - src = self.project.root.create_folder('src') - samplemod = testutils.create_module(self.project, 'samplemod', src) - found_module = self.project.find_module('samplemod') + src = self.project.root.create_folder("src") + samplemod = testutils.create_module(self.project, "samplemod", src) + found_module = self.project.find_module("samplemod") self.assertEqual(samplemod, found_module) def test_find_nested_module(self): - src = self.project.root.create_folder('src') - samplepkg = testutils.create_package(self.project, 'samplepkg', src) - samplemod = testutils.create_module(self.project, 'samplemod', - samplepkg) - found_module = self.project.find_module('samplepkg.samplemod') + src = self.project.root.create_folder("src") + samplepkg = testutils.create_package(self.project, "samplepkg", src) + samplemod = testutils.create_module(self.project, "samplemod", samplepkg) + found_module = self.project.find_module("samplepkg.samplemod") self.assertEqual(samplemod, found_module) def test_find_multiple_module(self): - src = self.project.root.create_folder('src') - samplemod1 = testutils.create_module(self.project, 'samplemod', src) - samplemod2 = testutils.create_module(self.project, 'samplemod') - test = self.project.root.create_folder('test') - samplemod3 = testutils.create_module(self.project, 'samplemod', test) - found_module = self.project.find_module('samplemod') - self.assertTrue(samplemod1 == found_module or - samplemod2 == found_module or - samplemod3 == found_module) + src = self.project.root.create_folder("src") + samplemod1 = testutils.create_module(self.project, "samplemod", src) + samplemod2 = testutils.create_module(self.project, "samplemod") + test = self.project.root.create_folder("test") + samplemod3 = testutils.create_module(self.project, "samplemod", test) + found_module = self.project.find_module("samplemod") + self.assertTrue( + samplemod1 == found_module + or samplemod2 == found_module + or samplemod3 == found_module + ) def test_find_module_packages(self): src = self.project.root - samplepkg = testutils.create_package(self.project, 'samplepkg', src) - found_module = self.project.find_module('samplepkg') + samplepkg = testutils.create_package(self.project, "samplepkg", src) + found_module = self.project.find_module("samplepkg") self.assertEqual(samplepkg, found_module) def test_find_module_when_module_and_package_with_the_same_name(self): src = self.project.root - testutils.create_module(self.project, 'sample', src) - samplepkg = testutils.create_package(self.project, 'sample', src) - found_module = self.project.find_module('sample') + testutils.create_module(self.project, "sample", src) + samplepkg = testutils.create_package(self.project, "sample", src) + found_module = self.project.find_module("sample") self.assertEqual(samplepkg, found_module) def test_source_folders_preference(self): - testutils.create_package(self.project, 'pkg1') - testutils.create_package(self.project, 'pkg1.src2') - lost = testutils.create_module(self.project, 'pkg1.src2.lost') - self.assertEqual(self.project.find_module('lost'), None) + testutils.create_package(self.project, "pkg1") + testutils.create_package(self.project, "pkg1.src2") + lost = testutils.create_module(self.project, "pkg1.src2.lost") + self.assertEqual(self.project.find_module("lost"), None) self.project.close() from rope.base.project import Project - self.project = Project(self.project.address, - source_folders=['pkg1/src2']) - self.assertEqual(self.project.find_module('lost'), lost) + + self.project = Project(self.project.address, source_folders=["pkg1/src2"]) + self.assertEqual(self.project.find_module("lost"), lost) def test_get_pyname_definition_location(self): - mod = libutils.get_string_module(self.project, 'a_var = 20\n') - a_var = mod['a_var'] + mod = libutils.get_string_module(self.project, "a_var = 20\n") + a_var = mod["a_var"] self.assertEqual((mod, 1), a_var.get_definition_location()) def test_get_pyname_definition_location_functions(self): - mod = libutils.get_string_module( - self.project, 'def a_func():\n pass\n') - a_func = mod['a_func'] + mod = libutils.get_string_module(self.project, "def a_func():\n pass\n") + a_func = mod["a_func"] self.assertEqual((mod, 1), a_func.get_definition_location()) def test_get_pyname_definition_location_class(self): - code = 'class AClass(object):\n pass\n\n' + code = "class AClass(object):\n pass\n\n" mod = libutils.get_string_module(self.project, code) - a_class = mod['AClass'] + a_class = mod["AClass"] self.assertEqual((mod, 1), a_class.get_definition_location()) def test_get_pyname_definition_location_local_variables(self): mod = libutils.get_string_module( - self.project, 'def a_func():\n a_var = 10\n') + self.project, "def a_func():\n a_var = 10\n" + ) a_func_scope = mod.get_scope().get_scopes()[0] - a_var = a_func_scope['a_var'] + a_var = a_func_scope["a_var"] self.assertEqual((mod, 2), a_var.get_definition_location()) def test_get_pyname_definition_location_reassigning(self): - mod = libutils.get_string_module( - self.project, 'a_var = 20\na_var=30\n') - a_var = mod['a_var'] + mod = libutils.get_string_module(self.project, "a_var = 20\na_var=30\n") + a_var = mod["a_var"] self.assertEqual((mod, 1), a_var.get_definition_location()) def test_get_pyname_definition_location_importes(self): - testutils.create_module(self.project, 'mod') - mod = libutils.get_string_module(self.project, 'import mod\n') - imported_module = self.project.get_module('mod') - module_pyname = mod['mod'] - self.assertEqual((imported_module, 1), - module_pyname.get_definition_location()) + testutils.create_module(self.project, "mod") + mod = libutils.get_string_module(self.project, "import mod\n") + imported_module = self.project.get_module("mod") + module_pyname = mod["mod"] + self.assertEqual((imported_module, 1), module_pyname.get_definition_location()) def test_get_pyname_definition_location_imports(self): - module_resource = testutils.create_module(self.project, 'mod') - module_resource.write('\ndef a_func():\n pass\n') - imported_module = self.project.get_module('mod') - mod = libutils.get_string_module( - self.project, 'from mod import a_func\n') - a_func = mod['a_func'] - self.assertEqual((imported_module, 2), - a_func.get_definition_location()) + module_resource = testutils.create_module(self.project, "mod") + module_resource.write("\ndef a_func():\n pass\n") + imported_module = self.project.get_module("mod") + mod = libutils.get_string_module(self.project, "from mod import a_func\n") + a_func = mod["a_func"] + self.assertEqual((imported_module, 2), a_func.get_definition_location()) def test_get_pyname_definition_location_parameters(self): - code = 'def a_func(param1, param2):\n a_var = param\n' + code = "def a_func(param1, param2):\n a_var = param\n" mod = libutils.get_string_module(self.project, code) a_func_scope = mod.get_scope().get_scopes()[0] - param1 = a_func_scope['param1'] + param1 = a_func_scope["param1"] self.assertEqual((mod, 1), param1.get_definition_location()) - param2 = a_func_scope['param2'] + param2 = a_func_scope["param2"] self.assertEqual((mod, 1), param2.get_definition_location()) def test_module_get_resource(self): - module_resource = testutils.create_module(self.project, 'mod') - module = self.project.get_module('mod') + module_resource = testutils.create_module(self.project, "mod") + module = self.project.get_module("mod") self.assertEqual(module_resource, module.get_resource()) string_module = libutils.get_string_module( - self.project, 'from mod import a_func\n') + self.project, "from mod import a_func\n" + ) self.assertEqual(None, string_module.get_resource()) def test_get_pyname_definition_location_class2(self): - code = 'class AClass(object):\n' \ - ' def __init__(self):\n' \ - ' self.an_attr = 10\n' + code = ( + "class AClass(object):\n" + " def __init__(self):\n" + " self.an_attr = 10\n" + ) mod = libutils.get_string_module(self.project, code) - a_class = mod['AClass'].get_object() - an_attr = a_class['an_attr'] + a_class = mod["AClass"].get_object() + an_attr = a_class["an_attr"] self.assertEqual((mod, 3), an_attr.get_definition_location()) def test_import_not_found_module_get_definition_location(self): - mod = libutils.get_string_module( - self.project, 'import doesnotexist\n') - does_not_exist = mod['doesnotexist'] - self.assertEqual((None, None), - does_not_exist.get_definition_location()) + mod = libutils.get_string_module(self.project, "import doesnotexist\n") + does_not_exist = mod["doesnotexist"] + self.assertEqual((None, None), does_not_exist.get_definition_location()) def test_from_not_found_module_get_definition_location(self): mod = libutils.get_string_module( - self.project, 'from doesnotexist import Sample\n') - sample = mod['Sample'] + self.project, "from doesnotexist import Sample\n" + ) + sample = mod["Sample"] self.assertEqual((None, None), sample.get_definition_location()) def test_from_package_import_module_get_definition_location(self): - pkg = testutils.create_package(self.project, 'pkg') - testutils.create_module(self.project, 'mod', pkg) - pkg_mod = self.project.get_module('pkg.mod') - mod = libutils.get_string_module( - self.project, 'from pkg import mod\n') - imported_mod = mod['mod'] - self.assertEqual((pkg_mod, 1), - imported_mod.get_definition_location()) + pkg = testutils.create_package(self.project, "pkg") + testutils.create_module(self.project, "mod", pkg) + pkg_mod = self.project.get_module("pkg.mod") + mod = libutils.get_string_module(self.project, "from pkg import mod\n") + imported_mod = mod["mod"] + self.assertEqual((pkg_mod, 1), imported_mod.get_definition_location()) def test_get_module_for_defined_pyobjects(self): mod = libutils.get_string_module( - self.project, 'class AClass(object):\n pass\n') - a_class = mod['AClass'].get_object() + self.project, "class AClass(object):\n pass\n" + ) + a_class = mod["AClass"].get_object() self.assertEqual(mod, a_class.get_module()) def test_get_definition_location_for_packages(self): - testutils.create_package(self.project, 'pkg') - init_module = self.project.get_module('pkg.__init__') - mod = libutils.get_string_module(self.project, 'import pkg\n') - pkg_pyname = mod['pkg'] - self.assertEqual((init_module, 1), - pkg_pyname.get_definition_location()) + testutils.create_package(self.project, "pkg") + init_module = self.project.get_module("pkg.__init__") + mod = libutils.get_string_module(self.project, "import pkg\n") + pkg_pyname = mod["pkg"] + self.assertEqual((init_module, 1), pkg_pyname.get_definition_location()) def test_get_definition_location_for_filtered_packages(self): - pkg = testutils.create_package(self.project, 'pkg') - testutils.create_module(self.project, 'mod', pkg) - init_module = self.project.get_module('pkg.__init__') - mod = libutils.get_string_module(self.project, 'import pkg.mod') - pkg_pyname = mod['pkg'] - self.assertEqual((init_module, 1), - pkg_pyname.get_definition_location()) + pkg = testutils.create_package(self.project, "pkg") + testutils.create_module(self.project, "mod", pkg) + init_module = self.project.get_module("pkg.__init__") + mod = libutils.get_string_module(self.project, "import pkg.mod") + pkg_pyname = mod["pkg"] + self.assertEqual((init_module, 1), pkg_pyname.get_definition_location()) def test_out_of_project_modules(self): scope = libutils.get_string_scope( - self.project, 'import rope.base.project as project\n') - imported_module = scope['project'].get_object() - self.assertTrue('Project' in imported_module) + self.project, "import rope.base.project as project\n" + ) + imported_module = scope["project"].get_object() + self.assertTrue("Project" in imported_module) def test_file_encoding_reading(self): - contents = u'# -*- coding: utf-8 -*-\n' + \ - u'#\N{LATIN SMALL LETTER I WITH DIAERESIS}\n' - mod = testutils.create_module(self.project, 'mod') + contents = ( + u"# -*- coding: utf-8 -*-\n" + u"#\N{LATIN SMALL LETTER I WITH DIAERESIS}\n" + ) + mod = testutils.create_module(self.project, "mod") mod.write(contents) - self.project.get_module('mod') + self.project.get_module("mod") def test_global_keyword(self): - contents = 'a_var = 1\ndef a_func():\n global a_var\n' + contents = "a_var = 1\ndef a_func():\n global a_var\n" mod = libutils.get_string_module(self.project, contents) - global_var = mod['a_var'] - func_scope = mod['a_func'].get_object().get_scope() - local_var = func_scope['a_var'] + global_var = mod["a_var"] + func_scope = mod["a_func"].get_object().get_scope() + local_var = func_scope["a_var"] self.assertEqual(global_var, local_var) def test_not_leaking_for_vars_inside_parent_scope(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n' \ - ' def f(self):\n' \ - ' for my_var1, my_var2 in []:\n' \ - ' pass\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n" + " def f(self):\n" + " for my_var1, my_var2 in []:\n" + " pass\n" + ) mod.write(code) pymod = self.pycore.resource_to_pyobject(mod) - c_class = pymod['C'].get_object() - self.assertFalse('my_var1' in c_class) - self.assertFalse('my_var2' in c_class) + c_class = pymod["C"].get_object() + self.assertFalse("my_var1" in c_class) + self.assertFalse("my_var2" in c_class) def test_not_leaking_for_vars_inside_parent_scope2(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n' \ - ' def f(self):\n' \ - ' for my_var in []:\n' \ - ' pass\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n" + " def f(self):\n" + " for my_var in []:\n" + " pass\n" + ) mod.write(code) pymod = self.pycore.resource_to_pyobject(mod) - c_class = pymod['C'].get_object() - self.assertFalse('my_var' in c_class) + c_class = pymod["C"].get_object() + self.assertFalse("my_var" in c_class) def test_variables_defined_in_excepts(self): - mod = testutils.create_module(self.project, 'mod') - code = 'try:\n' \ - ' myvar1 = 1\n' \ - 'except:\n' \ - ' myvar2 = 1\n' \ - 'finally:\n' \ - ' myvar3 = 1\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "try:\n" + " myvar1 = 1\n" + "except:\n" + " myvar2 = 1\n" + "finally:\n" + " myvar3 = 1\n" + ) mod.write(code) pymod = self.pycore.resource_to_pyobject(mod) - self.assertTrue('myvar1' in pymod) - self.assertTrue('myvar2' in pymod) - self.assertTrue('myvar3' in pymod) + self.assertTrue("myvar1" in pymod) + self.assertTrue("myvar2" in pymod) + self.assertTrue("myvar3" in pymod) def test_not_leaking_tuple_assigned_names_inside_parent_scope(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n' \ - ' def f(self):\n' \ - ' var1, var2 = range(2)\n' + mod = testutils.create_module(self.project, "mod") + code = ( + "class C(object):\n" " def f(self):\n" " var1, var2 = range(2)\n" + ) mod.write(code) pymod = self.pycore.resource_to_pyobject(mod) - c_class = pymod['C'].get_object() - self.assertFalse('var1' in c_class) + c_class = pymod["C"].get_object() + self.assertFalse("var1" in c_class) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_with_statement_variables(self): - code = 'import threading\nwith threading.lock() as var: pass\n' + code = "import threading\nwith threading.lock() as var: pass\n" if sys.version_info < (2, 6, 0): - code = 'from __future__ import with_statement\n' + code + code = "from __future__ import with_statement\n" + code pymod = libutils.get_string_module(self.project, code) - self.assertTrue('var' in pymod) + self.assertTrue("var" in pymod) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_with_statement_variables_and_tuple_assignment(self): - code = 'class A(object):\n' \ - ' def __enter__(self):' \ - ' return (1, 2)\n' \ - ' def __exit__(self, type, value, tb):\n' \ - ' pass\n'\ - 'with A() as (a, b):\n' \ - ' pass\n' + code = ( + "class A(object):\n" + " def __enter__(self):" + " return (1, 2)\n" + " def __exit__(self, type, value, tb):\n" + " pass\n" + "with A() as (a, b):\n" + " pass\n" + ) if sys.version_info < (2, 6, 0): - code = 'from __future__ import with_statement\n' + code + code = "from __future__ import with_statement\n" + code pymod = libutils.get_string_module(self.project, code) - self.assertTrue('a' in pymod) - self.assertTrue('b' in pymod) + self.assertTrue("a" in pymod) + self.assertTrue("b" in pymod) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_with_statement_variable_type(self): - code = 'class A(object):\n' \ - ' def __enter__(self):\n' \ - ' return self\n'\ - ' def __exit__(self, type, value, tb):\n' \ - ' pass\n' \ - 'with A() as var:\n' \ - ' pass\n' + code = ( + "class A(object):\n" + " def __enter__(self):\n" + " return self\n" + " def __exit__(self, type, value, tb):\n" + " pass\n" + "with A() as var:\n" + " pass\n" + ) if sys.version_info < (2, 6, 0): - code = 'from __future__ import with_statement\n' + code + code = "from __future__ import with_statement\n" + code pymod = libutils.get_string_module(self.project, code) - a_class = pymod['A'].get_object() - var = pymod['var'].get_object() + a_class = pymod["A"].get_object() + var = pymod["var"].get_object() self.assertEqual(a_class, var.get_type()) - @testutils.only_for('2.7') + @testutils.only_for("2.7") def test_nested_with_statement_variable_type(self): - code = 'class A(object):\n' \ - ' def __enter__(self):\n' \ - ' return self\n'\ - ' def __exit__(self, type, value, tb):\n' \ - ' pass\n' \ - 'class B(object):\n' \ - ' def __enter__(self):\n' \ - ' return self\n'\ - ' def __exit__(self, type, value, tb):\n' \ - ' pass\n' \ - 'with A() as var_a, B() as var_b:\n' \ - ' pass\n' + code = ( + "class A(object):\n" + " def __enter__(self):\n" + " return self\n" + " def __exit__(self, type, value, tb):\n" + " pass\n" + "class B(object):\n" + " def __enter__(self):\n" + " return self\n" + " def __exit__(self, type, value, tb):\n" + " pass\n" + "with A() as var_a, B() as var_b:\n" + " pass\n" + ) if sys.version_info < (2, 6, 0): - code = 'from __future__ import with_statement\n' + code + code = "from __future__ import with_statement\n" + code pymod = libutils.get_string_module(self.project, code) - a_class = pymod['A'].get_object() - var_a = pymod['var_a'].get_object() + a_class = pymod["A"].get_object() + var_a = pymod["var_a"].get_object() self.assertEqual(a_class, var_a.get_type()) - b_class = pymod['B'].get_object() - var_b = pymod['var_b'].get_object() + b_class = pymod["B"].get_object() + var_b = pymod["var_b"].get_object() self.assertEqual(b_class, var_b.get_type()) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_with_statement_with_no_vars(self): code = 'with open("file"): pass\n' if sys.version_info < (2, 6, 0): - code = 'from __future__ import with_statement\n' + code + code = "from __future__ import with_statement\n" + code pymod = libutils.get_string_module(self.project, code) pymod.get_attributes() def test_with_statement(self): - code = 'a = 10\n' \ - 'with open("file") as f: pass\n' + code = "a = 10\n" 'with open("file") as f: pass\n' pymod = libutils.get_string_module(self.project, code) - assigned = pymod.get_attribute('a') + assigned = pymod.get_attribute("a") self.assertEqual(BuiltinClass, type(assigned.get_object().get_type())) - assigned = pymod.get_attribute('f') + assigned = pymod.get_attribute("f") self.assertEqual(File, type(assigned.get_object().get_type())) def test_check_for_else_block(self): - code = 'for i in range(10):\n' \ - ' pass\n' \ - 'else:\n' \ - ' myvar = 1\n' + code = "for i in range(10):\n" " pass\n" "else:\n" " myvar = 1\n" mod = libutils.get_string_module(self.project, code) - a_var = mod['myvar'] + a_var = mod["myvar"] self.assertEqual((mod, 4), a_var.get_definition_location()) def test_check_names_defined_in_whiles(self): - mod = libutils.get_string_module( - self.project, 'while False:\n myvar = 1\n') - a_var = mod['myvar'] + mod = libutils.get_string_module(self.project, "while False:\n myvar = 1\n") + a_var = mod["myvar"] self.assertEqual((mod, 2), a_var.get_definition_location()) def test_get_definition_location_in_tuple_assnames(self): mod = libutils.get_string_module( - self.project, 'def f(x):\n x.z, a = range(2)\n') - x = mod['f'].get_object().get_scope()['x'] - a = mod['f'].get_object().get_scope()['a'] + self.project, "def f(x):\n x.z, a = range(2)\n" + ) + x = mod["f"].get_object().get_scope()["x"] + a = mod["f"].get_object().get_scope()["a"] self.assertEqual((mod, 1), x.get_definition_location()) self.assertEqual((mod, 2), a.get_definition_location()) def test_syntax_errors_in_code(self): with self.assertRaises(exceptions.ModuleSyntaxError): - libutils.get_string_module(self.project, 'xyx print\n') + libutils.get_string_module(self.project, "xyx print\n") def test_holding_error_location_information(self): try: - libutils.get_string_module(self.project, 'xyx print\n') + libutils.get_string_module(self.project, "xyx print\n") except exceptions.ModuleSyntaxError as e: self.assertEqual(1, e.lineno) def test_no_exceptions_on_module_encoding_problems(self): - mod = testutils.create_module(self.project, 'mod') - contents = b'\nsdsdsd\n\xa9\n' - file = open(mod.real_path, 'wb') + mod = testutils.create_module(self.project, "mod") + contents = b"\nsdsdsd\n\xa9\n" + file = open(mod.real_path, "wb") file.write(contents) file.close() mod.read() def test_syntax_errors_when_cannot_decode_file2(self): - mod = testutils.create_module(self.project, 'mod') - contents = b'\n\xa9\n' - file = open(mod.real_path, 'wb') + mod = testutils.create_module(self.project, "mod") + contents = b"\n\xa9\n" + file = open(mod.real_path, "wb") file.write(contents) file.close() with self.assertRaises(exceptions.ModuleSyntaxError): self.pycore.resource_to_pyobject(mod) def test_syntax_errors_when_null_bytes(self): - mod = testutils.create_module(self.project, 'mod') - contents = b'\n\x00\n' - file = open(mod.real_path, 'wb') + mod = testutils.create_module(self.project, "mod") + contents = b"\n\x00\n" + file = open(mod.real_path, "wb") file.write(contents) file.close() with self.assertRaises(exceptions.ModuleSyntaxError): self.pycore.resource_to_pyobject(mod) def test_syntax_errors_when_bad_strs(self): - mod = testutils.create_module(self.project, 'mod') + mod = testutils.create_module(self.project, "mod") contents = b'\n"\\x0"\n' - file = open(mod.real_path, 'wb') + file = open(mod.real_path, "wb") file.write(contents) file.close() with self.assertRaises(exceptions.ModuleSyntaxError): self.pycore.resource_to_pyobject(mod) def test_not_reaching_maximum_recursions_with_from_star_imports(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('from mod2 import *\n') - mod2.write('from mod1 import *\n') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("from mod2 import *\n") + mod2.write("from mod1 import *\n") pymod1 = self.pycore.resource_to_pyobject(mod1) pymod1.get_attributes() def test_not_reaching_maximum_recursions_when_importing_variables(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('from mod2 import myvar\n') - mod2.write('from mod1 import myvar\n') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("from mod2 import myvar\n") + mod2.write("from mod1 import myvar\n") pymod1 = self.pycore.resource_to_pyobject(mod1) - pymod1['myvar'].get_object() + pymod1["myvar"].get_object() def test_not_reaching_maximum_recursions_when_importing_variables2(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('from mod1 import myvar\n') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("from mod1 import myvar\n") pymod1 = self.pycore.resource_to_pyobject(mod1) - pymod1['myvar'].get_object() + pymod1["myvar"].get_object() def test_pyobject_equality_should_compare_types(self): - mod1 = testutils.create_module(self.project, 'mod1') + mod1 = testutils.create_module(self.project, "mod1") mod1.write('var1 = ""\nvar2 = ""\n') pymod1 = self.pycore.resource_to_pyobject(mod1) - self.assertEqual(pymod1['var1'].get_object(), - pymod1['var2'].get_object()) + self.assertEqual(pymod1["var1"].get_object(), pymod1["var2"].get_object()) class PyCoreInProjectsTest(unittest.TestCase): - def setUp(self): super(self.__class__, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - samplemod = testutils.create_module(self.project, 'samplemod') - code = 'class SampleClass(object):\n' \ - ' def sample_method():\n' \ - ' pass\n\n' \ - 'def sample_func():\n' \ - ' pass\n' \ - 'sample_var = 10\n\n' \ - 'def _underlined_func():\n' \ - ' pass\n\n' + samplemod = testutils.create_module(self.project, "samplemod") + code = ( + "class SampleClass(object):\n" + " def sample_method():\n" + " pass\n\n" + "def sample_func():\n" + " pass\n" + "sample_var = 10\n\n" + "def _underlined_func():\n" + " pass\n\n" + ) samplemod.write(code) - package = testutils.create_package(self.project, 'package') - testutils.create_module(self.project, 'nestedmod', package) + package = testutils.create_package(self.project, "package") + testutils.create_module(self.project, "nestedmod", package) def tearDown(self): testutils.remove_project(self.project) super(self.__class__, self).tearDown() def test_simple_import(self): - mod = libutils.get_string_module( - self.project, 'import samplemod\n') - samplemod = mod['samplemod'].get_object() - self.assertEqual(get_base_type('Module'), samplemod.get_type()) + mod = libutils.get_string_module(self.project, "import samplemod\n") + samplemod = mod["samplemod"].get_object() + self.assertEqual(get_base_type("Module"), samplemod.get_type()) def test_from_import_class(self): mod = libutils.get_string_module( - self.project, 'from samplemod import SampleClass\n') - result = mod['SampleClass'].get_object() - self.assertEqual(get_base_type('Type'), result.get_type()) - self.assertTrue('sample_func' not in mod.get_attributes()) + self.project, "from samplemod import SampleClass\n" + ) + result = mod["SampleClass"].get_object() + self.assertEqual(get_base_type("Type"), result.get_type()) + self.assertTrue("sample_func" not in mod.get_attributes()) def test_from_import_star(self): - mod = libutils.get_string_module( - self.project, 'from samplemod import *\n') - self.assertEqual(get_base_type('Type'), - mod['SampleClass'].get_object().get_type()) - self.assertEqual(get_base_type('Function'), - mod['sample_func'].get_object().get_type()) - self.assertTrue(mod['sample_var'] is not None) + mod = libutils.get_string_module(self.project, "from samplemod import *\n") + self.assertEqual( + get_base_type("Type"), mod["SampleClass"].get_object().get_type() + ) + self.assertEqual( + get_base_type("Function"), mod["sample_func"].get_object().get_type() + ) + self.assertTrue(mod["sample_var"] is not None) def test_from_import_star_overwriting(self): - code = 'from samplemod import *\n' \ - 'class SampleClass(object):\n pass\n' + code = "from samplemod import *\n" "class SampleClass(object):\n pass\n" mod = libutils.get_string_module(self.project, code) - samplemod = self.project.get_module('samplemod') - sample_class = samplemod['SampleClass'].get_object() - self.assertNotEqual(sample_class, - mod.get_attributes()['SampleClass'].get_object()) + samplemod = self.project.get_module("samplemod") + sample_class = samplemod["SampleClass"].get_object() + self.assertNotEqual( + sample_class, mod.get_attributes()["SampleClass"].get_object() + ) def test_from_import_star_not_imporing_underlined(self): - mod = libutils.get_string_module( - self.project, 'from samplemod import *') - self.assertTrue('_underlined_func' not in mod.get_attributes()) + mod = libutils.get_string_module(self.project, "from samplemod import *") + self.assertTrue("_underlined_func" not in mod.get_attributes()) def test_from_import_star_imports_in_functions(self): mod = libutils.get_string_module( - self.project, 'def f():\n from os import *\n') - mod['f'].get_object().get_scope().get_names() + self.project, "def f():\n from os import *\n" + ) + mod["f"].get_object().get_scope().get_names() def test_from_package_import_mod(self): mod = libutils.get_string_module( - self.project, 'from package import nestedmod\n') - self.assertEqual(get_base_type('Module'), - mod['nestedmod'].get_object().get_type()) + self.project, "from package import nestedmod\n" + ) + self.assertEqual( + get_base_type("Module"), mod["nestedmod"].get_object().get_type() + ) # XXX: Deciding to import everything on import start from packages def xxx_test_from_package_import_star(self): - mod = libutils.get_string_module( - self.project, 'from package import *\n') - self.assertTrue('nestedmod' not in mod.get_attributes()) + mod = libutils.get_string_module(self.project, "from package import *\n") + self.assertTrue("nestedmod" not in mod.get_attributes()) def test_unknown_when_module_cannot_be_found(self): mod = libutils.get_string_module( - self.project, 'from doesnotexist import nestedmod\n') - self.assertTrue('nestedmod' in mod) + self.project, "from doesnotexist import nestedmod\n" + ) + self.assertTrue("nestedmod" in mod) def test_from_import_function(self): - code = 'def f():\n from samplemod import SampleClass\n' + code = "def f():\n from samplemod import SampleClass\n" scope = libutils.get_string_scope(self.project, code) - self.assertEqual(get_base_type('Type'), - scope.get_scopes()[0]['SampleClass']. - get_object().get_type()) + self.assertEqual( + get_base_type("Type"), + scope.get_scopes()[0]["SampleClass"].get_object().get_type(), + ) def test_circular_imports(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('import mod2\n') - mod2.write('import mod1\n') - self.project.get_module('mod1') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("import mod2\n") + mod2.write("import mod1\n") + self.project.get_module("mod1") def test_circular_imports2(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write( - 'from mod2 import Sample2\nclass Sample1(object):\n pass\n') - mod2.write( - 'from mod1 import Sample1\nclass Sample2(object):\n pass\n') - self.project.get_module('mod1').get_attributes() + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("from mod2 import Sample2\nclass Sample1(object):\n pass\n") + mod2.write("from mod1 import Sample1\nclass Sample2(object):\n pass\n") + self.project.get_module("mod1").get_attributes() def test_multi_dot_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - pkg_mod = testutils.create_module(self.project, 'mod', pkg) - pkg_mod.write('def sample_func():\n pass\n') - mod = libutils.get_string_module(self.project, 'import pkg.mod\n') - self.assertTrue('pkg' in mod) - self.assertTrue('sample_func' in mod['pkg'].get_object()['mod']. - get_object()) + pkg = testutils.create_package(self.project, "pkg") + pkg_mod = testutils.create_module(self.project, "mod", pkg) + pkg_mod.write("def sample_func():\n pass\n") + mod = libutils.get_string_module(self.project, "import pkg.mod\n") + self.assertTrue("pkg" in mod) + self.assertTrue("sample_func" in mod["pkg"].get_object()["mod"].get_object()) def test_multi_dot_imports2(self): - pkg = testutils.create_package(self.project, 'pkg') - testutils.create_module(self.project, 'mod1', pkg) - testutils.create_module(self.project, 'mod2', pkg) + pkg = testutils.create_package(self.project, "pkg") + testutils.create_module(self.project, "mod1", pkg) + testutils.create_module(self.project, "mod2", pkg) mod = libutils.get_string_module( - self.project, 'import pkg.mod1\nimport pkg.mod2\n') - package = mod['pkg'].get_object() + self.project, "import pkg.mod1\nimport pkg.mod2\n" + ) + package = mod["pkg"].get_object() self.assertEqual(2, len(package.get_attributes())) - self.assertTrue('mod1' in package and - 'mod2' in package) + self.assertTrue("mod1" in package and "mod2" in package) def test_multi_dot_imports3(self): - pkg1 = testutils.create_package(self.project, 'pkg1') - pkg2 = testutils.create_package(self.project, 'pkg2', pkg1) - testutils.create_module(self.project, 'mod1', pkg2) - testutils.create_module(self.project, 'mod2', pkg2) - code = 'import pkg1.pkg2.mod1\nimport pkg1.pkg2.mod2\n' + pkg1 = testutils.create_package(self.project, "pkg1") + pkg2 = testutils.create_package(self.project, "pkg2", pkg1) + testutils.create_module(self.project, "mod1", pkg2) + testutils.create_module(self.project, "mod2", pkg2) + code = "import pkg1.pkg2.mod1\nimport pkg1.pkg2.mod2\n" mod = libutils.get_string_module(self.project, code) - package1 = mod['pkg1'].get_object() - package2 = package1['pkg2'].get_object() + package1 = mod["pkg1"].get_object() + package2 = package1["pkg2"].get_object() self.assertEqual(2, len(package2.get_attributes())) - self.assertTrue('mod1' in package2 and 'mod2' in package2) + self.assertTrue("mod1" in package2 and "mod2" in package2) def test_multi_dot_imports_as(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod1.write('def f():\n pass\n') - mod = libutils.get_string_module( - self.project, 'import pkg.mod1 as mod1\n') - module = mod['mod1'].get_object() - self.assertTrue('f' in module) + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod1.write("def f():\n pass\n") + mod = libutils.get_string_module(self.project, "import pkg.mod1 as mod1\n") + module = mod["mod1"].get_object() + self.assertTrue("f" in module) # TODO: not showing unimported names as attributes of packages def xxx_test_from_package_import_package(self): - pkg1 = testutils.create_package(self.project, 'pkg1') - pkg2 = testutils.create_package(self.project, 'pkg2', pkg1) - testutils.create_module(self.project, 'mod', pkg2) - mod = libutils.get_string_module( - self.project, 'from pkg1 import pkg2\n') - package = mod['pkg2'] + pkg1 = testutils.create_package(self.project, "pkg1") + pkg2 = testutils.create_package(self.project, "pkg2", pkg1) + testutils.create_module(self.project, "mod", pkg2) + mod = libutils.get_string_module(self.project, "from pkg1 import pkg2\n") + package = mod["pkg2"] self.assertEqual(0, len(package.get_attributes())) def test_invalidating_cache_after_resource_change(self): - module = testutils.create_module(self.project, 'mod') - module.write('import sys\n') - mod1 = self.project.get_module('mod') - self.assertTrue('var' not in mod1.get_attributes()) - module.write('var = 10\n') - mod2 = self.project.get_module('mod') - self.assertTrue('var' in mod2) + module = testutils.create_module(self.project, "mod") + module.write("import sys\n") + mod1 = self.project.get_module("mod") + self.assertTrue("var" not in mod1.get_attributes()) + module.write("var = 10\n") + mod2 = self.project.get_module("mod") + self.assertTrue("var" in mod2) def test_invalidating_cache_after_resource_change_for_init_dot_pys(self): - pkg = testutils.create_package(self.project, 'pkg') - mod = testutils.create_module(self.project, 'mod') - init_dot_py = pkg.get_child('__init__.py') - init_dot_py.write('a_var = 10\n') - mod.write('import pkg\n') - pymod = self.project.get_module('mod') - self.assertTrue('a_var' in pymod['pkg'].get_object()) - init_dot_py.write('new_var = 10\n') - self.assertTrue('a_var' not in - pymod['pkg'].get_object().get_attributes()) + pkg = testutils.create_package(self.project, "pkg") + mod = testutils.create_module(self.project, "mod") + init_dot_py = pkg.get_child("__init__.py") + init_dot_py.write("a_var = 10\n") + mod.write("import pkg\n") + pymod = self.project.get_module("mod") + self.assertTrue("a_var" in pymod["pkg"].get_object()) + init_dot_py.write("new_var = 10\n") + self.assertTrue("a_var" not in pymod["pkg"].get_object().get_attributes()) def test_invalidating_cache_after_rsrc_chng_for_nested_init_dot_pys(self): - pkg1 = testutils.create_package(self.project, 'pkg1') - pkg2 = testutils.create_package(self.project, 'pkg2', pkg1) - mod = testutils.create_module(self.project, 'mod') - init_dot_py = pkg2.get_child('__init__.py') - init_dot_py.write('a_var = 10\n') - mod.write('import pkg1\n') - pymod = self.project.get_module('mod') - self.assertTrue('a_var' in - pymod['pkg1'].get_object()['pkg2'].get_object()) - init_dot_py.write('new_var = 10\n') - self.assertTrue('a_var' not in - pymod['pkg1'].get_object()['pkg2'].get_object()) + pkg1 = testutils.create_package(self.project, "pkg1") + pkg2 = testutils.create_package(self.project, "pkg2", pkg1) + mod = testutils.create_module(self.project, "mod") + init_dot_py = pkg2.get_child("__init__.py") + init_dot_py.write("a_var = 10\n") + mod.write("import pkg1\n") + pymod = self.project.get_module("mod") + self.assertTrue("a_var" in pymod["pkg1"].get_object()["pkg2"].get_object()) + init_dot_py.write("new_var = 10\n") + self.assertTrue("a_var" not in pymod["pkg1"].get_object()["pkg2"].get_object()) def test_from_import_nonexistent_module(self): - code = 'from doesnotexistmod import DoesNotExistClass\n' + code = "from doesnotexistmod import DoesNotExistClass\n" mod = libutils.get_string_module(self.project, code) - self.assertTrue('DoesNotExistClass' in mod) - self.assertEqual(get_base_type('Unknown'), - mod['DoesNotExistClass']. - get_object().get_type()) + self.assertTrue("DoesNotExistClass" in mod) + self.assertEqual( + get_base_type("Unknown"), mod["DoesNotExistClass"].get_object().get_type() + ) def test_from_import_nonexistent_name(self): - code = 'from samplemod import DoesNotExistClass\n' + code = "from samplemod import DoesNotExistClass\n" mod = libutils.get_string_module(self.project, code) - self.assertTrue('DoesNotExistClass' in mod) - self.assertEqual(get_base_type('Unknown'), - mod['DoesNotExistClass']. - get_object().get_type()) + self.assertTrue("DoesNotExistClass" in mod) + self.assertEqual( + get_base_type("Unknown"), mod["DoesNotExistClass"].get_object().get_type() + ) def test_not_considering_imported_names_as_sub_scopes(self): - code = 'from samplemod import SampleClass\n' + code = "from samplemod import SampleClass\n" scope = libutils.get_string_scope(self.project, code) self.assertEqual(0, len(scope.get_scopes())) def test_not_considering_imported_modules_as_sub_scopes(self): - scope = libutils.get_string_scope( - self.project, 'import samplemod\n') + scope = libutils.get_string_scope(self.project, "import samplemod\n") self.assertEqual(0, len(scope.get_scopes())) def test_inheriting_dotted_base_class(self): - code = 'import samplemod\n' \ - 'class Derived(samplemod.SampleClass):\n' \ - ' pass\n' + code = ( + "import samplemod\n" "class Derived(samplemod.SampleClass):\n" " pass\n" + ) mod = libutils.get_string_module(self.project, code) - derived = mod['Derived'].get_object() - self.assertTrue('sample_method' in derived) + derived = mod["Derived"].get_object() + self.assertTrue("sample_method" in derived) def test_self_in_methods(self): - code = 'class Sample(object):\n' \ - ' def func(self):\n' \ - ' pass\n' + code = "class Sample(object):\n" " def func(self):\n" " pass\n" scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() + sample_class = scope["Sample"].get_object() func_scope = scope.get_scopes()[0].get_scopes()[0] - self.assertEqual(sample_class, - func_scope['self'].get_object().get_type()) - self.assertTrue('func' in func_scope['self'].get_object()) + self.assertEqual(sample_class, func_scope["self"].get_object().get_type()) + self.assertTrue("func" in func_scope["self"].get_object()) def test_none_assignments_in_classes(self): - code = 'class C(object):\n' \ - ' var = ""\n' \ - ' def f(self):\n' \ - ' self.var += "".join([])\n' + code = ( + "class C(object):\n" + ' var = ""\n' + " def f(self):\n" + ' self.var += "".join([])\n' + ) scope = libutils.get_string_scope(self.project, code) - c_class = scope['C'].get_object() - self.assertTrue('var' in c_class) + c_class = scope["C"].get_object() + self.assertTrue("var" in c_class) def test_self_in_methods_with_decorators(self): - code = 'class Sample(object):\n' \ - ' @staticmethod\n' \ - ' def func(self):\n' \ - ' pass\n' + code = ( + "class Sample(object):\n" + " @staticmethod\n" + " def func(self):\n" + " pass\n" + ) scope = libutils.get_string_scope(self.project, code) - sample_class = scope['Sample'].get_object() + sample_class = scope["Sample"].get_object() func_scope = scope.get_scopes()[0].get_scopes()[0] - self.assertNotEqual(sample_class, - func_scope['self'].get_object().get_type()) + self.assertNotEqual(sample_class, func_scope["self"].get_object().get_type()) def test_location_of_imports_when_importing(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('from samplemod import SampleClass\n') - scope = libutils.get_string_scope( - self.project, 'from mod import SampleClass\n') - sample_class = scope['SampleClass'] - samplemod = self.project.get_module('samplemod') - self.assertEqual((samplemod, 1), - sample_class.get_definition_location()) + mod = testutils.create_module(self.project, "mod") + mod.write("from samplemod import SampleClass\n") + scope = libutils.get_string_scope(self.project, "from mod import SampleClass\n") + sample_class = scope["SampleClass"] + samplemod = self.project.get_module("samplemod") + self.assertEqual((samplemod, 1), sample_class.get_definition_location()) def test_nested_modules(self): - pkg = testutils.create_package(self.project, 'pkg') - testutils.create_module(self.project, 'mod', pkg) - imported_module = self.project.get_module('pkg.mod') - scope = libutils.get_string_scope(self.project, 'import pkg.mod\n') - mod_pyobject = scope['pkg'].get_object()['mod'] - self.assertEqual((imported_module, 1), - mod_pyobject.get_definition_location()) + pkg = testutils.create_package(self.project, "pkg") + testutils.create_module(self.project, "mod", pkg) + imported_module = self.project.get_module("pkg.mod") + scope = libutils.get_string_scope(self.project, "import pkg.mod\n") + mod_pyobject = scope["pkg"].get_object()["mod"] + self.assertEqual((imported_module, 1), mod_pyobject.get_definition_location()) def test_reading_init_dot_py(self): - pkg = testutils.create_package(self.project, 'pkg') - init_dot_py = pkg.get_child('__init__.py') - init_dot_py.write('a_var = 1\n') - pkg_object = self.project.get_module('pkg') - self.assertTrue('a_var' in pkg_object) + pkg = testutils.create_package(self.project, "pkg") + init_dot_py = pkg.get_child("__init__.py") + init_dot_py.write("a_var = 1\n") + pkg_object = self.project.get_module("pkg") + self.assertTrue("a_var" in pkg_object) def test_relative_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod2.write('import mod1\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod2.write("import mod1\n") mod1_object = self.pycore.resource_to_pyobject(mod1) mod2_object = self.pycore.resource_to_pyobject(mod2) - self.assertEqual(mod1_object, - mod2_object.get_attributes()['mod1'].get_object()) + self.assertEqual(mod1_object, mod2_object.get_attributes()["mod1"].get_object()) def test_relative_froms(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod1.write('def a_func():\n pass\n') - mod2.write('from mod1 import a_func\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod1.write("def a_func():\n pass\n") + mod2.write("from mod1 import a_func\n") mod1_object = self.pycore.resource_to_pyobject(mod1) mod2_object = self.pycore.resource_to_pyobject(mod2) - self.assertEqual(mod1_object['a_func'].get_object(), - mod2_object['a_func'].get_object()) + self.assertEqual( + mod1_object["a_func"].get_object(), mod2_object["a_func"].get_object() + ) def test_relative_imports_for_string_modules(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod2.write('import mod1\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod2.write("import mod1\n") mod1_object = self.pycore.resource_to_pyobject(mod1) - mod2_object = libutils.get_string_module( - self.project, mod2.read(), mod2) - self.assertEqual(mod1_object, mod2_object['mod1'].get_object()) + mod2_object = libutils.get_string_module(self.project, mod2.read(), mod2) + self.assertEqual(mod1_object, mod2_object["mod1"].get_object()) def test_relative_imports_for_string_scopes(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod2.write('import mod1\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod2.write("import mod1\n") mod1_object = self.pycore.resource_to_pyobject(mod1) - mod2_scope = libutils.get_string_scope(self.project, mod2.read(), - mod2) - self.assertEqual(mod1_object, mod2_scope['mod1'].get_object()) + mod2_scope = libutils.get_string_scope(self.project, mod2.read(), mod2) + self.assertEqual(mod1_object, mod2_scope["mod1"].get_object()) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_new_style_relative_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod2.write('from . import mod1\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod2.write("from . import mod1\n") mod1_object = self.pycore.resource_to_pyobject(mod1) mod2_object = self.pycore.resource_to_pyobject(mod2) - self.assertEqual(mod1_object, mod2_object['mod1'].get_object()) + self.assertEqual(mod1_object, mod2_object["mod1"].get_object()) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_new_style_relative_imports2(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod1.write('def a_func():\n pass\n') - mod2.write('from ..mod1 import a_func\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod1.write("def a_func():\n pass\n") + mod2.write("from ..mod1 import a_func\n") mod1_object = self.pycore.resource_to_pyobject(mod1) mod2_object = self.pycore.resource_to_pyobject(mod2) - self.assertEqual(mod1_object['a_func'].get_object(), - mod2_object['a_func'].get_object()) + self.assertEqual( + mod1_object["a_func"].get_object(), mod2_object["a_func"].get_object() + ) def test_invalidating_cache_for_from_imports_after_resource_change(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('def a_func():\n print(1)\n') - mod1.write('from mod2 import a_func\na_func()\n') - - pymod1 = self.project.get_module('mod1') - pymod2 = self.project.get_module('mod2') - self.assertEqual(pymod1['a_func'].get_object(), - pymod2['a_func'].get_object()) - mod2.write(mod2.read() + '\n') - pymod2 = self.project.get_module('mod2') - self.assertEqual(pymod1['a_func'].get_object(), - pymod2['a_func'].get_object()) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("def a_func():\n print(1)\n") + mod1.write("from mod2 import a_func\na_func()\n") + + pymod1 = self.project.get_module("mod1") + pymod2 = self.project.get_module("mod2") + self.assertEqual(pymod1["a_func"].get_object(), pymod2["a_func"].get_object()) + mod2.write(mod2.read() + "\n") + pymod2 = self.project.get_module("mod2") + self.assertEqual(pymod1["a_func"].get_object(), pymod2["a_func"].get_object()) def test_invalidating_superclasses_after_change(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('class A(object):\n def func1(self):\n pass\n') - mod2.write('import mod1\nclass B(mod1.A):\n pass\n') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("class A(object):\n def func1(self):\n pass\n") + mod2.write("import mod1\nclass B(mod1.A):\n pass\n") - b_class = self.project.get_module('mod2')['B'].get_object() - self.assertTrue('func1' in b_class) + b_class = self.project.get_module("mod2")["B"].get_object() + self.assertTrue("func1" in b_class) - mod1.write('class A(object):\n def func2(self):\n pass\n') - self.assertTrue('func2' in b_class) + mod1.write("class A(object):\n def func2(self):\n pass\n") + self.assertTrue("func2" in b_class) def test_caching_pymodule_with_syntax_errors(self): - self.project.prefs['ignore_syntax_errors'] = True - self.project.prefs['automatic_soa'] = True + self.project.prefs["ignore_syntax_errors"] = True + self.project.prefs["automatic_soa"] = True self.project.pycore._init_automatic_soa() - source = 'import sys\nab cd' - mod = testutils.create_module(self.project, 'mod') + source = "import sys\nab cd" + mod = testutils.create_module(self.project, "mod") mod.write(source) from rope.contrib import fixsyntax + fixer = fixsyntax.FixSyntax(self.project, source, mod, 10) pymodule = fixer.get_pymodule() - self.assertTrue(pymodule.source_code.startswith('import sys\npass\n')) + self.assertTrue(pymodule.source_code.startswith("import sys\npass\n")) class TextChangeDetectorTest(unittest.TestCase): - def test_trivial_case(self): - detector = _TextChangeDetector('\n', '\n') + detector = _TextChangeDetector("\n", "\n") self.assertFalse(detector.is_changed(1, 1)) def test_one_line_change(self): - detector = _TextChangeDetector('1\n2\n', '1\n3\n') + detector = _TextChangeDetector("1\n2\n", "1\n3\n") self.assertFalse(detector.is_changed(1, 1)) self.assertTrue(detector.is_changed(2, 2)) def test_line_expansion(self): - detector = _TextChangeDetector('1\n2\n', '1\n3\n4\n2\n') + detector = _TextChangeDetector("1\n2\n", "1\n3\n4\n2\n") self.assertFalse(detector.is_changed(1, 1)) self.assertFalse(detector.is_changed(2, 2)) def test_line_removals(self): - detector = _TextChangeDetector('1\n3\n4\n2\n', '1\n2\n') + detector = _TextChangeDetector("1\n3\n4\n2\n", "1\n2\n") self.assertFalse(detector.is_changed(1, 1)) self.assertTrue(detector.is_changed(2, 3)) self.assertFalse(detector.is_changed(4, 4)) def test_multi_line_checks(self): - detector = _TextChangeDetector('1\n2\n', '1\n3\n') + detector = _TextChangeDetector("1\n2\n", "1\n3\n") self.assertTrue(detector.is_changed(1, 2)) def test_consume_change(self): - detector = _TextChangeDetector('1\n2\n', '1\n3\n') + detector = _TextChangeDetector("1\n2\n", "1\n3\n") self.assertTrue(detector.is_changed(1, 2)) self.assertTrue(detector.consume_changes(1, 2)) self.assertFalse(detector.is_changed(1, 2)) class PyCoreProjectConfigsTest(unittest.TestCase): - def setUp(self): super(PyCoreProjectConfigsTest, self).setUp() self.project = None @@ -1102,26 +1104,28 @@ def tearDown(self): super(PyCoreProjectConfigsTest, self).tearDown() def test_python_files_config(self): - self.project = testutils.sample_project(python_files=['myscript']) - myscript = self.project.root.create_file('myscript') + self.project = testutils.sample_project(python_files=["myscript"]) + myscript = self.project.root.create_file("myscript") self.assertTrue(self.project.pycore.is_python_file(myscript)) def test_ignore_bad_imports(self): self.project = testutils.sample_project(ignore_bad_imports=True) pymod = libutils.get_string_module( - self.project, 'import some_nonexistent_module\n') - self.assertFalse('some_nonexistent_module' in pymod) + self.project, "import some_nonexistent_module\n" + ) + self.assertFalse("some_nonexistent_module" in pymod) def test_ignore_bad_imports_for_froms(self): self.project = testutils.sample_project(ignore_bad_imports=True) pymod = libutils.get_string_module( - self.project, 'from some_nonexistent_module import var\n') - self.assertFalse('var' in pymod) + self.project, "from some_nonexistent_module import var\n" + ) + self.assertFalse("var" in pymod) def test_reporting_syntax_errors_with_force_errors(self): self.project = testutils.sample_project(ignore_syntax_errors=True) - mod = testutils.create_module(self.project, 'mod') - mod.write('syntax error ...\n') + mod = testutils.create_module(self.project, "mod") + mod.write("syntax error ...\n") with self.assertRaises(exceptions.ModuleSyntaxError): self.project.pycore.resource_to_pyobject(mod, force_errors=True) @@ -1129,15 +1133,16 @@ def test_reporting_syntax_errors_in_strings_with_force_errors(self): self.project = testutils.sample_project(ignore_syntax_errors=True) with self.assertRaises(exceptions.ModuleSyntaxError): libutils.get_string_module( - self.project, 'syntax error ...', force_errors=True) + self.project, "syntax error ...", force_errors=True + ) def test_not_raising_errors_for_strings_with_ignore_errors(self): self.project = testutils.sample_project(ignore_syntax_errors=True) - libutils.get_string_module(self.project, 'syntax error ...') + libutils.get_string_module(self.project, "syntax error ...") def test_reporting_syntax_errors_with_force_errors_for_packages(self): self.project = testutils.sample_project(ignore_syntax_errors=True) - pkg = testutils.create_package(self.project, 'pkg') - pkg.get_child('__init__.py').write('syntax error ...\n') + pkg = testutils.create_package(self.project, "pkg") + pkg.get_child("__init__.py").write("syntax error ...\n") with self.assertRaises(exceptions.ModuleSyntaxError): self.project.pycore.resource_to_pyobject(pkg, force_errors=True) diff --git a/ropetest/pyscopestest.py b/ropetest/pyscopestest.py index 742426645..ead7b03fe 100644 --- a/ropetest/pyscopestest.py +++ b/ropetest/pyscopestest.py @@ -9,7 +9,6 @@ class PyCoreScopesTest(unittest.TestCase): - def setUp(self): super(PyCoreScopesTest, self).setUp() self.project = testutils.sample_project() @@ -21,13 +20,15 @@ def tearDown(self): def test_simple_scope(self): scope = libutils.get_string_scope( - self.project, 'def sample_func():\n pass\n') - sample_func = scope['sample_func'].get_object() - self.assertEqual(get_base_type('Function'), sample_func.get_type()) + self.project, "def sample_func():\n pass\n" + ) + sample_func = scope["sample_func"].get_object() + self.assertEqual(get_base_type("Function"), sample_func.get_type()) def test_simple_function_scope(self): scope = libutils.get_string_scope( - self.project, 'def sample_func():\n a = 10\n') + self.project, "def sample_func():\n a = 10\n" + ) self.assertEqual(1, len(scope.get_scopes())) sample_func_scope = scope.get_scopes()[0] self.assertEqual(1, len(sample_func_scope.get_names())) @@ -36,94 +37,105 @@ def test_simple_function_scope(self): def test_classes_inside_function_scopes(self): scope = libutils.get_string_scope( self.project, - 'def sample_func():\n' - ' class SampleClass(object):\n pass\n') + "def sample_func():\n" " class SampleClass(object):\n pass\n", + ) self.assertEqual(1, len(scope.get_scopes())) sample_func_scope = scope.get_scopes()[0] # noqa - self.assertEqual(get_base_type('Type'), - scope.get_scopes()[0]['SampleClass']. - get_object().get_type()) + self.assertEqual( + get_base_type("Type"), + scope.get_scopes()[0]["SampleClass"].get_object().get_type(), + ) def test_list_comprehension_scope_inside_assignment(self): scope = libutils.get_string_scope( - self.project, 'a_var = [b_var + d_var for b_var, c_var in e_var]\n') + self.project, "a_var = [b_var + d_var for b_var, c_var in e_var]\n" + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['a_var', 'b_var', 'c_var'], + ["a_var", "b_var", "c_var"], ) def test_list_comprehension_scope(self): scope = libutils.get_string_scope( - self.project, '[b_var + d_var for b_var, c_var in e_var]\n') + self.project, "[b_var + d_var for b_var, c_var in e_var]\n" + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['b_var', 'c_var'], + ["b_var", "c_var"], ) def test_set_comprehension_scope(self): scope = libutils.get_string_scope( - self.project, '{b_var + d_var for b_var, c_var in e_var}\n') + self.project, "{b_var + d_var for b_var, c_var in e_var}\n" + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['b_var', 'c_var'], + ["b_var", "c_var"], ) def test_generator_comprehension_scope(self): scope = libutils.get_string_scope( - self.project, '(b_var + d_var for b_var, c_var in e_var)\n') + self.project, "(b_var + d_var for b_var, c_var in e_var)\n" + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['b_var', 'c_var'], + ["b_var", "c_var"], ) def test_dict_comprehension_scope(self): scope = libutils.get_string_scope( - self.project, '{b_var: d_var for b_var, c_var in e_var}\n') + self.project, "{b_var: d_var for b_var, c_var in e_var}\n" + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['b_var', 'c_var'], + ["b_var", "c_var"], ) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_inline_assignment_in_comprehensions(self): scope = libutils.get_string_scope( - self.project, '''[ + self.project, + """[ (a_var := b_var + (f_var := g_var)) for b_var in [(j_var := i_var) for i_var in c_var] if a_var + (h_var := d_var) - ]''') + ]""", + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['a_var', 'b_var', 'f_var', 'h_var', 'i_var', 'j_var'], + ["a_var", "b_var", "f_var", "h_var", "i_var", "j_var"], ) def test_nested_comprehension(self): scope = libutils.get_string_scope( - self.project, '''[ + self.project, + """[ b_var + d_var for b_var, c_var in [ e_var for e_var in f_var ] - ]\n''') + ]\n""", + ) self.assertEqual( list(sorted(scope.get_defined_names())), - ['b_var', 'c_var', 'e_var'], + ["b_var", "c_var", "e_var"], ) def test_simple_class_scope(self): scope = libutils.get_string_scope( self.project, - 'class SampleClass(object):\n' - ' def f(self):\n var = 10\n') + "class SampleClass(object):\n" " def f(self):\n var = 10\n", + ) self.assertEqual(1, len(scope.get_scopes())) sample_class_scope = scope.get_scopes()[0] - self.assertTrue('f' in sample_class_scope) + self.assertTrue("f" in sample_class_scope) self.assertEqual(1, len(sample_class_scope.get_scopes())) f_in_class = sample_class_scope.get_scopes()[0] - self.assertTrue('var' in f_in_class) + self.assertTrue("var" in f_in_class) def test_get_lineno(self): scope = libutils.get_string_scope( - self.project, '\ndef sample_func():\n a = 10\n') + self.project, "\ndef sample_func():\n a = 10\n" + ) self.assertEqual(1, len(scope.get_scopes())) sample_func_scope = scope.get_scopes()[0] self.assertEqual(1, scope.get_start()) @@ -132,165 +144,169 @@ def test_get_lineno(self): def test_scope_kind(self): scope = libutils.get_string_scope( self.project, - 'class SampleClass(object):\n pass\n' - 'def sample_func():\n pass\n') + "class SampleClass(object):\n pass\n" "def sample_func():\n pass\n", + ) sample_class_scope = scope.get_scopes()[0] sample_func_scope = scope.get_scopes()[1] - self.assertEqual('Module', scope.get_kind()) - self.assertEqual('Class', sample_class_scope.get_kind()) - self.assertEqual('Function', sample_func_scope.get_kind()) + self.assertEqual("Module", scope.get_kind()) + self.assertEqual("Class", sample_class_scope.get_kind()) + self.assertEqual("Function", sample_func_scope.get_kind()) def test_function_parameters_in_scope_names(self): scope = libutils.get_string_scope( - self.project, 'def sample_func(param):\n a = 10\n') + self.project, "def sample_func(param):\n a = 10\n" + ) sample_func_scope = scope.get_scopes()[0] - self.assertTrue('param' in sample_func_scope) + self.assertTrue("param" in sample_func_scope) def test_get_names_contains_only_names_defined_in_a_scope(self): scope = libutils.get_string_scope( - self.project, - 'var1 = 10\ndef sample_func(param):\n var2 = 20\n') + self.project, "var1 = 10\ndef sample_func(param):\n var2 = 20\n" + ) sample_func_scope = scope.get_scopes()[0] - self.assertTrue('var1' not in sample_func_scope) + self.assertTrue("var1" not in sample_func_scope) def test_scope_lookup(self): scope = libutils.get_string_scope( - self.project, - 'var1 = 10\ndef sample_func(param):\n var2 = 20\n') - self.assertTrue(scope.lookup('var2') is None) - self.assertEqual(get_base_type('Function'), - scope.lookup('sample_func').get_object().get_type()) + self.project, "var1 = 10\ndef sample_func(param):\n var2 = 20\n" + ) + self.assertTrue(scope.lookup("var2") is None) + self.assertEqual( + get_base_type("Function"), + scope.lookup("sample_func").get_object().get_type(), + ) sample_func_scope = scope.get_scopes()[0] - self.assertTrue(sample_func_scope.lookup('var1') is not None) + self.assertTrue(sample_func_scope.lookup("var1") is not None) def test_function_scopes(self): - scope = libutils.get_string_scope( - self.project, 'def func():\n var = 10\n') + scope = libutils.get_string_scope(self.project, "def func():\n var = 10\n") func_scope = scope.get_scopes()[0] - self.assertTrue('var' in func_scope) + self.assertTrue("var" in func_scope) def test_function_scopes_classes(self): scope = libutils.get_string_scope( - self.project, - 'def func():\n class Sample(object):\n pass\n') + self.project, "def func():\n class Sample(object):\n pass\n" + ) func_scope = scope.get_scopes()[0] - self.assertTrue('Sample' in func_scope) + self.assertTrue("Sample" in func_scope) def test_function_getting_scope(self): - mod = libutils.get_string_module( - self.project, 'def func(): var = 10\n') - func_scope = mod['func'].get_object().get_scope() - self.assertTrue('var' in func_scope) + mod = libutils.get_string_module(self.project, "def func(): var = 10\n") + func_scope = mod["func"].get_object().get_scope() + self.assertTrue("var" in func_scope) def test_scopes_in_function_scopes(self): scope = libutils.get_string_scope( - self.project, - 'def func():\n def inner():\n var = 10\n') + self.project, "def func():\n def inner():\n var = 10\n" + ) func_scope = scope.get_scopes()[0] inner_scope = func_scope.get_scopes()[0] - self.assertTrue('var' in inner_scope) + self.assertTrue("var" in inner_scope) def test_for_variables_in_scopes(self): scope = libutils.get_string_scope( - self.project, 'for a_var in range(10):\n pass\n') - self.assertTrue('a_var' in scope) + self.project, "for a_var in range(10):\n pass\n" + ) + self.assertTrue("a_var" in scope) def test_assists_inside_fors(self): scope = libutils.get_string_scope( - self.project, 'for i in range(10):\n a_var = i\n') - self.assertTrue('a_var' in scope) + self.project, "for i in range(10):\n a_var = i\n" + ) + self.assertTrue("a_var" in scope) def test_first_parameter_of_a_method(self): - code = 'class AClass(object):\n' \ - ' def a_func(self, param):\n pass\n' - a_class = libutils.get_string_module(self.project, code)['AClass'].\ - get_object() - function_scope = a_class['a_func'].get_object().get_scope() - self.assertEqual(a_class, - function_scope['self'].get_object().get_type()) - self.assertNotEqual(a_class, function_scope['param']. - get_object().get_type()) + code = "class AClass(object):\n" " def a_func(self, param):\n pass\n" + a_class = libutils.get_string_module(self.project, code)["AClass"].get_object() + function_scope = a_class["a_func"].get_object().get_scope() + self.assertEqual(a_class, function_scope["self"].get_object().get_type()) + self.assertNotEqual(a_class, function_scope["param"].get_object().get_type()) def test_first_parameter_of_static_methods(self): - code = 'class AClass(object):\n' \ - ' @staticmethod\n def a_func(param):\n pass\n' - a_class = libutils.get_string_module(self.project, code)['AClass'].\ - get_object() - function_scope = a_class['a_func'].\ - get_object().get_scope() - self.assertNotEqual(a_class, - function_scope['param'].get_object().get_type()) + code = ( + "class AClass(object):\n" + " @staticmethod\n def a_func(param):\n pass\n" + ) + a_class = libutils.get_string_module(self.project, code)["AClass"].get_object() + function_scope = a_class["a_func"].get_object().get_scope() + self.assertNotEqual(a_class, function_scope["param"].get_object().get_type()) def test_first_parameter_of_class_methods(self): - code = 'class AClass(object):\n' \ - ' @classmethod\n def a_func(cls):\n pass\n' - a_class = libutils.get_string_module(self.project, code)['AClass'].\ - get_object() - function_scope = a_class['a_func'].get_object().get_scope() - self.assertEqual(a_class, function_scope['cls'].get_object()) + code = ( + "class AClass(object):\n" + " @classmethod\n def a_func(cls):\n pass\n" + ) + a_class = libutils.get_string_module(self.project, code)["AClass"].get_object() + function_scope = a_class["a_func"].get_object().get_scope() + self.assertEqual(a_class, function_scope["cls"].get_object()) def test_first_parameter_with_self_as_name_and_unknown_decorator(self): - code = 'def my_decorator(func):\n return func\n'\ - 'class AClass(object):\n' \ - ' @my_decorator\n def a_func(self):\n pass\n' - a_class = libutils.get_string_module(self.project, code)['AClass'].\ - get_object() - function_scope = a_class['a_func'].get_object().get_scope() - self.assertEqual(a_class, function_scope['self']. - get_object().get_type()) + code = ( + "def my_decorator(func):\n return func\n" + "class AClass(object):\n" + " @my_decorator\n def a_func(self):\n pass\n" + ) + a_class = libutils.get_string_module(self.project, code)["AClass"].get_object() + function_scope = a_class["a_func"].get_object().get_scope() + self.assertEqual(a_class, function_scope["self"].get_object().get_type()) def test_inside_class_scope_attribute_lookup(self): scope = libutils.get_string_scope( self.project, - 'class C(object):\n' - ' an_attr = 1\n' - ' def a_func(self):\n pass') + "class C(object):\n" + " an_attr = 1\n" + " def a_func(self):\n pass", + ) self.assertEqual(1, len(scope.get_scopes())) c_scope = scope.get_scopes()[0] - self.assertTrue('an_attr'in c_scope.get_names()) - self.assertTrue(c_scope.lookup('an_attr') is not None) + self.assertTrue("an_attr" in c_scope.get_names()) + self.assertTrue(c_scope.lookup("an_attr") is not None) f_in_c = c_scope.get_scopes()[0] - self.assertTrue(f_in_c.lookup('an_attr') is None) + self.assertTrue(f_in_c.lookup("an_attr") is None) def test_inside_class_scope_attribute_lookup2(self): scope = libutils.get_string_scope( self.project, - 'class C(object):\n' - ' def __init__(self):\n self.an_attr = 1\n' - ' def a_func(self):\n pass') + "class C(object):\n" + " def __init__(self):\n self.an_attr = 1\n" + " def a_func(self):\n pass", + ) self.assertEqual(1, len(scope.get_scopes())) c_scope = scope.get_scopes()[0] f_in_c = c_scope.get_scopes()[0] - self.assertTrue(f_in_c.lookup('an_attr') is None) + self.assertTrue(f_in_c.lookup("an_attr") is None) def test_get_inner_scope_for_staticmethods(self): scope = libutils.get_string_scope( self.project, - 'class C(object):\n' - ' @staticmethod\n' - ' def a_func(self):\n pass\n') + "class C(object):\n" + " @staticmethod\n" + " def a_func(self):\n pass\n", + ) c_scope = scope.get_scopes()[0] f_in_c = c_scope.get_scopes()[0] self.assertEqual(f_in_c, scope.get_inner_scope_for_line(4)) def test_getting_overwritten_scopes(self): scope = libutils.get_string_scope( - self.project, 'def f():\n pass\ndef f():\n pass\n') + self.project, "def f():\n pass\ndef f():\n pass\n" + ) self.assertEqual(2, len(scope.get_scopes())) f1_scope = scope.get_scopes()[0] f2_scope = scope.get_scopes()[1] self.assertNotEqual(f1_scope, f2_scope) def test_assigning_builtin_names(self): - mod = libutils.get_string_module(self.project, 'range = 1\n') - range = mod.get_scope().lookup('range') + mod = libutils.get_string_module(self.project, "range = 1\n") + range = mod.get_scope().lookup("range") self.assertEqual((mod, 1), range.get_definition_location()) def test_get_inner_scope_and_logical_lines(self): scope = libutils.get_string_scope( self.project, - 'class C(object):\n' - ' def f():\n s = """\n1\n2\n"""\n a = 1\n') + "class C(object):\n" + ' def f():\n s = """\n1\n2\n"""\n a = 1\n', + ) c_scope = scope.get_scopes()[0] f_in_c = c_scope.get_scopes()[0] self.assertEqual(f_in_c, scope.get_inner_scope_for_line(7)) @@ -298,19 +314,19 @@ def test_get_inner_scope_and_logical_lines(self): def test_getting_defined_names_for_classes(self): scope = libutils.get_string_scope( self.project, - 'class A(object):\n def a(self):\n pass\n' - 'class B(A):\n def b(self):\n pass\n') - a_scope = scope['A'].get_object().get_scope() # noqa - b_scope = scope['B'].get_object().get_scope() - self.assertTrue('a' in b_scope.get_names()) - self.assertTrue('b' in b_scope.get_names()) - self.assertTrue('a' not in b_scope.get_defined_names()) - self.assertTrue('b' in b_scope.get_defined_names()) + "class A(object):\n def a(self):\n pass\n" + "class B(A):\n def b(self):\n pass\n", + ) + a_scope = scope["A"].get_object().get_scope() # noqa + b_scope = scope["B"].get_object().get_scope() + self.assertTrue("a" in b_scope.get_names()) + self.assertTrue("b" in b_scope.get_names()) + self.assertTrue("a" not in b_scope.get_defined_names()) + self.assertTrue("b" in b_scope.get_defined_names()) def test_getting_defined_names_for_modules(self): - scope = libutils.get_string_scope( - self.project, 'class A(object):\n pass\n') - self.assertTrue('open' in scope.get_names()) - self.assertTrue('A' in scope.get_names()) - self.assertTrue('open' not in scope.get_defined_names()) - self.assertTrue('A' in scope.get_defined_names()) + scope = libutils.get_string_scope(self.project, "class A(object):\n pass\n") + self.assertTrue("open" in scope.get_names()) + self.assertTrue("A" in scope.get_names()) + self.assertTrue("open" not in scope.get_defined_names()) + self.assertTrue("A" in scope.get_defined_names()) diff --git a/ropetest/refactor/__init__.py b/ropetest/refactor/__init__.py index 38b6310ec..8ebe9dd81 100644 --- a/ropetest/refactor/__init__.py +++ b/ropetest/refactor/__init__.py @@ -25,132 +25,144 @@ class MethodObjectTest(unittest.TestCase): - def setUp(self): super(MethodObjectTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(MethodObjectTest, self).tearDown() def test_empty_method(self): - code = 'def func():\n pass\n' + code = "def func():\n pass\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) + replacer = MethodObject(self.project, self.mod, code.index("func")) self.assertEqual( - 'class _New(object):\n\n def __call__(self):\n pass\n', - replacer.get_new_class('_New')) + "class _New(object):\n\n def __call__(self):\n pass\n", + replacer.get_new_class("_New"), + ) def test_trivial_return(self): - code = 'def func():\n return 1\n' + code = "def func():\n return 1\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) + replacer = MethodObject(self.project, self.mod, code.index("func")) self.assertEqual( - 'class _New(object):\n\n def __call__(self):' - '\n return 1\n', - replacer.get_new_class('_New')) + "class _New(object):\n\n def __call__(self):" "\n return 1\n", + replacer.get_new_class("_New"), + ) def test_multi_line_header(self): - code = 'def func(\n ):\n return 1\n' + code = "def func(\n ):\n return 1\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) + replacer = MethodObject(self.project, self.mod, code.index("func")) self.assertEqual( - 'class _New(object):\n\n def __call__(self):' - '\n return 1\n', - replacer.get_new_class('_New')) + "class _New(object):\n\n def __call__(self):" "\n return 1\n", + replacer.get_new_class("_New"), + ) def test_a_single_parameter(self): - code = 'def func(param):\n return 1\n' + code = "def func(param):\n return 1\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) + replacer = MethodObject(self.project, self.mod, code.index("func")) self.assertEqual( - 'class _New(object):\n\n' - ' def __init__(self, param):\n self.param = param\n\n' - ' def __call__(self):\n return 1\n', - replacer.get_new_class('_New')) + "class _New(object):\n\n" + " def __init__(self, param):\n self.param = param\n\n" + " def __call__(self):\n return 1\n", + replacer.get_new_class("_New"), + ) def test_self_parameter(self): - code = 'def func(self):\n return 1\n' + code = "def func(self):\n return 1\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) + replacer = MethodObject(self.project, self.mod, code.index("func")) self.assertEqual( - 'class _New(object):\n\n' - ' def __init__(self, host):\n self.self = host\n\n' - ' def __call__(self):\n return 1\n', - replacer.get_new_class('_New')) + "class _New(object):\n\n" + " def __init__(self, host):\n self.self = host\n\n" + " def __call__(self):\n return 1\n", + replacer.get_new_class("_New"), + ) def test_simple_using_passed_parameters(self): - code = 'def func(param):\n return param\n' + code = "def func(param):\n return param\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) + replacer = MethodObject(self.project, self.mod, code.index("func")) self.assertEqual( - 'class _New(object):\n\n' - ' def __init__(self, param):\n self.param = param\n\n' - ' def __call__(self):\n return self.param\n', - replacer.get_new_class('_New')) + "class _New(object):\n\n" + " def __init__(self, param):\n self.param = param\n\n" + " def __call__(self):\n return self.param\n", + replacer.get_new_class("_New"), + ) def test_self_keywords_and_args_parameters(self): - code = 'def func(arg, *args, **kwds):\n' \ - ' result = arg + args[0] + kwds[arg]\n' \ - ' return result\n' + code = ( + "def func(arg, *args, **kwds):\n" + " result = arg + args[0] + kwds[arg]\n" + " return result\n" + ) self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) - expected = 'class _New(object):\n\n' \ - ' def __init__(self, arg, args, kwds):\n' \ - ' self.arg = arg\n' \ - ' self.args = args\n' \ - ' self.kwds = kwds\n\n' \ - ' def __call__(self):\n' \ - ' result = self.arg + ' \ - 'self.args[0] + self.kwds[self.arg]\n' \ - ' return result\n' - self.assertEqual(expected, replacer.get_new_class('_New')) + replacer = MethodObject(self.project, self.mod, code.index("func")) + expected = ( + "class _New(object):\n\n" + " def __init__(self, arg, args, kwds):\n" + " self.arg = arg\n" + " self.args = args\n" + " self.kwds = kwds\n\n" + " def __call__(self):\n" + " result = self.arg + " + "self.args[0] + self.kwds[self.arg]\n" + " return result\n" + ) + self.assertEqual(expected, replacer.get_new_class("_New")) def test_performing_on_not_a_function(self): - code = 'my_var = 10\n' + code = "my_var = 10\n" self.mod.write(code) with self.assertRaises(RefactoringError): - MethodObject(self.project, self.mod, code.index('my_var')) + MethodObject(self.project, self.mod, code.index("my_var")) def test_changing_the_module(self): - code = 'def func():\n return 1\n' + code = "def func():\n return 1\n" self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) - self.project.do(replacer.get_changes('_New')) - expected = 'def func():\n' \ - ' return _New()()\n\n\n' \ - 'class _New(object):\n\n' \ - ' def __call__(self):\n' \ - ' return 1\n' + replacer = MethodObject(self.project, self.mod, code.index("func")) + self.project.do(replacer.get_changes("_New")) + expected = ( + "def func():\n" + " return _New()()\n\n\n" + "class _New(object):\n\n" + " def __call__(self):\n" + " return 1\n" + ) self.assertEqual(expected, self.mod.read()) def test_changing_the_module_and_class_methods(self): - code = 'class C(object):\n\n' \ - ' def a_func(self):\n' \ - ' return 1\n\n' \ - ' def another_func(self):\n' \ - ' pass\n' + code = ( + "class C(object):\n\n" + " def a_func(self):\n" + " return 1\n\n" + " def another_func(self):\n" + " pass\n" + ) self.mod.write(code) - replacer = MethodObject(self.project, self.mod, code.index('func')) - self.project.do(replacer.get_changes('_New')) - expected = 'class C(object):\n\n' \ - ' def a_func(self):\n' \ - ' return _New(self)()\n\n' \ - ' def another_func(self):\n' \ - ' pass\n\n\n' \ - 'class _New(object):\n\n' \ - ' def __init__(self, host):\n' \ - ' self.self = host\n\n' \ - ' def __call__(self):\n' \ - ' return 1\n' + replacer = MethodObject(self.project, self.mod, code.index("func")) + self.project.do(replacer.get_changes("_New")) + expected = ( + "class C(object):\n\n" + " def a_func(self):\n" + " return _New(self)()\n\n" + " def another_func(self):\n" + " pass\n\n\n" + "class _New(object):\n\n" + " def __init__(self, host):\n" + " self.self = host\n\n" + " def __call__(self):\n" + " return 1\n" + ) self.assertEqual(expected, self.mod.read()) class IntroduceFactoryTest(unittest.TestCase): - def setUp(self): super(IntroduceFactoryTest, self).setUp() self.project = testutils.sample_project() @@ -161,302 +173,328 @@ def tearDown(self): super(IntroduceFactoryTest, self).tearDown() def _introduce_factory(self, resource, offset, *args, **kwds): - factory_introducer = IntroduceFactory(self.project, - resource, offset) + factory_introducer = IntroduceFactory(self.project, resource, offset) changes = factory_introducer.get_changes(*args, **kwds) self.project.do(changes) def test_adding_the_method(self): - code = 'class AClass(object):\n an_attr = 10\n' - mod = testutils.create_module(self.project, 'mod') + code = "class AClass(object):\n an_attr = 10\n" + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' - self._introduce_factory(mod, mod.read().index('AClass') + 1, 'create') + expected = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + ) + self._introduce_factory(mod, mod.read().index("AClass") + 1, "create") self.assertEqual(expected, mod.read()) def test_changing_occurances_in_the_main_module(self): - code = 'class AClass(object):\n' \ - ' an_attr = 10\n' \ - 'a_var = AClass()' - mod = testutils.create_module(self.project, 'mod') + code = "class AClass(object):\n" " an_attr = 10\n" "a_var = AClass()" + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n'\ - 'a_var = AClass.create()' - self._introduce_factory(mod, mod.read().index('AClass') + 1, 'create') + expected = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + "a_var = AClass.create()" + ) + self._introduce_factory(mod, mod.read().index("AClass") + 1, "create") self.assertEqual(expected, mod.read()) def test_changing_occurances_with_arguments(self): - code = 'class AClass(object):\n' \ - ' def __init__(self, arg):\n' \ - ' pass\n' \ - 'a_var = AClass(10)\n' - mod = testutils.create_module(self.project, 'mod') + code = ( + "class AClass(object):\n" + " def __init__(self, arg):\n" + " pass\n" + "a_var = AClass(10)\n" + ) + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'class AClass(object):\n' \ - ' def __init__(self, arg):\n' \ - ' pass\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' \ - 'a_var = AClass.create(10)\n' - self._introduce_factory(mod, mod.read().index('AClass') + 1, 'create') + expected = ( + "class AClass(object):\n" + " def __init__(self, arg):\n" + " pass\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + "a_var = AClass.create(10)\n" + ) + self._introduce_factory(mod, mod.read().index("AClass") + 1, "create") self.assertEqual(expected, mod.read()) def test_changing_occurances_in_other_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('class AClass(object):\n an_attr = 10\n') - mod2.write('import mod1\na_var = mod1.AClass()\n') - self._introduce_factory(mod1, mod1.read().index('AClass') + 1, - 'create') - expected1 = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' - expected2 = 'import mod1\n' \ - 'a_var = mod1.AClass.create()\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("class AClass(object):\n an_attr = 10\n") + mod2.write("import mod1\na_var = mod1.AClass()\n") + self._introduce_factory(mod1, mod1.read().index("AClass") + 1, "create") + expected1 = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + ) + expected2 = "import mod1\n" "a_var = mod1.AClass.create()\n" self.assertEqual(expected1, mod1.read()) self.assertEqual(expected2, mod2.read()) def test_raising_exception_for_non_classes(self): - mod = testutils.create_module(self.project, 'mod') - mod.write('def a_func():\n pass\n') + mod = testutils.create_module(self.project, "mod") + mod.write("def a_func():\n pass\n") with self.assertRaises(RefactoringError): - self._introduce_factory(mod, mod.read().index('a_func') + 1, - 'create') + self._introduce_factory(mod, mod.read().index("a_func") + 1, "create") def test_undoing_introduce_factory(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - code1 = 'class AClass(object):\n an_attr = 10\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + code1 = "class AClass(object):\n an_attr = 10\n" mod1.write(code1) - code2 = 'from mod1 import AClass\na_var = AClass()\n' + code2 = "from mod1 import AClass\na_var = AClass()\n" mod2.write(code2) - self._introduce_factory(mod1, mod1.read().index('AClass') + 1, - 'create') + self._introduce_factory(mod1, mod1.read().index("AClass") + 1, "create") self.project.history.undo() self.assertEqual(code1, mod1.read()) self.assertEqual(code2, mod2.read()) def test_using_on_an_occurance_outside_the_main_module(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('class AClass(object):\n an_attr = 10\n') - mod2.write('import mod1\na_var = mod1.AClass()\n') - self._introduce_factory(mod2, mod2.read().index('AClass') + 1, - 'create') - expected1 = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' - expected2 = 'import mod1\n' \ - 'a_var = mod1.AClass.create()\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("class AClass(object):\n an_attr = 10\n") + mod2.write("import mod1\na_var = mod1.AClass()\n") + self._introduce_factory(mod2, mod2.read().index("AClass") + 1, "create") + expected1 = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + ) + expected2 = "import mod1\n" "a_var = mod1.AClass.create()\n" self.assertEqual(expected1, mod1.read()) self.assertEqual(expected2, mod2.read()) def test_introduce_factory_in_nested_scopes(self): - code = 'def create_var():\n'\ - ' class AClass(object):\n'\ - ' an_attr = 10\n'\ - ' return AClass()\n' - mod = testutils.create_module(self.project, 'mod') + code = ( + "def create_var():\n" + " class AClass(object):\n" + " an_attr = 10\n" + " return AClass()\n" + ) + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'def create_var():\n'\ - ' class AClass(object):\n'\ - ' an_attr = 10\n\n'\ - ' @staticmethod\n ' \ - 'def create(*args, **kwds):\n'\ - ' return AClass(*args, **kwds)\n'\ - ' return AClass.create()\n' - self._introduce_factory(mod, mod.read().index('AClass') + 1, 'create') + expected = ( + "def create_var():\n" + " class AClass(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n " + "def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + " return AClass.create()\n" + ) + self._introduce_factory(mod, mod.read().index("AClass") + 1, "create") self.assertEqual(expected, mod.read()) def test_adding_factory_for_global_factories(self): - code = 'class AClass(object):\n an_attr = 10\n' - mod = testutils.create_module(self.project, 'mod') + code = "class AClass(object):\n an_attr = 10\n" + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - 'def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' - self._introduce_factory(mod, mod.read().index('AClass') + 1, - 'create', global_factory=True) + expected = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + "def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + ) + self._introduce_factory( + mod, mod.read().index("AClass") + 1, "create", global_factory=True + ) self.assertEqual(expected, mod.read()) def test_get_name_for_factories(self): - code = 'class C(object):\n pass\n' - mod = testutils.create_module(self.project, 'mod') + code = "class C(object):\n pass\n" + mod = testutils.create_module(self.project, "mod") mod.write(code) - factory = IntroduceFactory(self.project, mod, - mod.read().index('C') + 1) - self.assertEqual('C', factory.get_name()) + factory = IntroduceFactory(self.project, mod, mod.read().index("C") + 1) + self.assertEqual("C", factory.get_name()) def test_raising_exception_for_global_factory_for_nested_classes(self): - code = 'def create_var():\n'\ - ' class AClass(object):\n'\ - ' an_attr = 10\n'\ - ' return AClass()\n' - mod = testutils.create_module(self.project, 'mod') + code = ( + "def create_var():\n" + " class AClass(object):\n" + " an_attr = 10\n" + " return AClass()\n" + ) + mod = testutils.create_module(self.project, "mod") mod.write(code) with self.assertRaises(RefactoringError): - self._introduce_factory(mod, mod.read().index('AClass') + 1, - 'create', global_factory=True) + self._introduce_factory( + mod, mod.read().index("AClass") + 1, "create", global_factory=True + ) def test_changing_occurances_in_the_main_module_for_global_factories(self): - code = 'class AClass(object):\n' \ - ' an_attr = 10\n' \ - 'a_var = AClass()' - mod = testutils.create_module(self.project, 'mod') + code = "class AClass(object):\n" " an_attr = 10\n" "a_var = AClass()" + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'class AClass(object):\n an_attr = 10\n\n' \ - 'def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n'\ - 'a_var = create()' - self._introduce_factory(mod, mod.read().index('AClass') + 1, - 'create', global_factory=True) + expected = ( + "class AClass(object):\n an_attr = 10\n\n" + "def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + "a_var = create()" + ) + self._introduce_factory( + mod, mod.read().index("AClass") + 1, "create", global_factory=True + ) self.assertEqual(expected, mod.read()) def test_changing_occurances_in_other_modules_for_global_factories(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('class AClass(object):\n an_attr = 10\n') - mod2.write('import mod1\na_var = mod1.AClass()\n') - self._introduce_factory(mod1, mod1.read().index('AClass') + 1, - 'create', global_factory=True) - expected1 = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - 'def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' - expected2 = 'import mod1\n' \ - 'a_var = mod1.create()\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("class AClass(object):\n an_attr = 10\n") + mod2.write("import mod1\na_var = mod1.AClass()\n") + self._introduce_factory( + mod1, mod1.read().index("AClass") + 1, "create", global_factory=True + ) + expected1 = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + "def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + ) + expected2 = "import mod1\n" "a_var = mod1.create()\n" self.assertEqual(expected1, mod1.read()) self.assertEqual(expected2, mod2.read()) def test_import_if_necessary_in_other_mods_for_global_factories(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('class AClass(object):\n an_attr = 10\n') - mod2.write('from mod1 import AClass\npair = AClass(), AClass\n') - self._introduce_factory(mod1, mod1.read().index('AClass') + 1, - 'create', global_factory=True) - expected1 = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - 'def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' - expected2 = 'from mod1 import AClass, create\n' \ - 'pair = create(), AClass\n' + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("class AClass(object):\n an_attr = 10\n") + mod2.write("from mod1 import AClass\npair = AClass(), AClass\n") + self._introduce_factory( + mod1, mod1.read().index("AClass") + 1, "create", global_factory=True + ) + expected1 = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + "def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + ) + expected2 = "from mod1 import AClass, create\n" "pair = create(), AClass\n" self.assertEqual(expected1, mod1.read()) self.assertEqual(expected2, mod2.read()) def test_changing_occurances_for_renamed_classes(self): - code = 'class AClass(object):\n an_attr = 10' \ - '\na_class = AClass\na_var = a_class()' - mod = testutils.create_module(self.project, 'mod') + code = ( + "class AClass(object):\n an_attr = 10" + "\na_class = AClass\na_var = a_class()" + ) + mod = testutils.create_module(self.project, "mod") mod.write(code) - expected = 'class AClass(object):\n' \ - ' an_attr = 10\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return AClass(*args, **kwds)\n' \ - 'a_class = AClass\n' \ - 'a_var = a_class()' - self._introduce_factory(mod, mod.read().index('a_class') + 1, 'create') + expected = ( + "class AClass(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return AClass(*args, **kwds)\n" + "a_class = AClass\n" + "a_var = a_class()" + ) + self._introduce_factory(mod, mod.read().index("a_class") + 1, "create") self.assertEqual(expected, mod.read()) def test_changing_occurrs_in_the_same_module_with_conflict_ranges(self): - mod = testutils.create_module(self.project, 'mod') - code = 'class C(object):\n' \ - ' def create(self):\n' \ - ' return C()\n' + mod = testutils.create_module(self.project, "mod") + code = "class C(object):\n" " def create(self):\n" " return C()\n" mod.write(code) - self._introduce_factory(mod, mod.read().index('C'), 'create_c', True) - expected = 'class C(object):\n' \ - ' def create(self):\n' \ - ' return create_c()\n' + self._introduce_factory(mod, mod.read().index("C"), "create_c", True) + expected = ( + "class C(object):\n" " def create(self):\n" " return create_c()\n" + ) self.assertTrue(mod.read().startswith(expected)) def _transform_module_to_package(self, resource): - self.project.do(rope.refactor.ModuleToPackage( - self.project, resource).get_changes()) + self.project.do( + rope.refactor.ModuleToPackage(self.project, resource).get_changes() + ) def test_transform_module_to_package(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('import mod2\nfrom mod2 import AClass\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('class AClass(object):\n pass\n') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("import mod2\nfrom mod2 import AClass\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("class AClass(object):\n pass\n") self._transform_module_to_package(mod2) - mod2 = self.project.get_resource('mod2') + mod2 = self.project.get_resource("mod2") root_folder = self.project.root - self.assertFalse(root_folder.has_child('mod2.py')) - self.assertEqual('class AClass(object):\n pass\n', - root_folder.get_child('mod2'). - get_child('__init__.py').read()) + self.assertFalse(root_folder.has_child("mod2.py")) + self.assertEqual( + "class AClass(object):\n pass\n", + root_folder.get_child("mod2").get_child("__init__.py").read(), + ) def test_transform_module_to_package_undoing(self): - pkg = testutils.create_package(self.project, 'pkg') - mod = testutils.create_module(self.project, 'mod', pkg) + pkg = testutils.create_package(self.project, "pkg") + mod = testutils.create_module(self.project, "mod", pkg) self._transform_module_to_package(mod) - self.assertFalse(pkg.has_child('mod.py')) - self.assertTrue(pkg.get_child('mod').has_child('__init__.py')) + self.assertFalse(pkg.has_child("mod.py")) + self.assertTrue(pkg.get_child("mod").has_child("__init__.py")) self.project.history.undo() - self.assertTrue(pkg.has_child('mod.py')) - self.assertFalse(pkg.has_child('mod')) + self.assertTrue(pkg.has_child("mod.py")) + self.assertFalse(pkg.has_child("mod")) def test_transform_module_to_package_with_relative_imports(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod1.write('import mod2\nfrom mod2 import AClass\n') - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod2.write('class AClass(object):\n pass\n') + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod1.write("import mod2\nfrom mod2 import AClass\n") + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod2.write("class AClass(object):\n pass\n") self._transform_module_to_package(mod1) - new_init = self.project.get_resource('pkg/mod1/__init__.py') - self.assertEqual('import pkg.mod2\nfrom pkg.mod2 import AClass\n', - new_init.read()) + new_init = self.project.get_resource("pkg/mod1/__init__.py") + self.assertEqual( + "import pkg.mod2\nfrom pkg.mod2 import AClass\n", new_init.read() + ) def test_resources_parameter(self): - code = 'class A(object):\n an_attr = 10\n' - code1 = 'import mod\na = mod.A()\n' - mod = testutils.create_module(self.project, 'mod') - mod1 = testutils.create_module(self.project, 'mod1') + code = "class A(object):\n an_attr = 10\n" + code1 = "import mod\na = mod.A()\n" + mod = testutils.create_module(self.project, "mod") + mod1 = testutils.create_module(self.project, "mod1") mod.write(code) mod1.write(code1) - expected = 'class A(object):\n' \ - ' an_attr = 10\n\n' \ - ' @staticmethod\n' \ - ' def create(*args, **kwds):\n' \ - ' return A(*args, **kwds)\n' - self._introduce_factory(mod, mod.read().index('A') + 1, - 'create', resources=[mod]) + expected = ( + "class A(object):\n" + " an_attr = 10\n\n" + " @staticmethod\n" + " def create(*args, **kwds):\n" + " return A(*args, **kwds)\n" + ) + self._introduce_factory( + mod, mod.read().index("A") + 1, "create", resources=[mod] + ) self.assertEqual(expected, mod.read()) self.assertEqual(code1, mod1.read()) class EncapsulateFieldTest(unittest.TestCase): - def setUp(self): super(EncapsulateFieldTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') - self.mod1 = testutils.create_module(self.project, 'mod1') - self.a_class = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.attr = 1\n' - self.added_methods = '\n' \ - ' def get_attr(self):\n' \ - ' return self.attr\n\n' \ - ' def set_attr(self, value):\n' \ - ' self.attr = value\n' + self.mod = testutils.create_module(self.project, "mod") + self.mod1 = testutils.create_module(self.project, "mod1") + self.a_class = ( + "class A(object):\n" " def __init__(self):\n" " self.attr = 1\n" + ) + self.added_methods = ( + "\n" + " def get_attr(self):\n" + " return self.attr\n\n" + " def set_attr(self, value):\n" + " self.attr = value\n" + ) self.encapsulated = self.a_class + self.added_methods def tearDown(self): @@ -464,275 +502,256 @@ def tearDown(self): super(EncapsulateFieldTest, self).tearDown() def _encapsulate(self, resource, offset, **args): - changes = EncapsulateField(self.project, resource, offset).\ - get_changes(**args) + changes = EncapsulateField(self.project, resource, offset).get_changes(**args) self.project.do(changes) def test_adding_getters_and_setters(self): code = self.a_class self.mod.write(code) - self._encapsulate(self.mod, code.index('attr') + 1) + self._encapsulate(self.mod, code.index("attr") + 1) self.assertEqual(self.encapsulated, self.mod.read()) def test_changing_getters_in_other_modules(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'range(a_var.attr)\n' + code = "import mod\n" "a_var = mod.A()\n" "range(a_var.attr)\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'range(a_var.get_attr())\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = "import mod\n" "a_var = mod.A()\n" "range(a_var.get_attr())\n" self.assertEqual(expected, self.mod1.read()) def test_changing_setters_in_other_modules(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.attr = 1\n' + code = "import mod\n" "a_var = mod.A()\n" "a_var.attr = 1\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.set_attr(1)\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = "import mod\n" "a_var = mod.A()\n" "a_var.set_attr(1)\n" self.assertEqual(expected, self.mod1.read()) def test_changing_getters_in_setters(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.attr = 1 + a_var.attr\n' + code = "import mod\n" "a_var = mod.A()\n" "a_var.attr = 1 + a_var.attr\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.set_attr(1 + a_var.get_attr())\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = ( + "import mod\n" "a_var = mod.A()\n" "a_var.set_attr(1 + a_var.get_attr())\n" + ) self.assertEqual(expected, self.mod1.read()) def test_appending_to_class_end(self): - self.mod1.write(self.a_class + 'a_var = A()\n') - self._encapsulate(self.mod1, self.mod1.read().index('attr') + 1) - self.assertEqual(self.encapsulated + 'a_var = A()\n', - self.mod1.read()) + self.mod1.write(self.a_class + "a_var = A()\n") + self._encapsulate(self.mod1, self.mod1.read().index("attr") + 1) + self.assertEqual(self.encapsulated + "a_var = A()\n", self.mod1.read()) def test_performing_in_other_modules(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'range(a_var.attr)\n' + code = "import mod\n" "a_var = mod.A()\n" "range(a_var.attr)\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod1, self.mod1.read().index('attr') + 1) + self._encapsulate(self.mod1, self.mod1.read().index("attr") + 1) self.assertEqual(self.encapsulated, self.mod.read()) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'range(a_var.get_attr())\n' + expected = "import mod\n" "a_var = mod.A()\n" "range(a_var.get_attr())\n" self.assertEqual(expected, self.mod1.read()) def test_changing_main_module_occurances(self): - code = self.a_class + \ - 'a_var = A()\n' \ - 'a_var.attr = a_var.attr * 2\n' + code = self.a_class + "a_var = A()\n" "a_var.attr = a_var.attr * 2\n" self.mod1.write(code) - self._encapsulate(self.mod1, self.mod1.read().index('attr') + 1) - expected = self.encapsulated + \ - 'a_var = A()\n' \ - 'a_var.set_attr(a_var.get_attr() * 2)\n' + self._encapsulate(self.mod1, self.mod1.read().index("attr") + 1) + expected = ( + self.encapsulated + "a_var = A()\n" "a_var.set_attr(a_var.get_attr() * 2)\n" + ) self.assertEqual(expected, self.mod1.read()) def test_raising_exception_when_performed_on_non_attributes(self): - self.mod1.write('attr = 10') + self.mod1.write("attr = 10") with self.assertRaises(RefactoringError): - self._encapsulate(self.mod1, self.mod1.read().index('attr') + 1) + self._encapsulate(self.mod1, self.mod1.read().index("attr") + 1) def test_raising_exception_on_tuple_assignments(self): self.mod.write(self.a_class) - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.attr = 1\n' \ - 'a_var.attr, b = 1, 2\n' + code = ( + "import mod\n" + "a_var = mod.A()\n" + "a_var.attr = 1\n" + "a_var.attr, b = 1, 2\n" + ) self.mod1.write(code) with self.assertRaises(RefactoringError): - self._encapsulate(self.mod1, self.mod1.read().index('attr') + 1) + self._encapsulate(self.mod1, self.mod1.read().index("attr") + 1) def test_raising_exception_on_tuple_assignments2(self): self.mod.write(self.a_class) - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.attr = 1\n' \ - 'b, a_var.attr = 1, 2\n' + code = ( + "import mod\n" + "a_var = mod.A()\n" + "a_var.attr = 1\n" + "b, a_var.attr = 1, 2\n" + ) self.mod1.write(code) with self.assertRaises(RefactoringError): - self._encapsulate(self.mod1, self.mod1.read().index('attr') + 1) + self._encapsulate(self.mod1, self.mod1.read().index("attr") + 1) def test_tuple_assignments_and_function_calls(self): - code = 'import mod\n' \ - 'def func(a1=0, a2=0):\n' \ - ' pass\n' \ - 'a_var = mod.A()\n' \ - 'func(a_var.attr, a2=2)\n' + code = ( + "import mod\n" + "def func(a1=0, a2=0):\n" + " pass\n" + "a_var = mod.A()\n" + "func(a_var.attr, a2=2)\n" + ) self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'def func(a1=0, a2=0):\n' \ - ' pass\n' \ - 'a_var = mod.A()\n' \ - 'func(a_var.get_attr(), a2=2)\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = ( + "import mod\n" + "def func(a1=0, a2=0):\n" + " pass\n" + "a_var = mod.A()\n" + "func(a_var.get_attr(), a2=2)\n" + ) self.assertEqual(expected, self.mod1.read()) def test_tuple_assignments(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a, b = a_var.attr, 1\n' + code = "import mod\n" "a_var = mod.A()\n" "a, b = a_var.attr, 1\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a, b = a_var.get_attr(), 1\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = "import mod\n" "a_var = mod.A()\n" "a, b = a_var.get_attr(), 1\n" self.assertEqual(expected, self.mod1.read()) def test_changing_augmented_assignments(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.attr += 1\n' + code = "import mod\n" "a_var = mod.A()\n" "a_var.attr += 1\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.set_attr(a_var.get_attr() + 1)\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = ( + "import mod\n" "a_var = mod.A()\n" "a_var.set_attr(a_var.get_attr() + 1)\n" + ) self.assertEqual(expected, self.mod1.read()) def test_changing_augmented_assignments2(self): - code = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.attr <<= 1\n' + code = "import mod\n" "a_var = mod.A()\n" "a_var.attr <<= 1\n" self.mod1.write(code) self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = 'import mod\n' \ - 'a_var = mod.A()\n' \ - 'a_var.set_attr(a_var.get_attr() << 1)\n' + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = ( + "import mod\n" "a_var = mod.A()\n" "a_var.set_attr(a_var.get_attr() << 1)\n" + ) self.assertEqual(expected, self.mod1.read()) def test_changing_occurrences_inside_the_class(self): - new_class = self.a_class + '\n' \ - ' def a_func(self):\n' \ - ' self.attr = 1\n' + new_class = ( + self.a_class + "\n" " def a_func(self):\n" " self.attr = 1\n" + ) self.mod.write(new_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1) - expected = self.a_class + '\n' \ - ' def a_func(self):\n' \ - ' self.set_attr(1)\n' + \ - self.added_methods + self._encapsulate(self.mod, self.mod.read().index("attr") + 1) + expected = ( + self.a_class + "\n" + " def a_func(self):\n" + " self.set_attr(1)\n" + self.added_methods + ) self.assertEqual(expected, self.mod.read()) def test_getter_and_setter_parameters(self): self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1, - getter='getAttr', setter='setAttr') - new_methods = self.added_methods.replace('get_attr', 'getAttr').\ - replace('set_attr', 'setAttr') + self._encapsulate( + self.mod, + self.mod.read().index("attr") + 1, + getter="getAttr", + setter="setAttr", + ) + new_methods = self.added_methods.replace("get_attr", "getAttr").replace( + "set_attr", "setAttr" + ) expected = self.a_class + new_methods self.assertEqual(expected, self.mod.read()) def test_using_resources_parameter(self): - self.mod1.write('import mod\na = mod.A()\nvar = a.attr\n') + self.mod1.write("import mod\na = mod.A()\nvar = a.attr\n") self.mod.write(self.a_class) - self._encapsulate(self.mod, self.mod.read().index('attr') + 1, - resources=[self.mod]) - self.assertEqual('import mod\na = mod.A()\nvar = a.attr\n', - self.mod1.read()) + self._encapsulate( + self.mod, self.mod.read().index("attr") + 1, resources=[self.mod] + ) + self.assertEqual("import mod\na = mod.A()\nvar = a.attr\n", self.mod1.read()) expected = self.a_class + self.added_methods self.assertEqual(expected, self.mod.read()) class LocalToFieldTest(unittest.TestCase): - def setUp(self): super(LocalToFieldTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(LocalToFieldTest, self).tearDown() def _perform_convert_local_variable_to_field(self, resource, offset): - changes = LocalToField( - self.project, resource, offset).get_changes() + changes = LocalToField(self.project, resource, offset).get_changes() self.project.do(changes) def test_simple_local_to_field(self): - code = 'class A(object):\n' \ - ' def a_func(self):\n' \ - ' var = 10\n' + code = "class A(object):\n" " def a_func(self):\n" " var = 10\n" self.mod.write(code) - self._perform_convert_local_variable_to_field(self.mod, - code.index('var') + 1) - expected = 'class A(object):\n' \ - ' def a_func(self):\n' \ - ' self.var = 10\n' + self._perform_convert_local_variable_to_field(self.mod, code.index("var") + 1) + expected = ( + "class A(object):\n" " def a_func(self):\n" " self.var = 10\n" + ) self.assertEqual(expected, self.mod.read()) def test_raising_exception_when_performed_on_a_global_var(self): - self.mod.write('var = 10\n') + self.mod.write("var = 10\n") with self.assertRaises(RefactoringError): self._perform_convert_local_variable_to_field( - self.mod, self.mod.read().index('var') + 1) + self.mod, self.mod.read().index("var") + 1 + ) def test_raising_exception_when_performed_on_field(self): - code = 'class A(object):\n' \ - ' def a_func(self):\n' \ - ' self.var = 10\n' + code = "class A(object):\n" " def a_func(self):\n" " self.var = 10\n" self.mod.write(code) with self.assertRaises(RefactoringError): self._perform_convert_local_variable_to_field( - self.mod, self.mod.read().index('var') + 1) + self.mod, self.mod.read().index("var") + 1 + ) def test_raising_exception_when_performed_on_a_parameter(self): - code = 'class A(object):\n' \ - ' def a_func(self, var):\n' \ - ' a = var\n' + code = "class A(object):\n" " def a_func(self, var):\n" " a = var\n" self.mod.write(code) with self.assertRaises(RefactoringError): self._perform_convert_local_variable_to_field( - self.mod, self.mod.read().index('var') + 1) + self.mod, self.mod.read().index("var") + 1 + ) # NOTE: This situation happens alot and is normally not an error - #@testutils.assert_raises(RefactoringError) + # @testutils.assert_raises(RefactoringError) def test_not_rais_exception_when_there_is_a_field_with_the_same_name(self): - code = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - ' def a_func(self):\n var = 10\n' + code = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + " def a_func(self):\n var = 10\n" + ) self.mod.write(code) self._perform_convert_local_variable_to_field( - self.mod, self.mod.read().rindex('var') + 1) + self.mod, self.mod.read().rindex("var") + 1 + ) def test_local_to_field_with_self_renamed(self): - code = 'class A(object):\n' \ - ' def a_func(myself):\n' \ - ' var = 10\n' + code = "class A(object):\n" " def a_func(myself):\n" " var = 10\n" self.mod.write(code) - self._perform_convert_local_variable_to_field(self.mod, - code.index('var') + 1) - expected = 'class A(object):\n' \ - ' def a_func(myself):\n' \ - ' myself.var = 10\n' + self._perform_convert_local_variable_to_field(self.mod, code.index("var") + 1) + expected = ( + "class A(object):\n" " def a_func(myself):\n" " myself.var = 10\n" + ) self.assertEqual(expected, self.mod.read()) class IntroduceParameterTest(unittest.TestCase): - def setUp(self): super(IntroduceParameterTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) @@ -740,78 +759,67 @@ def tearDown(self): def _introduce_parameter(self, offset, name): rope.refactor.introduce_parameter.IntroduceParameter( - self.project, self.mod, offset).get_changes(name).do() + self.project, self.mod, offset + ).get_changes(name).do() def test_simple_case(self): - code = 'var = 1\n' \ - 'def f():\n' \ - ' b = var\n' + code = "var = 1\n" "def f():\n" " b = var\n" self.mod.write(code) - offset = self.mod.read().rindex('var') - self._introduce_parameter(offset, 'var') - expected = 'var = 1\n' \ - 'def f(var=var):\n' \ - ' b = var\n' + offset = self.mod.read().rindex("var") + self._introduce_parameter(offset, "var") + expected = "var = 1\n" "def f(var=var):\n" " b = var\n" self.assertEqual(expected, self.mod.read()) def test_changing_function_body(self): - code = 'var = 1\n' \ - 'def f():\n' \ - ' b = var\n' + code = "var = 1\n" "def f():\n" " b = var\n" self.mod.write(code) - offset = self.mod.read().rindex('var') - self._introduce_parameter(offset, 'p1') - expected = 'var = 1\n' \ - 'def f(p1=var):\n' \ - ' b = p1\n' + offset = self.mod.read().rindex("var") + self._introduce_parameter(offset, "p1") + expected = "var = 1\n" "def f(p1=var):\n" " b = p1\n" self.assertEqual(expected, self.mod.read()) def test_unknown_variables(self): - self.mod.write('def f():\n b = var + c\n') - offset = self.mod.read().rindex('var') + self.mod.write("def f():\n b = var + c\n") + offset = self.mod.read().rindex("var") with self.assertRaises(RefactoringError): - self._introduce_parameter(offset, 'p1') - self.assertEqual('def f(p1=var):\n b = p1 + c\n', - self.mod.read()) + self._introduce_parameter(offset, "p1") + self.assertEqual("def f(p1=var):\n b = p1 + c\n", self.mod.read()) def test_failing_when_not_inside(self): - self.mod.write('var = 10\nb = var\n') - offset = self.mod.read().rindex('var') + self.mod.write("var = 10\nb = var\n") + offset = self.mod.read().rindex("var") with self.assertRaises(RefactoringError): - self._introduce_parameter(offset, 'p1') + self._introduce_parameter(offset, "p1") def test_attribute_accesses(self): - code = 'class C(object):\n' \ - ' a = 10\nc = C()\n' \ - 'def f():\n' \ - ' b = c.a\n' + code = "class C(object):\n" " a = 10\nc = C()\n" "def f():\n" " b = c.a\n" self.mod.write(code) - offset = self.mod.read().rindex('a') - self._introduce_parameter(offset, 'p1') - expected = 'class C(object):\n' \ - ' a = 10\n' \ - 'c = C()\n' \ - 'def f(p1=c.a):\n' \ - ' b = p1\n' + offset = self.mod.read().rindex("a") + self._introduce_parameter(offset, "p1") + expected = ( + "class C(object):\n" + " a = 10\n" + "c = C()\n" + "def f(p1=c.a):\n" + " b = p1\n" + ) self.assertEqual(expected, self.mod.read()) def test_introducing_parameters_for_methods(self): - code = 'var = 1\n' \ - 'class C(object):\n' \ - ' def f(self):\n' \ - ' b = var\n' + code = "var = 1\n" "class C(object):\n" " def f(self):\n" " b = var\n" self.mod.write(code) - offset = self.mod.read().rindex('var') - self._introduce_parameter(offset, 'p1') - expected = 'var = 1\n' \ - 'class C(object):\n' \ - ' def f(self, p1=var):\n' \ - ' b = p1\n' + offset = self.mod.read().rindex("var") + self._introduce_parameter(offset, "p1") + expected = ( + "var = 1\n" + "class C(object):\n" + " def f(self, p1=var):\n" + " b = p1\n" + ) self.assertEqual(expected, self.mod.read()) class _MockTaskObserver(object): - def __init__(self): self.called = 0 @@ -820,7 +828,6 @@ def __call__(self): class TaskHandleTest(unittest.TestCase): - def test_trivial_case(self): handle = rope.base.taskhandle.TaskHandle() self.assertFalse(handle.is_stopped()) @@ -837,8 +844,8 @@ def test_job_sets(self): def test_starting_and_finishing_jobs(self): handle = rope.base.taskhandle.TaskHandle() - jobs = handle.create_jobset(name='test job set', count=1) - jobs.started_job('job1') + jobs = handle.create_jobset(name="test job set", count=1) + jobs.started_job("job1") jobs.finished_job() def test_test_checking_status(self): @@ -853,7 +860,7 @@ def test_test_checking_status_when_starting(self): jobs = handle.create_jobset() handle.stop() with self.assertRaises(InterruptedTaskError): - jobs.started_job('job1') + jobs.started_job("job1") def test_calling_the_observer_after_stopping(self): handle = rope.base.taskhandle.TaskHandle() @@ -873,26 +880,26 @@ def test_calling_the_observer_when_starting_and_finishing_jobs(self): handle = rope.base.taskhandle.TaskHandle() observer = _MockTaskObserver() handle.add_observer(observer) - jobs = handle.create_jobset(name='test job set', count=1) - jobs.started_job('job1') + jobs = handle.create_jobset(name="test job set", count=1) + jobs.started_job("job1") jobs.finished_job() self.assertEqual(3, observer.called) def test_job_set_get_percent_done(self): handle = rope.base.taskhandle.TaskHandle() - jobs = handle.create_jobset(name='test job set', count=2) + jobs = handle.create_jobset(name="test job set", count=2) self.assertEqual(0, jobs.get_percent_done()) - jobs.started_job('job1') + jobs.started_job("job1") jobs.finished_job() self.assertEqual(50, jobs.get_percent_done()) - jobs.started_job('job2') + jobs.started_job("job2") jobs.finished_job() self.assertEqual(100, jobs.get_percent_done()) def test_getting_job_name(self): handle = rope.base.taskhandle.TaskHandle() - jobs = handle.create_jobset(name='test job set', count=1) - self.assertEqual('test job set', jobs.get_name()) + jobs = handle.create_jobset(name="test job set", count=1) + self.assertEqual("test job set", jobs.get_name()) self.assertEqual(None, jobs.get_active_job_name()) - jobs.started_job('job1') - self.assertEqual('job1', jobs.get_active_job_name()) + jobs.started_job("job1") + self.assertEqual("job1", jobs.get_active_job_name()) diff --git a/ropetest/refactor/change_signature_test.py b/ropetest/refactor/change_signature_test.py index 9fcf5ef1f..0c692a494 100644 --- a/ropetest/refactor/change_signature_test.py +++ b/ropetest/refactor/change_signature_test.py @@ -9,434 +9,465 @@ class ChangeSignatureTest(unittest.TestCase): - def setUp(self): super(ChangeSignatureTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(ChangeSignatureTest, self).tearDown() def test_normalizing_parameters_for_trivial_case(self): - code = 'def a_func():\n pass\na_func()' + code = "def a_func():\n pass\na_func()" self.mod.write(code) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) self.assertEqual(code, self.mod.read()) def test_normalizing_parameters_for_trivial_case2(self): - code = 'def a_func(param):\n pass\na_func(2)' + code = "def a_func(param):\n pass\na_func(2)" self.mod.write(code) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) self.assertEqual(code, self.mod.read()) def test_normalizing_parameters_for_unneeded_keyword(self): - self.mod.write('def a_func(param):\n pass\na_func(param=1)') + self.mod.write("def a_func(param):\n pass\na_func(param=1)") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('def a_func(param):\n pass\na_func(1)', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual("def a_func(param):\n pass\na_func(1)", self.mod.read()) def test_normalizing_parameters_for_unneeded_keyword_for_methods(self): - code = 'class A(object):\n' \ - ' def a_func(self, param):\n' \ - ' pass\n' \ - 'a_var = A()\n' \ - 'a_var.a_func(param=1)\n' + code = ( + "class A(object):\n" + " def a_func(self, param):\n" + " pass\n" + "a_var = A()\n" + "a_var.a_func(param=1)\n" + ) self.mod.write(code) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - expected = 'class A(object):\n' \ - ' def a_func(self, param):\n' \ - ' pass\n' \ - 'a_var = A()\n' \ - 'a_var.a_func(1)\n' + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + expected = ( + "class A(object):\n" + " def a_func(self, param):\n" + " pass\n" + "a_var = A()\n" + "a_var.a_func(1)\n" + ) self.assertEqual(expected, self.mod.read()) def test_normalizing_parameters_for_unsorted_keyword(self): - self.mod.write('def a_func(p1, p2):\n pass\na_func(p2=2, p1=1)') + self.mod.write("def a_func(p1, p2):\n pass\na_func(p2=2, p1=1)") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('def a_func(p1, p2):\n pass\na_func(1, 2)', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual("def a_func(p1, p2):\n pass\na_func(1, 2)", self.mod.read()) def test_raising_exceptions_for_non_functions(self): - self.mod.write('a_var = 10') + self.mod.write("a_var = 10") with self.assertRaises(rope.base.exceptions.RefactoringError): change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_var') + 1) + self.project, self.mod, self.mod.read().index("a_var") + 1 + ) def test_normalizing_parameters_for_args_parameter(self): - self.mod.write('def a_func(*arg):\n pass\na_func(1, 2)\n') + self.mod.write("def a_func(*arg):\n pass\na_func(1, 2)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('def a_func(*arg):\n pass\na_func(1, 2)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual("def a_func(*arg):\n pass\na_func(1, 2)\n", self.mod.read()) def test_normalizing_parameters_for_args_parameter_and_keywords(self): - self.mod.write( - 'def a_func(param, *args):\n pass\na_func(*[1, 2, 3])\n') + self.mod.write("def a_func(param, *args):\n pass\na_func(*[1, 2, 3])\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('def a_func(param, *args):\n pass\n' - 'a_func(*[1, 2, 3])\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual( + "def a_func(param, *args):\n pass\n" "a_func(*[1, 2, 3])\n", + self.mod.read(), + ) def test_normalizing_functions_from_other_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func(param):\n pass\n') - self.mod.write('import mod1\nmod1.a_func(param=1)\n') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func(param):\n pass\n") + self.mod.write("import mod1\nmod1.a_func(param=1)\n") signature = change_signature.ChangeSignature( - self.project, mod1, mod1.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('import mod1\nmod1.a_func(1)\n', self.mod.read()) + self.project, mod1, mod1.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual("import mod1\nmod1.a_func(1)\n", self.mod.read()) def test_normalizing_parameters_for_keyword_parameters(self): - self.mod.write('def a_func(p1, **kwds):\n pass\n' - 'a_func(p2=2, p1=1)\n') + self.mod.write("def a_func(p1, **kwds):\n pass\n" "a_func(p2=2, p1=1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('def a_func(p1, **kwds):\n pass\n' - 'a_func(1, p2=2)\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual( + "def a_func(p1, **kwds):\n pass\n" "a_func(1, p2=2)\n", self.mod.read() + ) def test_removing_arguments(self): - self.mod.write('def a_func(p1):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(0)])) - self.assertEqual('def a_func():\n pass\na_func()\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(0)])) + self.assertEqual("def a_func():\n pass\na_func()\n", self.mod.read()) def test_removing_arguments_with_multiple_args(self): - self.mod.write('def a_func(p1, p2):\n pass\na_func(1, 2)\n') + self.mod.write("def a_func(p1, p2):\n pass\na_func(1, 2)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(0)])) - self.assertEqual('def a_func(p2):\n pass\na_func(2)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(0)])) + self.assertEqual("def a_func(p2):\n pass\na_func(2)\n", self.mod.read()) def test_removing_arguments_passed_as_keywords(self): - self.mod.write('def a_func(p1):\n pass\na_func(p1=1)\n') + self.mod.write("def a_func(p1):\n pass\na_func(p1=1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(0)])) - self.assertEqual('def a_func():\n pass\na_func()\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(0)])) + self.assertEqual("def a_func():\n pass\na_func()\n", self.mod.read()) def test_removing_arguments_with_defaults(self): - self.mod.write('def a_func(p1=1):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1=1):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(0)])) - self.assertEqual('def a_func():\n pass\na_func()\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(0)])) + self.assertEqual("def a_func():\n pass\na_func()\n", self.mod.read()) def test_removing_arguments_star_args(self): - self.mod.write('def a_func(p1, *args):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1, *args):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(1)])) - self.assertEqual('def a_func(p1):\n pass\na_func(1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(1)])) + self.assertEqual("def a_func(p1):\n pass\na_func(1)\n", self.mod.read()) def test_removing_keyword_arg(self): - self.mod.write('def a_func(p1, **kwds):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1, **kwds):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(1)])) - self.assertEqual('def a_func(p1):\n pass\na_func(1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(1)])) + self.assertEqual("def a_func(p1):\n pass\na_func(1)\n", self.mod.read()) def test_removing_keyword_arg2(self): - self.mod.write('def a_func(p1, *args, **kwds):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1, *args, **kwds):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(2)])) - self.assertEqual('def a_func(p1, *args):\n pass\na_func(1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(2)])) + self.assertEqual( + "def a_func(p1, *args):\n pass\na_func(1)\n", self.mod.read() + ) # XXX: What to do here for star args? @unittest.skip("How to deal with start args?") def xxx_test_removing_arguments_star_args2(self): - self.mod.write('def a_func(p1, *args):\n pass\n' - 'a_func(2, 3, p1=1)\n') + self.mod.write("def a_func(p1, *args):\n pass\n" "a_func(2, 3, p1=1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(1)])) - self.assertEqual('def a_func(p1):\n pass\na_func(p1=1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(1)])) + self.assertEqual("def a_func(p1):\n pass\na_func(p1=1)\n", self.mod.read()) # XXX: What to do here for star args? def xxx_test_removing_arguments_star_args3(self): - self.mod.write('def a_func(p1, *args):\n pass\n' - 'a_func(*[1, 2, 3])\n') + self.mod.write("def a_func(p1, *args):\n pass\n" "a_func(*[1, 2, 3])\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(1)])) - self.assertEqual('def a_func(p1):\n pass\na_func(*[1, 2, 3])\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(1)])) + self.assertEqual( + "def a_func(p1):\n pass\na_func(*[1, 2, 3])\n", self.mod.read() + ) def test_adding_arguments_for_normal_args_changing_definition(self): - self.mod.write('def a_func():\n pass\n') + self.mod.write("def a_func():\n pass\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentAdder(0, 'p1')])) - self.assertEqual('def a_func(p1):\n pass\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentAdder(0, "p1")]) + ) + self.assertEqual("def a_func(p1):\n pass\n", self.mod.read()) def test_adding_arguments_for_normal_args_with_defaults(self): - self.mod.write('def a_func():\n pass\na_func()\n') + self.mod.write("def a_func():\n pass\na_func()\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - adder = change_signature.ArgumentAdder(0, 'p1', 'None') + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + adder = change_signature.ArgumentAdder(0, "p1", "None") self.project.do(signature.get_changes([adder])) - self.assertEqual('def a_func(p1=None):\n pass\na_func()\n', - self.mod.read()) + self.assertEqual("def a_func(p1=None):\n pass\na_func()\n", self.mod.read()) def test_adding_arguments_for_normal_args_changing_calls(self): - self.mod.write('def a_func():\n pass\na_func()\n') + self.mod.write("def a_func():\n pass\na_func()\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - adder = change_signature.ArgumentAdder(0, 'p1', 'None', '1') + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + adder = change_signature.ArgumentAdder(0, "p1", "None", "1") self.project.do(signature.get_changes([adder])) - self.assertEqual('def a_func(p1=None):\n pass\na_func(1)\n', - self.mod.read()) + self.assertEqual("def a_func(p1=None):\n pass\na_func(1)\n", self.mod.read()) def test_adding_arguments_for_norm_args_chang_calls_with_kwords(self): - self.mod.write('def a_func(p1=0):\n pass\na_func()\n') + self.mod.write("def a_func(p1=0):\n pass\na_func()\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - adder = change_signature.ArgumentAdder(1, 'p2', '0', '1') + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + adder = change_signature.ArgumentAdder(1, "p2", "0", "1") self.project.do(signature.get_changes([adder])) - self.assertEqual('def a_func(p1=0, p2=0):\n pass\na_func(p2=1)\n', - self.mod.read()) + self.assertEqual( + "def a_func(p1=0, p2=0):\n pass\na_func(p2=1)\n", self.mod.read() + ) def test_adding_arguments_for_norm_args_chang_calls_with_no_value(self): - self.mod.write('def a_func(p2=0):\n pass\na_func(1)\n') + self.mod.write("def a_func(p2=0):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - adder = change_signature.ArgumentAdder(0, 'p1', '0', None) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + adder = change_signature.ArgumentAdder(0, "p1", "0", None) self.project.do(signature.get_changes([adder])) - self.assertEqual('def a_func(p1=0, p2=0):\n pass\na_func(p2=1)\n', - self.mod.read()) + self.assertEqual( + "def a_func(p1=0, p2=0):\n pass\na_func(p2=1)\n", self.mod.read() + ) def test_adding_duplicate_parameter_and_raising_exceptions(self): - self.mod.write('def a_func(p1):\n pass\n') + self.mod.write("def a_func(p1):\n pass\n") with self.assertRaises(rope.base.exceptions.RefactoringError): signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentAdder(1, 'p1')])) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentAdder(1, "p1")]) + ) def test_inlining_default_arguments(self): - self.mod.write('def a_func(p1=0):\n pass\na_func()\n') + self.mod.write("def a_func(p1=0):\n pass\na_func()\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentDefaultInliner(0)])) - self.assertEqual('def a_func(p1=0):\n pass\n' - 'a_func(0)\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentDefaultInliner(0)]) + ) + self.assertEqual("def a_func(p1=0):\n pass\n" "a_func(0)\n", self.mod.read()) def test_inlining_default_arguments2(self): - self.mod.write('def a_func(p1=0):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1=0):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentDefaultInliner(0)])) - self.assertEqual('def a_func(p1=0):\n pass\n' - 'a_func(1)\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentDefaultInliner(0)]) + ) + self.assertEqual("def a_func(p1=0):\n pass\n" "a_func(1)\n", self.mod.read()) def test_preserving_args_and_keywords_order(self): - self.mod.write('def a_func(*args, **kwds):\n pass\n' - 'a_func(3, 1, 2, a=1, c=3, b=2)\n') + self.mod.write( + "def a_func(*args, **kwds):\n pass\n" "a_func(3, 1, 2, a=1, c=3, b=2)\n" + ) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentNormalizer()])) - self.assertEqual('def a_func(*args, **kwds):\n pass\n' - 'a_func(3, 1, 2, a=1, c=3, b=2)\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentNormalizer()])) + self.assertEqual( + "def a_func(*args, **kwds):\n pass\n" "a_func(3, 1, 2, a=1, c=3, b=2)\n", + self.mod.read(), + ) def test_change_order_for_only_one_parameter(self): - self.mod.write('def a_func(p1):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1):\n pass\na_func(1)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentReorderer([0])])) - self.assertEqual('def a_func(p1):\n pass\na_func(1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentReorderer([0])]) + ) + self.assertEqual("def a_func(p1):\n pass\na_func(1)\n", self.mod.read()) def test_change_order_for_two_parameter(self): - self.mod.write('def a_func(p1, p2):\n pass\na_func(1, 2)\n') + self.mod.write("def a_func(p1, p2):\n pass\na_func(1, 2)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentReorderer([1, 0])])) - self.assertEqual('def a_func(p2, p1):\n pass\na_func(2, 1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentReorderer([1, 0])]) + ) + self.assertEqual( + "def a_func(p2, p1):\n pass\na_func(2, 1)\n", self.mod.read() + ) def test_reordering_multi_line_function_headers(self): - self.mod.write('def a_func(p1,\n p2):\n pass\na_func(1, 2)\n') + self.mod.write("def a_func(p1,\n p2):\n pass\na_func(1, 2)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentReorderer([1, 0])])) - self.assertEqual('def a_func(p2, p1):\n pass\na_func(2, 1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentReorderer([1, 0])]) + ) + self.assertEqual( + "def a_func(p2, p1):\n pass\na_func(2, 1)\n", self.mod.read() + ) def test_changing_order_with_static_params(self): - self.mod.write('def a_func(p1, p2=0, p3=0):\n pass\na_func(1, 2)\n') + self.mod.write("def a_func(p1, p2=0, p3=0):\n pass\na_func(1, 2)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentReorderer([0, 2, 1])])) - self.assertEqual('def a_func(p1, p3=0, p2=0):\n pass\n' - 'a_func(1, p2=2)\n', self.mod.read()) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do( + signature.get_changes([change_signature.ArgumentReorderer([0, 2, 1])]) + ) + self.assertEqual( + "def a_func(p1, p3=0, p2=0):\n pass\n" "a_func(1, p2=2)\n", + self.mod.read(), + ) def test_doing_multiple_changes(self): changers = [] - self.mod.write('def a_func(p1):\n pass\na_func(1)\n') + self.mod.write("def a_func(p1):\n pass\na_func(1)\n") changers.append(change_signature.ArgumentRemover(0)) - changers.append(change_signature.ArgumentAdder(0, 'p2', None, None)) + changers.append(change_signature.ArgumentAdder(0, "p2", None, None)) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) signature.get_changes(changers).do() - self.assertEqual('def a_func(p2):\n pass\na_func()\n', - self.mod.read()) + self.assertEqual("def a_func(p2):\n pass\na_func()\n", self.mod.read()) def test_doing_multiple_changes2(self): changers = [] - self.mod.write('def a_func(p1, p2):\n pass\na_func(p2=2)\n') - changers.append(change_signature.ArgumentAdder(2, 'p3', None, '3')) + self.mod.write("def a_func(p1, p2):\n pass\na_func(p2=2)\n") + changers.append(change_signature.ArgumentAdder(2, "p3", None, "3")) changers.append(change_signature.ArgumentReorderer([1, 0, 2])) changers.append(change_signature.ArgumentRemover(1)) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) signature.get_changes(changers).do() - self.assertEqual('def a_func(p2, p3):\n pass\na_func(2, 3)\n', - self.mod.read()) + self.assertEqual( + "def a_func(p2, p3):\n pass\na_func(2, 3)\n", self.mod.read() + ) def test_changing_signature_in_subclasses(self): self.mod.write( - 'class A(object):\n def a_method(self):\n pass\n' - 'class B(A):\n def a_method(self):\n pass\n') + "class A(object):\n def a_method(self):\n pass\n" + "class B(A):\n def a_method(self):\n pass\n" + ) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_method') + 1) - signature.get_changes([change_signature.ArgumentAdder(1, 'p1')], - in_hierarchy=True).do() + self.project, self.mod, self.mod.read().index("a_method") + 1 + ) + signature.get_changes( + [change_signature.ArgumentAdder(1, "p1")], in_hierarchy=True + ).do() self.assertEqual( - 'class A(object):\n def a_method(self, p1):\n pass\n' - 'class B(A):\n def a_method(self, p1):\n pass\n', - self.mod.read()) + "class A(object):\n def a_method(self, p1):\n pass\n" + "class B(A):\n def a_method(self, p1):\n pass\n", + self.mod.read(), + ) def test_differentiating_class_accesses_from_instance_accesses(self): self.mod.write( - 'class A(object):\n def a_func(self, param):\n pass\n' - 'a_var = A()\nA.a_func(a_var, param=1)') + "class A(object):\n def a_func(self, param):\n pass\n" + "a_var = A()\nA.a_func(a_var, param=1)" + ) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('a_func') + 1) - self.project.do(signature.get_changes( - [change_signature.ArgumentRemover(1)])) + self.project, self.mod, self.mod.read().index("a_func") + 1 + ) + self.project.do(signature.get_changes([change_signature.ArgumentRemover(1)])) self.assertEqual( - 'class A(object):\n def a_func(self):\n pass\n' - 'a_var = A()\nA.a_func(a_var)', self.mod.read()) + "class A(object):\n def a_func(self):\n pass\n" + "a_var = A()\nA.a_func(a_var)", + self.mod.read(), + ) def test_changing_signature_for_constructors(self): self.mod.write( - 'class C(object):\n def __init__(self, p):\n pass\n' - 'c = C(1)\n') + "class C(object):\n def __init__(self, p):\n pass\n" "c = C(1)\n" + ) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('C') + 1) + self.project, self.mod, self.mod.read().index("C") + 1 + ) signature.get_changes([change_signature.ArgumentRemover(1)]).do() self.assertEqual( - 'class C(object):\n def __init__(self):\n pass\n' - 'c = C()\n', - self.mod.read()) + "class C(object):\n def __init__(self):\n pass\n" "c = C()\n", + self.mod.read(), + ) def test_changing_signature_for_constructors2(self): self.mod.write( - 'class C(object):\n def __init__(self, p):\n pass\n' - 'c = C(1)\n') + "class C(object):\n def __init__(self, p):\n pass\n" "c = C(1)\n" + ) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('__init__') + 1) + self.project, self.mod, self.mod.read().index("__init__") + 1 + ) signature.get_changes([change_signature.ArgumentRemover(1)]).do() self.assertEqual( - 'class C(object):\n def __init__(self):\n pass\n' - 'c = C()\n', - self.mod.read()) + "class C(object):\n def __init__(self):\n pass\n" "c = C()\n", + self.mod.read(), + ) def test_changing_signature_for_constructors_when_using_super(self): self.mod.write( - 'class A(object):\n def __init__(self, p):\n pass\n' - 'class B(A):\n ' - 'def __init__(self, p):\n super(B, self).__init__(p)\n') + "class A(object):\n def __init__(self, p):\n pass\n" + "class B(A):\n " + "def __init__(self, p):\n super(B, self).__init__(p)\n" + ) signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().index('__init__') + 1) + self.project, self.mod, self.mod.read().index("__init__") + 1 + ) signature.get_changes([change_signature.ArgumentRemover(1)]).do() self.assertEqual( - 'class A(object):\n def __init__(self):\n pass\n' - 'class B(A):\n ' - 'def __init__(self, p):\n super(B, self).__init__()\n', - self.mod.read()) + "class A(object):\n def __init__(self):\n pass\n" + "class B(A):\n " + "def __init__(self, p):\n super(B, self).__init__()\n", + self.mod.read(), + ) def test_redordering_arguments_reported_by_mft(self): - self.mod.write('def f(a, b, c):\n pass\nf(1, 2, 3)\n') + self.mod.write("def f(a, b, c):\n pass\nf(1, 2, 3)\n") signature = change_signature.ChangeSignature( - self.project, self.mod, self.mod.read().rindex('f')) - signature.get_changes( - [change_signature.ArgumentReorderer([1, 2, 0])]).do() - self.assertEqual('def f(b, c, a):\n pass\nf(2, 3, 1)\n', - self.mod.read()) + self.project, self.mod, self.mod.read().rindex("f") + ) + signature.get_changes([change_signature.ArgumentReorderer([1, 2, 0])]).do() + self.assertEqual("def f(b, c, a):\n pass\nf(2, 3, 1)\n", self.mod.read()) def test_resources_parameter(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func(param):\n pass\n') - self.mod.write('import mod1\nmod1.a_func(1)\n') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func(param):\n pass\n") + self.mod.write("import mod1\nmod1.a_func(1)\n") signature = change_signature.ChangeSignature( - self.project, mod1, mod1.read().index('a_func') + 1) - signature.get_changes([change_signature.ArgumentRemover(0)], - resources=[mod1]).do() - self.assertEqual('import mod1\nmod1.a_func(1)\n', self.mod.read()) - self.assertEqual('def a_func():\n pass\n', mod1.read()) + self.project, mod1, mod1.read().index("a_func") + 1 + ) + signature.get_changes( + [change_signature.ArgumentRemover(0)], resources=[mod1] + ).do() + self.assertEqual("import mod1\nmod1.a_func(1)\n", self.mod.read()) + self.assertEqual("def a_func():\n pass\n", mod1.read()) def test_reordering_and_automatic_defaults(self): - code = 'def f(p1, p2=2):\n' \ - ' pass\n' \ - 'f(1, 2)\n' + code = "def f(p1, p2=2):\n" " pass\n" "f(1, 2)\n" self.mod.write(code) signature = change_signature.ChangeSignature( - self.project, self.mod, code.index('f(')) - reorder = change_signature.ArgumentReorderer([1, 0], autodef='1') + self.project, self.mod, code.index("f(") + ) + reorder = change_signature.ArgumentReorderer([1, 0], autodef="1") signature.get_changes([reorder]).do() - expected = 'def f(p2=2, p1=1):\n' \ - ' pass\n' \ - 'f(2, 1)\n' + expected = "def f(p2=2, p1=1):\n" " pass\n" "f(2, 1)\n" self.assertEqual(expected, self.mod.read()) diff --git a/ropetest/refactor/extracttest.py b/ropetest/refactor/extracttest.py index b3bbbf7d2..8db4acd82 100644 --- a/ropetest/refactor/extracttest.py +++ b/ropetest/refactor/extracttest.py @@ -1,4 +1,5 @@ from textwrap import dedent + try: import unittest2 as unittest except ImportError: @@ -12,7 +13,6 @@ class ExtractMethodTest(unittest.TestCase): - def setUp(self): super(ExtractMethodTest, self).setUp() self.project = testutils.sample_project() @@ -23,15 +23,14 @@ def tearDown(self): super(ExtractMethodTest, self).tearDown() def do_extract_method(self, source_code, start, end, extracted, **kwds): - testmod = testutils.create_module(self.project, 'testmod') + testmod = testutils.create_module(self.project, "testmod") testmod.write(source_code) - extractor = extract.ExtractMethod( - self.project, testmod, start, end) + extractor = extract.ExtractMethod(self.project, testmod, start, end) self.project.do(extractor.get_changes(extracted, **kwds)) return testmod.read() def do_extract_variable(self, source_code, start, end, extracted, **kwds): - testmod = testutils.create_module(self.project, 'testmod') + testmod = testutils.create_module(self.project, "testmod") testmod.write(source_code) extractor = extract.ExtractVariable(self.project, testmod, start, end) self.project.do(extractor.get_changes(extracted, **kwds)) @@ -48,15 +47,15 @@ def a_func(): print('two') """) start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'extracted') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "extracted") + expected = dedent("""\ def a_func(): extracted() print('two') def extracted(): print('one') - ''') + """) self.assertEqual(expected, refactored) def test_simple_extract_function_one_line(self): @@ -67,383 +66,437 @@ def a_func(): """) selected = "'one'" start, end = code.index(selected), code.index(selected) + len(selected) - refactored = self.do_extract_method(code, start, end, 'extracted') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "extracted") + expected = dedent("""\ def a_func(): resp = extracted() print(resp) def extracted(): return 'one' - ''') + """) self.assertEqual(expected, refactored) def test_extract_function_at_the_end_of_file(self): code = "def a_func():\n print('one')" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'extracted') - expected = "def a_func():\n extracted()\n" \ - "def extracted():\n print('one')\n" + refactored = self.do_extract_method(code, start, end, "extracted") + expected = ( + "def a_func():\n extracted()\n" "def extracted():\n print('one')\n" + ) self.assertEqual(expected, refactored) def test_extract_function_after_scope(self): - code = "def a_func():\n print('one')\n print('two')" \ - "\n\nprint('hey')\n" + code = "def a_func():\n print('one')\n print('two')" "\n\nprint('hey')\n" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'extracted') - expected = "def a_func():\n extracted()\n print('two')\n\n" \ - "def extracted():\n print('one')\n\nprint('hey')\n" + refactored = self.do_extract_method(code, start, end, "extracted") + expected = ( + "def a_func():\n extracted()\n print('two')\n\n" + "def extracted():\n print('one')\n\nprint('hey')\n" + ) self.assertEqual(expected, refactored) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_extract_function_containing_dict_generalized_unpacking(self): - code = dedent('''\ + code = dedent("""\ def a_func(dict1): dict2 = {} a_var = {a: b, **dict1, **dict2} - ''') - start = code.index('{a') - end = code.index('2}') + len('2}') - refactored = self.do_extract_method(code, start, end, 'extracted') - expected = dedent('''\ + """) + start = code.index("{a") + end = code.index("2}") + len("2}") + refactored = self.do_extract_method(code, start, end, "extracted") + expected = dedent("""\ def a_func(dict1): dict2 = {} a_var = extracted(dict1, dict2) def extracted(dict1, dict2): return {a: b, **dict1, **dict2} - ''') + """) self.assertEqual(expected, refactored) def test_simple_extract_function_with_parameter(self): code = "def a_func():\n a_var = 10\n print(a_var)\n" start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = "def a_func():\n a_var = 10\n new_func(a_var)\n\n" \ - "def new_func(a_var):\n print(a_var)\n" + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a_var = 10\n new_func(a_var)\n\n" + "def new_func(a_var):\n print(a_var)\n" + ) self.assertEqual(expected, refactored) def test_not_unread_variables_as_parameter(self): code = "def a_func():\n a_var = 10\n print('hey')\n" start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = "def a_func():\n a_var = 10\n new_func()\n\n" \ - "def new_func():\n print('hey')\n" + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a_var = 10\n new_func()\n\n" + "def new_func():\n print('hey')\n" + ) self.assertEqual(expected, refactored) def test_simple_extract_function_with_two_parameter(self): - code = 'def a_func():\n a_var = 10\n another_var = 20\n' \ - ' third_var = a_var + another_var\n' + code = ( + "def a_func():\n a_var = 10\n another_var = 20\n" + " third_var = a_var + another_var\n" + ) start, end = self._convert_line_range_to_offset(code, 4, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n a_var = 10\n another_var = 20\n' \ - ' new_func(a_var, another_var)\n\n' \ - 'def new_func(a_var, another_var):\n' \ - ' third_var = a_var + another_var\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a_var = 10\n another_var = 20\n" + " new_func(a_var, another_var)\n\n" + "def new_func(a_var, another_var):\n" + " third_var = a_var + another_var\n" + ) self.assertEqual(expected, refactored) def test_simple_extract_function_with_return_value(self): - code = 'def a_func():\n a_var = 10\n print(a_var)\n' + code = "def a_func():\n a_var = 10\n print(a_var)\n" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n a_var = new_func()' \ - '\n print(a_var)\n\n' \ - 'def new_func():\n a_var = 10\n return a_var\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a_var = new_func()" + "\n print(a_var)\n\n" + "def new_func():\n a_var = 10\n return a_var\n" + ) self.assertEqual(expected, refactored) def test_extract_function_with_multiple_return_values(self): - code = 'def a_func():\n a_var = 10\n another_var = 20\n' \ - ' third_var = a_var + another_var\n' + code = ( + "def a_func():\n a_var = 10\n another_var = 20\n" + " third_var = a_var + another_var\n" + ) start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n a_var, another_var = new_func()\n' \ - ' third_var = a_var + another_var\n\n' \ - 'def new_func():\n a_var = 10\n another_var = 20\n' \ - ' return a_var, another_var\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a_var, another_var = new_func()\n" + " third_var = a_var + another_var\n\n" + "def new_func():\n a_var = 10\n another_var = 20\n" + " return a_var, another_var\n" + ) self.assertEqual(expected, refactored) def test_simple_extract_method(self): - code = 'class AClass(object):\n\n' \ - ' def a_func(self):\n print(1)\n print(2)\n' + code = ( + "class AClass(object):\n\n" + " def a_func(self):\n print(1)\n print(2)\n" + ) start, end = self._convert_line_range_to_offset(code, 4, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'class AClass(object):\n\n' \ - ' def a_func(self):\n' \ - ' self.new_func()\n' \ - ' print(2)\n\n' \ - ' def new_func(self):\n print(1)\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "class AClass(object):\n\n" + " def a_func(self):\n" + " self.new_func()\n" + " print(2)\n\n" + " def new_func(self):\n print(1)\n" + ) self.assertEqual(expected, refactored) def test_extract_method_with_args_and_returns(self): - code = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' a_var = 10\n' \ - ' another_var = a_var * 3\n' \ - ' third_var = a_var + another_var\n' + code = ( + "class AClass(object):\n" + " def a_func(self):\n" + " a_var = 10\n" + " another_var = a_var * 3\n" + " third_var = a_var + another_var\n" + ) start, end = self._convert_line_range_to_offset(code, 4, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' a_var = 10\n' \ - ' another_var = self.new_func(a_var)\n' \ - ' third_var = a_var + another_var\n\n' \ - ' def new_func(self, a_var):\n' \ - ' another_var = a_var * 3\n' \ - ' return another_var\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "class AClass(object):\n" + " def a_func(self):\n" + " a_var = 10\n" + " another_var = self.new_func(a_var)\n" + " third_var = a_var + another_var\n\n" + " def new_func(self, a_var):\n" + " another_var = a_var * 3\n" + " return another_var\n" + ) self.assertEqual(expected, refactored) def test_extract_method_with_self_as_argument(self): - code = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' print(self)\n' + code = ( + "class AClass(object):\n" " def a_func(self):\n" " print(self)\n" + ) start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' self.new_func()\n\n' \ - ' def new_func(self):\n' \ - ' print(self)\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "class AClass(object):\n" + " def a_func(self):\n" + " self.new_func()\n\n" + " def new_func(self):\n" + " print(self)\n" + ) self.assertEqual(expected, refactored) def test_extract_method_with_no_self_as_argument(self): - code = 'class AClass(object):\n' \ - ' def a_func():\n' \ - ' print(1)\n' + code = "class AClass(object):\n" " def a_func():\n" " print(1)\n" start, end = self._convert_line_range_to_offset(code, 3, 3) with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_with_multiple_methods(self): - code = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' print(self)\n\n' \ - ' def another_func(self):\n' \ - ' pass\n' + code = ( + "class AClass(object):\n" + " def a_func(self):\n" + " print(self)\n\n" + " def another_func(self):\n" + " pass\n" + ) start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' self.new_func()\n\n' \ - ' def new_func(self):\n' \ - ' print(self)\n\n' \ - ' def another_func(self):\n' \ - ' pass\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "class AClass(object):\n" + " def a_func(self):\n" + " self.new_func()\n\n" + " def new_func(self):\n" + " print(self)\n\n" + " def another_func(self):\n" + " pass\n" + ) self.assertEqual(expected, refactored) def test_extract_function_with_function_returns(self): - code = 'def a_func():\n def inner_func():\n pass\n' \ - ' inner_func()\n' + code = ( + "def a_func():\n def inner_func():\n pass\n" " inner_func()\n" + ) start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n' \ - ' inner_func = new_func()\n inner_func()\n\n' \ - 'def new_func():\n' \ - ' def inner_func():\n pass\n' \ - ' return inner_func\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n" + " inner_func = new_func()\n inner_func()\n\n" + "def new_func():\n" + " def inner_func():\n pass\n" + " return inner_func\n" + ) self.assertEqual(expected, refactored) def test_simple_extract_global_function(self): code = "print('one')\nprint('two')\nprint('three')\n" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = "print('one')\n\ndef new_func():\n print('two')\n" \ - "\nnew_func()\nprint('three')\n" + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "print('one')\n\ndef new_func():\n print('two')\n" + "\nnew_func()\nprint('three')\n" + ) self.assertEqual(expected, refactored) def test_extract_global_function_inside_ifs(self): - code = 'if True:\n a = 10\n' + code = "if True:\n a = 10\n" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = '\ndef new_func():\n a = 10\n\nif True:\n' \ - ' new_func()\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = "\ndef new_func():\n a = 10\n\nif True:\n" " new_func()\n" self.assertEqual(expected, refactored) def test_extract_function_while_inner_function_reads(self): - code = 'def a_func():\n a_var = 10\n' \ - ' def inner_func():\n print(a_var)\n' \ - ' return inner_func\n' + code = ( + "def a_func():\n a_var = 10\n" + " def inner_func():\n print(a_var)\n" + " return inner_func\n" + ) start, end = self._convert_line_range_to_offset(code, 3, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n a_var = 10\n' \ - ' inner_func = new_func(a_var)' \ - '\n return inner_func\n\n' \ - 'def new_func(a_var):\n' \ - ' def inner_func():\n print(a_var)\n' \ - ' return inner_func\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a_var = 10\n" + " inner_func = new_func(a_var)" + "\n return inner_func\n\n" + "def new_func(a_var):\n" + " def inner_func():\n print(a_var)\n" + " return inner_func\n" + ) self.assertEqual(expected, refactored) def test_extract_method_bad_range(self): code = "def a_func():\n pass\na_var = 10\n" start, end = self._convert_line_range_to_offset(code, 2, 3) with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_bad_range2(self): code = "class AClass(object):\n pass\n" start, end = self._convert_line_range_to_offset(code, 1, 1) with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_containing_return(self): - code = 'def a_func(arg):\n if arg:\n return arg * 2' \ - '\n return 1' + code = "def a_func(arg):\n if arg:\n return arg * 2" "\n return 1" start, end = self._convert_line_range_to_offset(code, 2, 4) with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_containing_yield(self): code = "def a_func(arg):\n yield arg * 2\n" start, end = self._convert_line_range_to_offset(code, 2, 2) with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_containing_uncomplete_lines(self): - code = 'a_var = 20\nanother_var = 30\n' - start = code.index('20') - end = code.index('30') + 2 + code = "a_var = 20\nanother_var = 30\n" + start = code.index("20") + end = code.index("30") + 2 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_containing_uncomplete_lines2(self): - code = 'a_var = 20\nanother_var = 30\n' - start = code.index('20') - end = code.index('another') + 5 + code = "a_var = 20\nanother_var = 30\n" + start = code.index("20") + end = code.index("another") + 5 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_function_and_argument_as_paramenter(self): - code = 'def a_func(arg):\n print(arg)\n' + code = "def a_func(arg):\n print(arg)\n" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func(arg):\n new_func(arg)\n\n' \ - 'def new_func(arg):\n print(arg)\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func(arg):\n new_func(arg)\n\n" + "def new_func(arg):\n print(arg)\n" + ) self.assertEqual(expected, refactored) def test_extract_function_and_end_as_the_start_of_a_line(self): code = 'print("hey")\nif True:\n pass\n' start = 0 - end = code.index('\n') + 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = '\ndef new_func():\n print("hey")\n\n' \ - 'new_func()\nif True:\n pass\n' + end = code.index("\n") + 1 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + '\ndef new_func():\n print("hey")\n\n' "new_func()\nif True:\n pass\n" + ) self.assertEqual(expected, refactored) def test_extract_function_and_indented_blocks(self): - code = 'def a_func(arg):\n if True:\n' \ - ' if True:\n print(arg)\n' + code = ( + "def a_func(arg):\n if True:\n" + " if True:\n print(arg)\n" + ) start, end = self._convert_line_range_to_offset(code, 3, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func(arg):\n ' \ - 'if True:\n new_func(arg)\n\n' \ - 'def new_func(arg):\n if True:\n print(arg)\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func(arg):\n " + "if True:\n new_func(arg)\n\n" + "def new_func(arg):\n if True:\n print(arg)\n" + ) self.assertEqual(expected, refactored) def test_extract_method_and_multi_line_headers(self): - code = 'def a_func(\n arg):\n print(arg)\n' + code = "def a_func(\n arg):\n print(arg)\n" start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func(\n arg):\n new_func(arg)\n\n' \ - 'def new_func(arg):\n print(arg)\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func(\n arg):\n new_func(arg)\n\n" + "def new_func(arg):\n print(arg)\n" + ) self.assertEqual(expected, refactored) def test_single_line_extract_function(self): - code = 'a_var = 10 + 20\n' - start = code.index('10') - end = code.index('20') + 2 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = "\ndef new_func():\n " \ - "return 10 + 20\n\na_var = new_func()\n" + code = "a_var = 10 + 20\n" + start = code.index("10") + end = code.index("20") + 2 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = "\ndef new_func():\n " "return 10 + 20\n\na_var = new_func()\n" self.assertEqual(expected, refactored) def test_single_line_extract_function2(self): - code = 'def a_func():\n a = 10\n b = a * 20\n' - start = code.rindex('a') - end = code.index('20') + 2 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n a = 10\n b = new_func(a)\n' \ - '\ndef new_func(a):\n return a * 20\n' + code = "def a_func():\n a = 10\n b = a * 20\n" + start = code.rindex("a") + end = code.index("20") + 2 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a = 10\n b = new_func(a)\n" + "\ndef new_func(a):\n return a * 20\n" + ) self.assertEqual(expected, refactored) def test_single_line_extract_method_and_logical_lines(self): - code = 'a_var = 10 +\\\n 20\n' - start = code.index('10') - end = code.index('20') + 2 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = '\ndef new_func():\n ' \ - 'return 10 + 20\n\na_var = new_func()\n' + code = "a_var = 10 +\\\n 20\n" + start = code.index("10") + end = code.index("20") + 2 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = "\ndef new_func():\n " "return 10 + 20\n\na_var = new_func()\n" self.assertEqual(expected, refactored) def test_single_line_extract_method_and_logical_lines2(self): - code = 'a_var = (10,\\\n 20)\n' - start = code.index('10') - 1 - end = code.index('20') + 3 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = '\ndef new_func():\n' \ - ' return (10, 20)\n\na_var = new_func()\n' + code = "a_var = (10,\\\n 20)\n" + start = code.index("10") - 1 + end = code.index("20") + 3 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = "\ndef new_func():\n" " return (10, 20)\n\na_var = new_func()\n" self.assertEqual(expected, refactored) def test_single_line_extract_method(self): - code = "class AClass(object):\n\n" \ - " def a_func(self):\n a = 10\n b = a * a\n" - start = code.rindex('=') + 2 - end = code.rindex('a') + 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'class AClass(object):\n\n' \ - ' def a_func(self):\n' \ - ' a = 10\n b = self.new_func(a)\n\n' \ - ' def new_func(self, a):\n return a * a\n' + code = ( + "class AClass(object):\n\n" + " def a_func(self):\n a = 10\n b = a * a\n" + ) + start = code.rindex("=") + 2 + end = code.rindex("a") + 1 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "class AClass(object):\n\n" + " def a_func(self):\n" + " a = 10\n b = self.new_func(a)\n\n" + " def new_func(self, a):\n return a * a\n" + ) self.assertEqual(expected, refactored) def test_single_line_extract_function_if_condition(self): - code = 'if True:\n pass\n' - start = code.index('True') - end = code.index('True') + 4 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = "\ndef new_func():\n return True\n\nif new_func():" \ - "\n pass\n" + code = "if True:\n pass\n" + start = code.index("True") + end = code.index("True") + 4 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = "\ndef new_func():\n return True\n\nif new_func():" "\n pass\n" self.assertEqual(expected, refactored) def test_unneeded_params(self): - code = 'class A(object):\n ' \ - 'def a_func(self):\n a_var = 10\n a_var += 2\n' - start = code.rindex('2') - end = code.rindex('2') + 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'class A(object):\n' \ - ' def a_func(self):\n a_var = 10\n' \ - ' a_var += self.new_func()\n\n' \ - ' def new_func(self):\n return 2\n' + code = ( + "class A(object):\n " + "def a_func(self):\n a_var = 10\n a_var += 2\n" + ) + start = code.rindex("2") + end = code.rindex("2") + 1 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "class A(object):\n" + " def a_func(self):\n a_var = 10\n" + " a_var += self.new_func()\n\n" + " def new_func(self):\n return 2\n" + ) self.assertEqual(expected, refactored) def test_breaks_and_continues_inside_loops(self): - code = 'def a_func():\n for i in range(10):\n continue\n' - start = code.index('for') + code = "def a_func():\n for i in range(10):\n continue\n" + start = code.index("for") end = len(code) - 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n new_func()\n\n' \ - 'def new_func():\n' \ - ' for i in range(10):\n continue\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n new_func()\n\n" + "def new_func():\n" + " for i in range(10):\n continue\n" + ) self.assertEqual(expected, refactored) def test_breaks_and_continues_outside_loops(self): - code = 'def a_func():\n' \ - ' for i in range(10):\n a = i\n continue\n' - start = code.index('a = i') + code = ( + "def a_func():\n" + " for i in range(10):\n a = i\n continue\n" + ) + start = code.index("a = i") end = len(code) - 1 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_for_loop_variable_scope(self): - code = dedent('''\ + code = dedent("""\ def my_func(): i = 0 for dummy in range(10): i += 1 print(i) - ''') + """) start, end = self._convert_line_range_to_offset(code, 4, 5) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def my_func(): i = 0 for dummy in range(10): @@ -453,20 +506,20 @@ def new_func(i): i += 1 print(i) return i - ''') + """) self.assertEqual(expected, refactored) def test_for_loop_variable_scope_read_then_write(self): - code = dedent('''\ + code = dedent("""\ def my_func(): i = 0 for dummy in range(10): a = i + 1 i = a + 1 - ''') + """) start, end = self._convert_line_range_to_offset(code, 4, 5) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def my_func(): i = 0 for dummy in range(10): @@ -476,20 +529,20 @@ def new_func(i): a = i + 1 i = a + 1 return i - ''') + """) self.assertEqual(expected, refactored) def test_for_loop_variable_scope_write_then_read(self): - code = dedent('''\ + code = dedent("""\ def my_func(): i = 0 for dummy in range(10): i = 'hello' print(i) - ''') + """) start, end = self._convert_line_range_to_offset(code, 4, 5) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def my_func(): i = 0 for dummy in range(10): @@ -498,20 +551,20 @@ def my_func(): def new_func(): i = 'hello' print(i) - ''') + """) self.assertEqual(expected, refactored) def test_for_loop_variable_scope_write_only(self): - code = dedent('''\ + code = dedent("""\ def my_func(): i = 0 for num in range(10): i = 'hello' + num print(i) - ''') + """) start, end = self._convert_line_range_to_offset(code, 4, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def my_func(): i = 0 for num in range(10): @@ -521,215 +574,223 @@ def my_func(): def new_func(num): i = 'hello' + num return i - ''') + """) self.assertEqual(expected, refactored) def test_variable_writes_followed_by_variable_reads_after_extraction(self): - code = 'def a_func():\n a = 1\n a = 2\n b = a\n' - start = code.index('a = 1') - end = code.index('a = 2') - 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n new_func()\n a = 2\n b = a\n\n' \ - 'def new_func():\n a = 1\n' + code = "def a_func():\n a = 1\n a = 2\n b = a\n" + start = code.index("a = 1") + end = code.index("a = 2") - 1 + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n new_func()\n a = 2\n b = a\n\n" + "def new_func():\n a = 1\n" + ) self.assertEqual(expected, refactored) def test_var_writes_followed_by_var_reads_inside_extraction(self): - code = 'def a_func():\n a = 1\n a = 2\n b = a\n' - start = code.index('a = 2') + code = "def a_func():\n a = 1\n a = 2\n b = a\n" + start = code.index("a = 2") end = len(code) - 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n a = 1\n new_func()\n\n' \ - 'def new_func():\n a = 2\n b = a\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n a = 1\n new_func()\n\n" + "def new_func():\n a = 2\n b = a\n" + ) self.assertEqual(expected, refactored) def test_extract_variable(self): - code = 'a_var = 10 + 20\n' - start = code.index('10') - end = code.index('20') + 2 - refactored = self.do_extract_variable(code, start, end, 'new_var') - expected = 'new_var = 10 + 20\na_var = new_var\n' + code = "a_var = 10 + 20\n" + start = code.index("10") + end = code.index("20") + 2 + refactored = self.do_extract_variable(code, start, end, "new_var") + expected = "new_var = 10 + 20\na_var = new_var\n" self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_extract_variable_f_string(self): - code = dedent('''\ + code = dedent("""\ foo(f"abc {a_var} def", 10) - ''') + """) start = code.index('f"') end = code.index('def"') + 4 - refactored = self.do_extract_variable(code, start, end, 'new_var') - expected = dedent('''\ + refactored = self.do_extract_variable(code, start, end, "new_var") + expected = dedent("""\ new_var = f"abc {a_var} def" foo(new_var, 10) - ''') + """) self.assertEqual(expected, refactored) def test_extract_variable_multiple_lines(self): - code = 'a = 1\nb = 2\n' - start = code.index('1') - end = code.index('1') + 1 - refactored = self.do_extract_variable(code, start, end, 'c') - expected = 'c = 1\na = c\nb = 2\n' + code = "a = 1\nb = 2\n" + start = code.index("1") + end = code.index("1") + 1 + refactored = self.do_extract_variable(code, start, end, "c") + expected = "c = 1\na = c\nb = 2\n" self.assertEqual(expected, refactored) def test_extract_variable_in_the_middle_of_statements(self): - code = 'a = 1 + 2\n' - start = code.index('1') - end = code.index('1') + 1 - refactored = self.do_extract_variable(code, start, end, 'c') - expected = 'c = 1\na = c + 2\n' + code = "a = 1 + 2\n" + start = code.index("1") + end = code.index("1") + 1 + refactored = self.do_extract_variable(code, start, end, "c") + expected = "c = 1\na = c + 2\n" self.assertEqual(expected, refactored) def test_extract_variable_for_a_tuple(self): - code = 'a = 1, 2\n' - start = code.index('1') - end = code.index('2') + 1 - refactored = self.do_extract_variable(code, start, end, 'c') - expected = 'c = 1, 2\na = c\n' + code = "a = 1, 2\n" + start = code.index("1") + end = code.index("2") + 1 + refactored = self.do_extract_variable(code, start, end, "c") + expected = "c = 1, 2\na = c\n" self.assertEqual(expected, refactored) def test_extract_variable_for_a_string(self): code = 'def a_func():\n a = "hey!"\n' start = code.index('"') end = code.rindex('"') + 1 - refactored = self.do_extract_variable(code, start, end, 'c') + refactored = self.do_extract_variable(code, start, end, "c") expected = 'def a_func():\n c = "hey!"\n a = c\n' self.assertEqual(expected, refactored) def test_extract_variable_inside_ifs(self): - code = 'if True:\n a = 1 + 2\n' - start = code.index('1') - end = code.rindex('2') + 1 - refactored = self.do_extract_variable(code, start, end, 'b') - expected = 'if True:\n b = 1 + 2\n a = b\n' + code = "if True:\n a = 1 + 2\n" + start = code.index("1") + end = code.rindex("2") + 1 + refactored = self.do_extract_variable(code, start, end, "b") + expected = "if True:\n b = 1 + 2\n a = b\n" self.assertEqual(expected, refactored) def test_extract_variable_inside_ifs_and_logical_lines(self): - code = 'if True:\n a = (3 + \n(1 + 2))\n' - start = code.index('1') - end = code.index('2') + 1 - refactored = self.do_extract_variable(code, start, end, 'b') - expected = 'if True:\n b = 1 + 2\n a = (3 + \n(b))\n' + code = "if True:\n a = (3 + \n(1 + 2))\n" + start = code.index("1") + end = code.index("2") + 1 + refactored = self.do_extract_variable(code, start, end, "b") + expected = "if True:\n b = 1 + 2\n a = (3 + \n(b))\n" self.assertEqual(expected, refactored) # TODO: Handle when extracting a subexpression def xxx_test_extract_variable_for_a_subexpression(self): - code = 'a = 3 + 1 + 2\n' - start = code.index('1') - end = code.index('2') + 1 - refactored = self.do_extract_variable(code, start, end, 'b') - expected = 'b = 1 + 2\na = 3 + b\n' + code = "a = 3 + 1 + 2\n" + start = code.index("1") + end = code.index("2") + 1 + refactored = self.do_extract_variable(code, start, end, "b") + expected = "b = 1 + 2\na = 3 + b\n" self.assertEqual(expected, refactored) def test_extract_variable_starting_from_the_start_of_the_line(self): - code = 'a_dict = {1: 1}\na_dict.values().count(1)\n' - start = code.rindex('a_dict') - end = code.index('count') - 1 - refactored = self.do_extract_variable(code, start, end, 'values') - expected = 'a_dict = {1: 1}\n' \ - 'values = a_dict.values()\nvalues.count(1)\n' + code = "a_dict = {1: 1}\na_dict.values().count(1)\n" + start = code.rindex("a_dict") + end = code.index("count") - 1 + refactored = self.do_extract_variable(code, start, end, "values") + expected = "a_dict = {1: 1}\n" "values = a_dict.values()\nvalues.count(1)\n" self.assertEqual(expected, refactored) def test_extract_variable_on_the_last_line_of_a_function(self): - code = 'def f():\n a_var = {}\n a_var.keys()\n' - start = code.rindex('a_var') - end = code.index('.keys') - refactored = self.do_extract_variable(code, start, end, 'new_var') - expected = 'def f():\n a_var = {}\n ' \ - 'new_var = a_var\n new_var.keys()\n' + code = "def f():\n a_var = {}\n a_var.keys()\n" + start = code.rindex("a_var") + end = code.index(".keys") + refactored = self.do_extract_variable(code, start, end, "new_var") + expected = ( + "def f():\n a_var = {}\n " "new_var = a_var\n new_var.keys()\n" + ) self.assertEqual(expected, refactored) def test_extract_variable_on_the_indented_function_statement(self): - code = 'def f():\n if True:\n a_var = 1 + 2\n' - start = code.index('1') - end = code.index('2') + 1 - refactored = self.do_extract_variable(code, start, end, 'new_var') - expected = 'def f():\n if True:\n' \ - ' new_var = 1 + 2\n a_var = new_var\n' + code = "def f():\n if True:\n a_var = 1 + 2\n" + start = code.index("1") + end = code.index("2") + 1 + refactored = self.do_extract_variable(code, start, end, "new_var") + expected = ( + "def f():\n if True:\n" + " new_var = 1 + 2\n a_var = new_var\n" + ) self.assertEqual(expected, refactored) def test_extract_method_on_the_last_line_of_a_function(self): - code = 'def f():\n a_var = {}\n a_var.keys()\n' - start = code.rindex('a_var') - end = code.index('.keys') - refactored = self.do_extract_method(code, start, end, 'new_f') - expected = 'def f():\n a_var = {}\n new_f(a_var).keys()\n\n' \ - 'def new_f(a_var):\n return a_var\n' + code = "def f():\n a_var = {}\n a_var.keys()\n" + start = code.rindex("a_var") + end = code.index(".keys") + refactored = self.do_extract_method(code, start, end, "new_f") + expected = ( + "def f():\n a_var = {}\n new_f(a_var).keys()\n\n" + "def new_f(a_var):\n return a_var\n" + ) self.assertEqual(expected, refactored) def test_raising_exception_when_on_incomplete_variables(self): - code = 'a_var = 10 + 20\n' - start = code.index('10') + 1 - end = code.index('20') + 2 + code = "a_var = 10 + 20\n" + start = code.index("10") + 1 + end = code.index("20") + 2 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_raising_exception_when_on_incomplete_variables_on_end(self): - code = 'a_var = 10 + 20\n' - start = code.index('10') - end = code.index('20') + 1 + code = "a_var = 10 + 20\n" + start = code.index("10") + end = code.index("20") + 1 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_raising_exception_on_bad_parens(self): - code = 'a_var = (10 + 20) + 30\n' - start = code.index('20') - end = code.index('30') + 2 + code = "a_var = (10 + 20) + 30\n" + start = code.index("20") + end = code.index("30") + 2 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_raising_exception_on_bad_operators(self): - code = 'a_var = 10 + 20 + 30\n' - start = code.index('10') - end = code.rindex('+') + 1 + code = "a_var = 10 + 20 + 30\n" + start = code.index("10") + end = code.rindex("+") + 1 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") # FIXME: Extract method should be more intelligent about bad ranges def xxx_test_raising_exception_on_function_parens(self): - code = 'a = range(10)' - start = code.index('(') - end = code.rindex(')') + 1 + code = "a = range(10)" + start = code.index("(") + end = code.rindex(")") + 1 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_method_and_extra_blank_lines(self): - code = '\nprint(1)\n' - refactored = self.do_extract_method(code, 0, len(code), 'new_f') - expected = '\n\ndef new_f():\n print(1)\n\nnew_f()\n' + code = "\nprint(1)\n" + refactored = self.do_extract_method(code, 0, len(code), "new_f") + expected = "\n\ndef new_f():\n print(1)\n\nnew_f()\n" self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_extract_method_f_string_extract_method(self): - code = dedent('''\ + code = dedent("""\ def func(a_var): foo(f"abc {a_var}", 10) - ''') + """) start = code.index('f"') end = code.index('}"') + 2 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def func(a_var): foo(new_func(a_var), 10) def new_func(a_var): return f"abc {a_var}" - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_extract_method_f_string_extract_method_complex_expression(self): - code = dedent('''\ + code = dedent("""\ def func(a_var): b_var = int c_var = 10 fill = 10 foo(f"abc {a_var + f'{b_var(a_var)}':{fill}16}" f"{c_var}", 10) - ''') + """) start = code.index('f"') end = code.index('c_var}"') + 7 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def func(a_var): b_var = int c_var = 10 @@ -738,85 +799,83 @@ def func(a_var): def new_func(a_var, b_var, c_var, fill): return f"abc {a_var + f'{b_var(a_var)}':{fill}16}" f"{c_var}" - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_extract_method_f_string_false_comment(self): - code = dedent('''\ + code = dedent("""\ def func(a_var): foo(f"abc {a_var} # ", 10) - ''') + """) start = code.index('f"') end = code.index('# "') + 3 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def func(a_var): foo(new_func(a_var), 10) def new_func(a_var): return f"abc {a_var} # " - ''') + """) self.assertEqual(expected, refactored) @unittest.expectedFailure - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_extract_method_f_string_false_format_value_in_regular_string(self): - code = dedent('''\ + code = dedent("""\ def func(a_var): b_var = 1 foo(f"abc {a_var} " "{b_var}" f"{b_var} def", 10) - ''') + """) start = code.index('f"') end = code.index('def"') + 4 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def func(a_var): b_var = 1 foo(new_func(a_var, b_var), 10) def new_func(a_var, b_var): return f"abc {a_var} " "{b_var}" f"{b_var} def" - ''') + """) self.assertEqual(expected, refactored) def test_variable_writes_in_the_same_line_as_variable_read(self): - code = 'a = 1\na = 1 + a\n' - start = code.index('\n') + 1 + code = "a = 1\na = 1 + a\n" + start = code.index("\n") + 1 end = len(code) - refactored = self.do_extract_method(code, start, end, 'new_f', - global_=True) - expected = 'a = 1\n\ndef new_f(a):\n a = 1 + a\n\nnew_f(a)\n' + refactored = self.do_extract_method(code, start, end, "new_f", global_=True) + expected = "a = 1\n\ndef new_f(a):\n a = 1 + a\n\nnew_f(a)\n" self.assertEqual(expected, refactored) def test_variable_writes_in_the_same_line_as_variable_read2(self): - code = dedent('''\ + code = dedent("""\ a = 1 a += 1 - ''') - start = code.index('\n') + 1 + """) + start = code.index("\n") + 1 end = len(code) - refactored = self.do_extract_method(code, start, end, 'new_f', - global_=True) - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_f", global_=True) + expected = dedent("""\ a = 1 def new_f(a): a += 1 new_f(a) - ''') + """) self.assertEqual(expected, refactored) def test_variable_writes_in_the_same_line_as_variable_read3(self): - code = dedent('''\ + code = dedent("""\ a = 1 a += 1 print(a) - ''') + """) start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, 'new_f') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_f") + expected = dedent("""\ a = 1 def new_f(a): @@ -825,17 +884,17 @@ def new_f(a): a = new_f(a) print(a) - ''') + """) self.assertEqual(expected, refactored) def test_variable_writes_only(self): - code = dedent('''\ + code = dedent("""\ i = 1 print(i) - ''') + """) start, end = self._convert_line_range_to_offset(code, 1, 1) - refactored = self.do_extract_method(code, start, end, 'new_f') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_f") + expected = dedent("""\ def new_f(): i = 1 @@ -843,299 +902,343 @@ def new_f(): i = new_f() print(i) - ''') + """) self.assertEqual(expected, refactored) def test_variable_and_similar_expressions(self): - code = 'a = 1\nb = 1\n' - start = code.index('1') + code = "a = 1\nb = 1\n" + start = code.index("1") end = start + 1 - refactored = self.do_extract_variable(code, start, end, - 'one', similar=True) - expected = 'one = 1\na = one\nb = one\n' + refactored = self.do_extract_variable(code, start, end, "one", similar=True) + expected = "one = 1\na = one\nb = one\n" self.assertEqual(expected, refactored) def test_definition_should_appear_before_the_first_use(self): - code = 'a = 1\nb = 1\n' - start = code.rindex('1') + code = "a = 1\nb = 1\n" + start = code.rindex("1") end = start + 1 - refactored = self.do_extract_variable(code, start, end, - 'one', similar=True) - expected = 'one = 1\na = one\nb = one\n' + refactored = self.do_extract_variable(code, start, end, "one", similar=True) + expected = "one = 1\na = one\nb = one\n" self.assertEqual(expected, refactored) def test_extract_method_and_similar_expressions(self): - code = 'a = 1\nb = 1\n' - start = code.index('1') + code = "a = 1\nb = 1\n" + start = code.index("1") end = start + 1 - refactored = self.do_extract_method(code, start, end, - 'one', similar=True) - expected = '\ndef one():\n return 1\n\na = one()\nb = one()\n' + refactored = self.do_extract_method(code, start, end, "one", similar=True) + expected = "\ndef one():\n return 1\n\na = one()\nb = one()\n" self.assertEqual(expected, refactored) def test_simple_extract_method_and_similar_statements(self): - code = 'class AClass(object):\n\n' \ - ' def func1(self):\n a = 1 + 2\n b = a\n' \ - ' def func2(self):\n a = 1 + 2\n b = a\n' + code = ( + "class AClass(object):\n\n" + " def func1(self):\n a = 1 + 2\n b = a\n" + " def func2(self):\n a = 1 + 2\n b = a\n" + ) start, end = self._convert_line_range_to_offset(code, 4, 4) - refactored = self.do_extract_method(code, start, end, - 'new_func', similar=True) - expected = 'class AClass(object):\n\n' \ - ' def func1(self):\n' \ - ' a = self.new_func()\n b = a\n\n' \ - ' def new_func(self):\n' \ - ' a = 1 + 2\n return a\n' \ - ' def func2(self):\n' \ - ' a = self.new_func()\n b = a\n' + refactored = self.do_extract_method(code, start, end, "new_func", similar=True) + expected = ( + "class AClass(object):\n\n" + " def func1(self):\n" + " a = self.new_func()\n b = a\n\n" + " def new_func(self):\n" + " a = 1 + 2\n return a\n" + " def func2(self):\n" + " a = self.new_func()\n b = a\n" + ) self.assertEqual(expected, refactored) def test_extract_method_and_similar_statements2(self): - code = 'class AClass(object):\n\n' \ - ' def func1(self, p1):\n a = p1 + 2\n' \ - ' def func2(self, p2):\n a = p2 + 2\n' - start = code.rindex('p1') - end = code.index('2\n') + 1 - refactored = self.do_extract_method(code, start, end, - 'new_func', similar=True) - expected = 'class AClass(object):\n\n' \ - ' def func1(self, p1):\n ' \ - 'a = self.new_func(p1)\n\n' \ - ' def new_func(self, p1):\n return p1 + 2\n' \ - ' def func2(self, p2):\n a = self.new_func(p2)\n' + code = ( + "class AClass(object):\n\n" + " def func1(self, p1):\n a = p1 + 2\n" + " def func2(self, p2):\n a = p2 + 2\n" + ) + start = code.rindex("p1") + end = code.index("2\n") + 1 + refactored = self.do_extract_method(code, start, end, "new_func", similar=True) + expected = ( + "class AClass(object):\n\n" + " def func1(self, p1):\n " + "a = self.new_func(p1)\n\n" + " def new_func(self, p1):\n return p1 + 2\n" + " def func2(self, p2):\n a = self.new_func(p2)\n" + ) self.assertEqual(expected, refactored) def test_extract_method_and_similar_sttemnts_return_is_different(self): - code = 'class AClass(object):\n\n' \ - ' def func1(self, p1):\n a = p1 + 2\n' \ - ' def func2(self, p2):\n self.attr = p2 + 2\n' - start = code.rindex('p1') - end = code.index('2\n') + 1 - refactored = self.do_extract_method(code, start, end, - 'new_func', similar=True) - expected = 'class AClass(object):\n\n' \ - ' def func1(self, p1):' \ - '\n a = self.new_func(p1)\n\n' \ - ' def new_func(self, p1):\n return p1 + 2\n' \ - ' def func2(self, p2):\n' \ - ' self.attr = self.new_func(p2)\n' + code = ( + "class AClass(object):\n\n" + " def func1(self, p1):\n a = p1 + 2\n" + " def func2(self, p2):\n self.attr = p2 + 2\n" + ) + start = code.rindex("p1") + end = code.index("2\n") + 1 + refactored = self.do_extract_method(code, start, end, "new_func", similar=True) + expected = ( + "class AClass(object):\n\n" + " def func1(self, p1):" + "\n a = self.new_func(p1)\n\n" + " def new_func(self, p1):\n return p1 + 2\n" + " def func2(self, p2):\n" + " self.attr = self.new_func(p2)\n" + ) self.assertEqual(expected, refactored) def test_extract_method_and_similar_sttemnts_overlapping_regions(self): - code = 'def func(p):\n' \ - ' a = p\n' \ - ' b = a\n' \ - ' c = b\n' \ - ' d = c\n' \ - ' return d' - start = code.index('a') - end = code.rindex('a') + 1 - refactored = self.do_extract_method( - code, start, end, 'new_func', similar=True) - expected = 'def func(p):\n' \ - ' b = new_func(p)\n' \ - ' d = new_func(b)\n' \ - ' return d\n' \ - 'def new_func(p):\n' \ - ' a = p\n' \ - ' b = a\n' \ - ' return b\n' + code = ( + "def func(p):\n" + " a = p\n" + " b = a\n" + " c = b\n" + " d = c\n" + " return d" + ) + start = code.index("a") + end = code.rindex("a") + 1 + refactored = self.do_extract_method(code, start, end, "new_func", similar=True) + expected = ( + "def func(p):\n" + " b = new_func(p)\n" + " d = new_func(b)\n" + " return d\n" + "def new_func(p):\n" + " a = p\n" + " b = a\n" + " return b\n" + ) self.assertEqual(expected, refactored) def test_definition_should_appear_where_it_is_visible(self): - code = 'if True:\n a = 1\nelse:\n b = 1\n' - start = code.rindex('1') + code = "if True:\n a = 1\nelse:\n b = 1\n" + start = code.rindex("1") end = start + 1 - refactored = self.do_extract_variable(code, start, end, - 'one', similar=True) - expected = 'one = 1\nif True:\n a = one\nelse:\n b = one\n' + refactored = self.do_extract_variable(code, start, end, "one", similar=True) + expected = "one = 1\nif True:\n a = one\nelse:\n b = one\n" self.assertEqual(expected, refactored) def test_extract_variable_and_similar_statements_in_classes(self): - code = 'class AClass(object):\n\n' \ - ' def func1(self):\n a = 1\n' \ - ' def func2(self):\n b = 1\n' - start = code.index(' 1') + 1 - refactored = self.do_extract_variable(code, start, start + 1, - 'one', similar=True) - expected = 'class AClass(object):\n\n' \ - ' def func1(self):\n one = 1\n a = one\n' \ - ' def func2(self):\n b = 1\n' + code = ( + "class AClass(object):\n\n" + " def func1(self):\n a = 1\n" + " def func2(self):\n b = 1\n" + ) + start = code.index(" 1") + 1 + refactored = self.do_extract_variable( + code, start, start + 1, "one", similar=True + ) + expected = ( + "class AClass(object):\n\n" + " def func1(self):\n one = 1\n a = one\n" + " def func2(self):\n b = 1\n" + ) self.assertEqual(expected, refactored) def test_extract_method_in_staticmethods(self): - code = 'class AClass(object):\n\n' \ - ' @staticmethod\n def func2():\n b = 1\n' - start = code.index(' 1') + 1 - refactored = self.do_extract_method(code, start, start + 1, - 'one', similar=True) - expected = 'class AClass(object):\n\n' \ - ' @staticmethod\n def func2():\n' \ - ' b = AClass.one()\n\n' \ - ' @staticmethod\n def one():\n' \ - ' return 1\n' + code = ( + "class AClass(object):\n\n" + " @staticmethod\n def func2():\n b = 1\n" + ) + start = code.index(" 1") + 1 + refactored = self.do_extract_method(code, start, start + 1, "one", similar=True) + expected = ( + "class AClass(object):\n\n" + " @staticmethod\n def func2():\n" + " b = AClass.one()\n\n" + " @staticmethod\n def one():\n" + " return 1\n" + ) self.assertEqual(expected, refactored) def test_extract_normal_method_with_staticmethods(self): - code = 'class AClass(object):\n\n' \ - ' @staticmethod\n def func1():\n b = 1\n' \ - ' def func2(self):\n b = 1\n' - start = code.rindex(' 1') + 1 - refactored = self.do_extract_method(code, start, start + 1, - 'one', similar=True) - expected = 'class AClass(object):\n\n' \ - ' @staticmethod\n def func1():\n b = 1\n' \ - ' def func2(self):\n b = self.one()\n\n' \ - ' def one(self):\n return 1\n' + code = ( + "class AClass(object):\n\n" + " @staticmethod\n def func1():\n b = 1\n" + " def func2(self):\n b = 1\n" + ) + start = code.rindex(" 1") + 1 + refactored = self.do_extract_method(code, start, start + 1, "one", similar=True) + expected = ( + "class AClass(object):\n\n" + " @staticmethod\n def func1():\n b = 1\n" + " def func2(self):\n b = self.one()\n\n" + " def one(self):\n return 1\n" + ) self.assertEqual(expected, refactored) def test_extract_variable_with_no_new_lines_at_the_end(self): - code = 'a_var = 10' - start = code.index('10') + code = "a_var = 10" + start = code.index("10") end = start + 2 - refactored = self.do_extract_variable(code, start, end, 'new_var') - expected = 'new_var = 10\na_var = new_var' + refactored = self.do_extract_variable(code, start, end, "new_var") + expected = "new_var = 10\na_var = new_var" self.assertEqual(expected, refactored) def test_extract_method_containing_return_in_functions(self): - code = 'def f(arg):\n return arg\nprint(f(1))\n' + code = "def f(arg):\n return arg\nprint(f(1))\n" start, end = self._convert_line_range_to_offset(code, 1, 3) - refactored = self.do_extract_method(code, start, end, 'a_func') - expected = '\ndef a_func():\n def f(arg):\n return arg\n' \ - ' print(f(1))\n\na_func()\n' + refactored = self.do_extract_method(code, start, end, "a_func") + expected = ( + "\ndef a_func():\n def f(arg):\n return arg\n" + " print(f(1))\n\na_func()\n" + ) self.assertEqual(expected, refactored) def test_extract_method_and_varying_first_parameter(self): - code = 'class C(object):\n' \ - ' def f1(self):\n print(str(self))\n' \ - ' def f2(self):\n print(str(1))\n' - start = code.index('print(') + 6 - end = code.index('))\n') + 1 - refactored = self.do_extract_method(code, start, end, - 'to_str', similar=True) - expected = 'class C(object):\n' \ - ' def f1(self):\n print(self.to_str())\n\n' \ - ' def to_str(self):\n return str(self)\n' \ - ' def f2(self):\n print(str(1))\n' + code = ( + "class C(object):\n" + " def f1(self):\n print(str(self))\n" + " def f2(self):\n print(str(1))\n" + ) + start = code.index("print(") + 6 + end = code.index("))\n") + 1 + refactored = self.do_extract_method(code, start, end, "to_str", similar=True) + expected = ( + "class C(object):\n" + " def f1(self):\n print(self.to_str())\n\n" + " def to_str(self):\n return str(self)\n" + " def f2(self):\n print(str(1))\n" + ) self.assertEqual(expected, refactored) def test_extract_method_when_an_attribute_exists_in_function_scope(self): - code = 'class A(object):\n def func(self):\n pass\n' \ - 'a = A()\n' \ - 'def f():\n' \ - ' func = a.func()\n' \ - ' print(func)\n' + code = ( + "class A(object):\n def func(self):\n pass\n" + "a = A()\n" + "def f():\n" + " func = a.func()\n" + " print(func)\n" + ) start, end = self._convert_line_range_to_offset(code, 6, 6) - refactored = self.do_extract_method(code, start, end, 'g') - refactored = refactored[refactored.index('A()') + 4:] - expected = 'def f():\n func = g()\n print(func)\n\n' \ - 'def g():\n func = a.func()\n return func\n' + refactored = self.do_extract_method(code, start, end, "g") + refactored = refactored[refactored.index("A()") + 4 :] + expected = ( + "def f():\n func = g()\n print(func)\n\n" + "def g():\n func = a.func()\n return func\n" + ) self.assertEqual(expected, refactored) def test_global_option_for_extract_method(self): - code = 'def a_func():\n print(1)\n' + code = "def a_func():\n print(1)\n" start, end = self._convert_line_range_to_offset(code, 2, 2) - refactored = self.do_extract_method(code, start, end, - 'extracted', global_=True) - expected = 'def a_func():\n extracted()\n\n' \ - 'def extracted():\n print(1)\n' + refactored = self.do_extract_method(code, start, end, "extracted", global_=True) + expected = ( + "def a_func():\n extracted()\n\n" "def extracted():\n print(1)\n" + ) self.assertEqual(expected, refactored) def test_global_extract_method(self): - code = 'class AClass(object):\n\n' \ - ' def a_func(self):\n print(1)\n' + code = "class AClass(object):\n\n" " def a_func(self):\n print(1)\n" start, end = self._convert_line_range_to_offset(code, 4, 4) - refactored = self.do_extract_method(code, start, end, - 'new_func', global_=True) - expected = 'class AClass(object):\n\n' \ - ' def a_func(self):\n new_func()\n\n' \ - 'def new_func():\n print(1)\n' + refactored = self.do_extract_method(code, start, end, "new_func", global_=True) + expected = ( + "class AClass(object):\n\n" + " def a_func(self):\n new_func()\n\n" + "def new_func():\n print(1)\n" + ) self.assertEqual(expected, refactored) def test_extract_method_with_multiple_methods(self): # noqa - code = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' print(1)\n\n' \ - ' def another_func(self):\n' \ - ' pass\n' + code = ( + "class AClass(object):\n" + " def a_func(self):\n" + " print(1)\n\n" + " def another_func(self):\n" + " pass\n" + ) start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, - 'new_func', global_=True) - expected = 'class AClass(object):\n' \ - ' def a_func(self):\n' \ - ' new_func()\n\n' \ - ' def another_func(self):\n' \ - ' pass\n\n' \ - 'def new_func():\n' \ - ' print(1)\n' + refactored = self.do_extract_method(code, start, end, "new_func", global_=True) + expected = ( + "class AClass(object):\n" + " def a_func(self):\n" + " new_func()\n\n" + " def another_func(self):\n" + " pass\n\n" + "def new_func():\n" + " print(1)\n" + ) self.assertEqual(expected, refactored) def test_where_to_seach_when_extracting_global_names(self): - code = 'def a():\n return 1\ndef b():\n return 1\nb = 1\n' - start = code.index('1') + code = "def a():\n return 1\ndef b():\n return 1\nb = 1\n" + start = code.index("1") end = start + 1 - refactored = self.do_extract_variable(code, start, end, 'one', - similar=True, global_=True) - expected = 'def a():\n return one\none = 1\n' \ - 'def b():\n return one\nb = one\n' + refactored = self.do_extract_variable( + code, start, end, "one", similar=True, global_=True + ) + expected = ( + "def a():\n return one\none = 1\n" "def b():\n return one\nb = one\n" + ) self.assertEqual(expected, refactored) def test_extracting_pieces_with_distinct_temp_names(self): - code = 'a = 1\nprint(a)\nb = 1\nprint(b)\n' - start = code.index('a') - end = code.index('\nb') - refactored = self.do_extract_method(code, start, end, 'f', - similar=True, global_=True) - expected = '\ndef f():\n a = 1\n print(a)\n\nf()\nf()\n' + code = "a = 1\nprint(a)\nb = 1\nprint(b)\n" + start = code.index("a") + end = code.index("\nb") + refactored = self.do_extract_method( + code, start, end, "f", similar=True, global_=True + ) + expected = "\ndef f():\n a = 1\n print(a)\n\nf()\nf()\n" self.assertEqual(expected, refactored) def test_extract_methods_in_glob_funcs_should_be_glob(self): - code = 'def f():\n a = 1\ndef g():\n b = 1\n' - start = code.rindex('1') - refactored = self.do_extract_method(code, start, start + 1, 'one', - similar=True, global_=False) - expected = 'def f():\n a = one()\ndef g():\n b = one()\n\n' \ - 'def one():\n return 1\n' + code = "def f():\n a = 1\ndef g():\n b = 1\n" + start = code.rindex("1") + refactored = self.do_extract_method( + code, start, start + 1, "one", similar=True, global_=False + ) + expected = ( + "def f():\n a = one()\ndef g():\n b = one()\n\n" + "def one():\n return 1\n" + ) self.assertEqual(expected, refactored) def test_extract_methods_in_glob_funcs_should_be_glob_2(self): - code = 'if 1:\n var = 2\n' - start = code.rindex('2') - refactored = self.do_extract_method(code, start, start + 1, 'two', - similar=True, global_=False) - expected = '\ndef two():\n return 2\n\nif 1:\n var = two()\n' + code = "if 1:\n var = 2\n" + start = code.rindex("2") + refactored = self.do_extract_method( + code, start, start + 1, "two", similar=True, global_=False + ) + expected = "\ndef two():\n return 2\n\nif 1:\n var = two()\n" self.assertEqual(expected, refactored) def test_extract_method_and_try_blocks(self): - code = 'def f():\n try:\n pass\n' \ - ' except Exception:\n pass\n' + code = ( + "def f():\n try:\n pass\n" " except Exception:\n pass\n" + ) start, end = self._convert_line_range_to_offset(code, 2, 5) - refactored = self.do_extract_method(code, start, end, 'g') - expected = 'def f():\n g()\n\ndef g():\n try:\n pass\n' \ - ' except Exception:\n pass\n' + refactored = self.do_extract_method(code, start, end, "g") + expected = ( + "def f():\n g()\n\ndef g():\n try:\n pass\n" + " except Exception:\n pass\n" + ) self.assertEqual(expected, refactored) def test_extract_and_not_passing_global_functions(self): - code = 'def next(p):\n return p + 1\nvar = next(1)\n' - start = code.rindex('next') - refactored = self.do_extract_method(code, start, len(code) - 1, 'two') - expected = 'def next(p):\n return p + 1\n' \ - '\ndef two():\n return next(1)\n\nvar = two()\n' + code = "def next(p):\n return p + 1\nvar = next(1)\n" + start = code.rindex("next") + refactored = self.do_extract_method(code, start, len(code) - 1, "two") + expected = ( + "def next(p):\n return p + 1\n" + "\ndef two():\n return next(1)\n\nvar = two()\n" + ) self.assertEqual(expected, refactored) def test_extracting_with_only_one_return(self): - code = 'def f():\n var = 1\n return var\n' + code = "def f():\n var = 1\n return var\n" start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'g') - expected = 'def f():\n return g()\n\n' \ - 'def g():\n var = 1\n return var\n' + refactored = self.do_extract_method(code, start, end, "g") + expected = ( + "def f():\n return g()\n\n" "def g():\n var = 1\n return var\n" + ) self.assertEqual(expected, refactored) def test_extracting_variable_and_implicit_continuations(self): code = 's = ("1"\n "2")\n' start = code.index('"') end = code.rindex('"') + 1 - refactored = self.do_extract_variable(code, start, end, 's2') + refactored = self.do_extract_variable(code, start, end, "s2") expected = 's2 = "1" "2"\ns = (s2)\n' self.assertEqual(expected, refactored) @@ -1143,23 +1246,22 @@ def test_extracting_method_and_implicit_continuations(self): code = 's = ("1"\n "2")\n' start = code.index('"') end = code.rindex('"') + 1 - refactored = self.do_extract_method(code, start, end, 'f') + refactored = self.do_extract_method(code, start, end, "f") expected = '\ndef f():\n return "1" "2"\n\ns = (f())\n' self.assertEqual(expected, refactored) def test_passing_conditional_updated_vars_in_extracted(self): - code = 'def f(a):\n' \ - ' if 0:\n' \ - ' a = 1\n' \ - ' print(a)\n' + code = "def f(a):\n" " if 0:\n" " a = 1\n" " print(a)\n" start, end = self._convert_line_range_to_offset(code, 2, 4) - refactored = self.do_extract_method(code, start, end, 'g') - expected = 'def f(a):\n' \ - ' g(a)\n\n' \ - 'def g(a):\n' \ - ' if 0:\n' \ - ' a = 1\n' \ - ' print(a)\n' + refactored = self.do_extract_method(code, start, end, "g") + expected = ( + "def f(a):\n" + " g(a)\n\n" + "def g(a):\n" + " if 0:\n" + " a = 1\n" + " print(a)\n" + ) self.assertEqual(expected, refactored) def test_returning_conditional_updated_vars_in_extracted(self): @@ -1170,7 +1272,7 @@ def f(a): print(a) """) start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'g') + refactored = self.do_extract_method(code, start, end, "g") expected = dedent("""\ def f(a): a = g(a) @@ -1184,41 +1286,44 @@ def g(a): self.assertEqual(expected, refactored) def test_extract_method_with_variables_possibly_written_to(self): - code = "def a_func(b):\n" \ - " if b > 0:\n" \ - " a = 2\n" \ - " print(a)\n" + code = "def a_func(b):\n" " if b > 0:\n" " a = 2\n" " print(a)\n" start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'extracted') - expected = "def a_func(b):\n" \ - " a = extracted(b)\n" \ - " print(a)\n\n" \ - "def extracted(b):\n" \ - " if b > 0:\n" \ - " a = 2\n" \ - " return a\n" + refactored = self.do_extract_method(code, start, end, "extracted") + expected = ( + "def a_func(b):\n" + " a = extracted(b)\n" + " print(a)\n\n" + "def extracted(b):\n" + " if b > 0:\n" + " a = 2\n" + " return a\n" + ) self.assertEqual(expected, refactored) def test_extract_method_with_list_comprehension(self): - code = "def foo():\n" \ - " x = [e for e in []]\n" \ - " f = 23\n" \ - "\n" \ - " for e, f in []:\n" \ - " def bar():\n" \ - " e[42] = 1\n" + code = ( + "def foo():\n" + " x = [e for e in []]\n" + " f = 23\n" + "\n" + " for e, f in []:\n" + " def bar():\n" + " e[42] = 1\n" + ) start, end = self._convert_line_range_to_offset(code, 4, 7) - refactored = self.do_extract_method(code, start, end, 'baz') - expected = "def foo():\n" \ - " x = [e for e in []]\n" \ - " f = 23\n" \ - "\n" \ - " baz()\n" \ - "\n" \ - "def baz():\n" \ - " for e, f in []:\n" \ - " def bar():\n" \ - " e[42] = 1\n" + refactored = self.do_extract_method(code, start, end, "baz") + expected = ( + "def foo():\n" + " x = [e for e in []]\n" + " f = 23\n" + "\n" + " baz()\n" + "\n" + "def baz():\n" + " for e, f in []:\n" + " def bar():\n" + " e[42] = 1\n" + ) self.assertEqual(expected, refactored) def test_extract_method_with_list_comprehension_and_iter(self): @@ -1232,7 +1337,7 @@ def bar(): x[42] = 1 """) start, end = self._convert_line_range_to_offset(code, 4, 7) - refactored = self.do_extract_method(code, start, end, 'baz') + refactored = self.do_extract_method(code, start, end, "baz") expected = dedent("""\ def foo(): x = [e for e in []] @@ -1248,89 +1353,101 @@ def bar(): self.assertEqual(expected, refactored) def test_extract_method_with_list_comprehension_and_orelse(self): - code = "def foo():\n" \ - " x = [e for e in []]\n" \ - " f = 23\n" \ - "\n" \ - " for e, f in []:\n" \ - " def bar():\n" \ - " e[42] = 1\n" + code = ( + "def foo():\n" + " x = [e for e in []]\n" + " f = 23\n" + "\n" + " for e, f in []:\n" + " def bar():\n" + " e[42] = 1\n" + ) start, end = self._convert_line_range_to_offset(code, 4, 7) - refactored = self.do_extract_method(code, start, end, 'baz') - expected = "def foo():\n" \ - " x = [e for e in []]\n" \ - " f = 23\n" \ - "\n" \ - " baz()\n" \ - "\n" \ - "def baz():\n" \ - " for e, f in []:\n" \ - " def bar():\n" \ - " e[42] = 1\n" + refactored = self.do_extract_method(code, start, end, "baz") + expected = ( + "def foo():\n" + " x = [e for e in []]\n" + " f = 23\n" + "\n" + " baz()\n" + "\n" + "def baz():\n" + " for e, f in []:\n" + " def bar():\n" + " e[42] = 1\n" + ) self.assertEqual(expected, refactored) def test_extract_function_with_for_else_statemant(self): - code = 'def a_func():\n for i in range(10):\n a = i\n ' \ - 'else:\n a = None\n' - start = code.index('for') + code = ( + "def a_func():\n for i in range(10):\n a = i\n " + "else:\n a = None\n" + ) + start = code.index("for") end = len(code) - 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n new_func()\n\n' \ - 'def new_func():\n' \ - ' for i in range(10):\n a = i\n else:\n' \ - ' a = None\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n new_func()\n\n" + "def new_func():\n" + " for i in range(10):\n a = i\n else:\n" + " a = None\n" + ) self.assertEqual(expected, refactored) def test_extract_function_with_for_else_statemant_more(self): """TODO: fixed code to test passed """ - code = 'def a_func():\n'\ - ' for i in range(10):\n'\ - ' a = i\n'\ - ' else:\n'\ - ' for i in range(5):\n'\ - ' b = i\n'\ - ' else:\n'\ - ' b = None\n'\ - ' a = None\n' - - start = code.index('for') + code = ( + "def a_func():\n" + " for i in range(10):\n" + " a = i\n" + " else:\n" + " for i in range(5):\n" + " b = i\n" + " else:\n" + " b = None\n" + " a = None\n" + ) + + start = code.index("for") end = len(code) - 1 - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = 'def a_func():\n new_func()\n\n' \ - 'def new_func():\n' \ - ' for i in range(10):\n'\ - ' a = i\n'\ - ' else:\n'\ - ' for i in range(5):\n'\ - ' b = i\n'\ - ' else:\n'\ - ' b = None\n'\ - ' a = None\n' + refactored = self.do_extract_method(code, start, end, "new_func") + expected = ( + "def a_func():\n new_func()\n\n" + "def new_func():\n" + " for i in range(10):\n" + " a = i\n" + " else:\n" + " for i in range(5):\n" + " b = i\n" + " else:\n" + " b = None\n" + " a = None\n" + ) self.assertEqual(expected, refactored) def test_extract_function_with_for_else_statemant_outside_loops(self): - code = dedent('''\ + code = dedent("""\ def a_func(): for i in range(10): a = i else: a=None - ''') - start = code.index('a = i') + """) + start = code.index("a = i") end = len(code) - 1 with self.assertRaises(rope.base.exceptions.RefactoringError): - self.do_extract_method(code, start, end, 'new_func') + self.do_extract_method(code, start, end, "new_func") def test_extract_function_with_inline_assignment_in_method(self): - code = dedent('''\ + code = dedent("""\ def foo(): i = 1 i += 1 print(i) - ''') + """) start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def foo(): i = 1 i = new_func(i) @@ -1339,20 +1456,20 @@ def foo(): def new_func(i): i += 1 return i - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_extract_function_statement_with_inline_assignment_in_condition(self): - code = dedent('''\ + code = dedent("""\ def foo(a): if i := a == 5: i += 1 print(i) - ''') + """) start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def foo(a): i = new_func(a) print(i) @@ -1361,21 +1478,23 @@ def new_func(a): if i := a == 5: i += 1 return i - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_extract_function_expression_with_inline_assignment_in_condition(self): - code = dedent('''\ + code = dedent("""\ def foo(a): if i := a == 5: i += 1 print(i) - ''') - extract_target = 'i := a == 5' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + """) + extract_target = "i := a == 5" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def foo(a): if i := new_func(a): i += 1 @@ -1383,22 +1502,24 @@ def foo(a): def new_func(a): return (i := a == 5) - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_extract_function_expression_with_inline_assignment_complex(self): - code = dedent('''\ + code = dedent("""\ def foo(a): if i := a == (c := 5): i += 1 c += 1 print(i) - ''') - extract_target = 'i := a == (c := 5)' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + """) + extract_target = "i := a == (c := 5)" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def foo(a): if i, c := new_func(a): i += 1 @@ -1407,64 +1528,71 @@ def foo(a): def new_func(a): return (i := a == (c := 5)) - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.8') - def test_extract_function_expression_with_inline_assignment_in_inner_expression(self): - code = dedent('''\ + @testutils.only_for_versions_higher("3.8") + def test_extract_function_expression_with_inline_assignment_in_inner_expression( + self, + ): + code = dedent("""\ def foo(a): if a == (c := 5): c += 1 print(i) - ''') - extract_target = 'a == (c := 5)' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - with self.assertRaisesRegexp(rope.base.exceptions.RefactoringError, 'Extracted piece cannot contain named expression \\(:= operator\\).'): - self.do_extract_method(code, start, end, 'new_func') + """) + extract_target = "a == (c := 5)" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + with self.assertRaisesRegexp( + rope.base.exceptions.RefactoringError, + "Extracted piece cannot contain named expression \\(:= operator\\).", + ): + self.do_extract_method(code, start, end, "new_func") def test_extract_exec(self): - code = dedent('''\ + code = dedent("""\ exec("def f(): pass", {}) - ''') + """) start, end = self._convert_line_range_to_offset(code, 1, 1) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def new_func(): exec("def f(): pass", {}) new_func() - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_extract_exec_statement(self): - code = dedent('''\ + code = dedent("""\ exec "def f(): pass" in {} - ''') + """) start, end = self._convert_line_range_to_offset(code, 1, 1) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def new_func(): exec "def f(): pass" in {} new_func() - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_extract_async_function(self): - code = dedent('''\ + code = dedent("""\ async def my_func(my_list): for x in my_list: var = x + 1 return var - ''') + """) start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ async def my_func(my_list): for x in my_list: var = new_func(x) @@ -1473,21 +1601,21 @@ async def my_func(my_list): def new_func(x): var = x + 1 return var - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_extract_inner_async_function(self): - code = dedent('''\ + code = dedent("""\ def my_func(my_list): async def inner_func(my_list): for x in my_list: var = x + 1 return inner_func - ''') + """) start, end = self._convert_line_range_to_offset(code, 2, 4) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def my_func(my_list): inner_func = new_func(my_list) return inner_func @@ -1497,21 +1625,21 @@ async def inner_func(my_list): for x in my_list: var = x + 1 return inner_func - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_extract_around_inner_async_function(self): - code = dedent('''\ + code = dedent("""\ def my_func(lst): async def inner_func(obj): for x in obj: var = x + 1 return map(inner_func, lst) - ''') + """) start, end = self._convert_line_range_to_offset(code, 5, 5) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ def my_func(lst): async def inner_func(obj): for x in obj: @@ -1520,20 +1648,20 @@ async def inner_func(obj): def new_func(inner_func, lst): return map(inner_func, lst) - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_extract_refactor_around_async_for_loop(self): - code = dedent('''\ + code = dedent("""\ async def my_func(my_list): async for x in my_list: var = x + 1 return var - ''') + """) start, end = self._convert_line_range_to_offset(code, 3, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ async def my_func(my_list): async for x in my_list: var = new_func(x) @@ -1542,38 +1670,41 @@ async def my_func(my_list): def new_func(x): var = x + 1 return var - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.5') - @testutils.only_for_versions_lower('3.8') + @testutils.only_for_versions_higher("3.5") + @testutils.only_for_versions_lower("3.8") def test_extract_refactor_containing_async_for_loop_should_error_before_py38(self): """ Refactoring async/await syntaxes is only supported in Python 3.8 and higher because support for ast.PyCF_ALLOW_TOP_LEVEL_AWAIT was only added to the standard library in Python 3.8. """ - code = dedent('''\ + code = dedent("""\ async def my_func(my_list): async for x in my_list: var = x + 1 return var - ''') + """) start, end = self._convert_line_range_to_offset(code, 2, 3) - with self.assertRaisesRegexp(rope.base.exceptions.RefactoringError, 'Extracted piece can only have async/await statements if Rope is running on Python 3.8 or higher'): - self.do_extract_method(code, start, end, 'new_func') + with self.assertRaisesRegexp( + rope.base.exceptions.RefactoringError, + "Extracted piece can only have async/await statements if Rope is running on Python 3.8 or higher", + ): + self.do_extract_method(code, start, end, "new_func") - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_extract_refactor_containing_async_for_loop_is_supported_after_py38(self): - code = dedent('''\ + code = dedent("""\ async def my_func(my_list): async for x in my_list: var = x + 1 return var - ''') + """) start, end = self._convert_line_range_to_offset(code, 2, 3) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ async def my_func(my_list): var = new_func(my_list) return var @@ -1582,21 +1713,21 @@ def new_func(my_list): async for x in my_list: var = x + 1 return var - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_extract_await_expression(self): - code = dedent('''\ + code = dedent("""\ async def my_func(my_list): for url in my_list: resp = await request(url) return resp - ''') - selected = 'request(url)' + """) + selected = "request(url)" start, end = code.index(selected), code.index(selected) + len(selected) - refactored = self.do_extract_method(code, start, end, 'new_func') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "new_func") + expected = dedent("""\ async def my_func(my_list): for url in my_list: resp = await new_func(url) @@ -1604,20 +1735,24 @@ async def my_func(my_list): def new_func(url): return request(url) - ''') + """) self.assertEqual(expected, refactored) def test_extract_to_staticmethod(self): - code = dedent('''\ + code = dedent("""\ class A: def first_method(self): a_var = 1 b_var = a_var + 1 - ''') - extract_target = 'a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'second_method', kind="staticmethod") - expected = dedent('''\ + """) + extract_target = "a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method( + code, start, end, "second_method", kind="staticmethod" + ) + expected = dedent("""\ class A: def first_method(self): a_var = 1 @@ -1626,20 +1761,24 @@ def first_method(self): @staticmethod def second_method(a_var): return a_var + 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_to_staticmethod_when_self_in_body(self): - code = dedent('''\ + code = dedent("""\ class A: def first_method(self): a_var = 1 b_var = self.a_var + 1 - ''') - extract_target = 'self.a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'second_method', kind="staticmethod") - expected = dedent('''\ + """) + extract_target = "self.a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method( + code, start, end, "second_method", kind="staticmethod" + ) + expected = dedent("""\ class A: def first_method(self): a_var = 1 @@ -1648,31 +1787,37 @@ def first_method(self): @staticmethod def second_method(self): return self.a_var + 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_from_function_to_staticmethod_raises_exception(self): - code = dedent('''\ + code = dedent("""\ def first_method(): a_var = 1 b_var = a_var + 1 - ''') - extract_target = 'a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - with self.assertRaisesRegexp(rope.base.exceptions.RefactoringError, "Cannot extract to staticmethod/classmethod outside class"): - self.do_extract_method(code, start, end, 'second_method', kind="staticmethod") + """) + extract_target = "a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + with self.assertRaisesRegexp( + rope.base.exceptions.RefactoringError, + "Cannot extract to staticmethod/classmethod outside class", + ): + self.do_extract_method( + code, start, end, "second_method", kind="staticmethod" + ) def test_extract_method_in_classmethods(self): - code = dedent('''\ + code = dedent("""\ class AClass(object): @classmethod def func2(cls): b = 1 - ''') - start = code.index(' 1') + 1 - refactored = self.do_extract_method(code, start, start + 1, - 'one', similar=True) - expected = dedent('''\ + """) + start = code.index(" 1") + 1 + refactored = self.do_extract_method(code, start, start + 1, "one", similar=True) + expected = dedent("""\ class AClass(object): @classmethod def func2(cls): @@ -1681,31 +1826,42 @@ def func2(cls): @classmethod def one(cls): return 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_from_function_to_classmethod_raises_exception(self): - code = dedent('''\ + code = dedent("""\ def first_method(): a_var = 1 b_var = a_var + 1 - ''') - extract_target = 'a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - with self.assertRaisesRegexp(rope.base.exceptions.RefactoringError, "Cannot extract to staticmethod/classmethod outside class"): - self.do_extract_method(code, start, end, 'second_method', kind="classmethod") + """) + extract_target = "a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + with self.assertRaisesRegexp( + rope.base.exceptions.RefactoringError, + "Cannot extract to staticmethod/classmethod outside class", + ): + self.do_extract_method( + code, start, end, "second_method", kind="classmethod" + ) def test_extract_to_classmethod_when_self_in_body(self): - code = dedent('''\ + code = dedent("""\ class A: def first_method(self): a_var = 1 b_var = self.a_var + 1 - ''') - extract_target = 'self.a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'second_method', kind="classmethod") - expected = dedent('''\ + """) + extract_target = "self.a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method( + code, start, end, "second_method", kind="classmethod" + ) + expected = dedent("""\ class A: def first_method(self): a_var = 1 @@ -1714,20 +1870,24 @@ def first_method(self): @classmethod def second_method(cls, self): return self.a_var + 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_to_classmethod(self): - code = dedent('''\ + code = dedent("""\ class A: def first_method(self): a_var = 1 b_var = a_var + 1 - ''') - extract_target = 'a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'second_method', kind="classmethod") - expected = dedent('''\ + """) + extract_target = "a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method( + code, start, end, "second_method", kind="classmethod" + ) + expected = dedent("""\ class A: def first_method(self): a_var = 1 @@ -1736,20 +1896,22 @@ def first_method(self): @classmethod def second_method(cls, a_var): return a_var + 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_to_classmethod_when_name_starts_with_at_sign(self): - code = dedent('''\ + code = dedent("""\ class A: def first_method(self): a_var = 1 b_var = a_var + 1 - ''') - extract_target = 'a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, '@second_method') - expected = dedent('''\ + """) + extract_target = "a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "@second_method") + expected = dedent("""\ class A: def first_method(self): a_var = 1 @@ -1758,20 +1920,22 @@ def first_method(self): @classmethod def second_method(cls, a_var): return a_var + 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_to_staticmethod_when_name_starts_with_dollar_sign(self): - code = dedent('''\ + code = dedent("""\ class A: def first_method(self): a_var = 1 b_var = a_var + 1 - ''') - extract_target = 'a_var + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, '$second_method') - expected = dedent('''\ + """) + extract_target = "a_var + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "$second_method") + expected = dedent("""\ class A: def first_method(self): a_var = 1 @@ -1780,25 +1944,29 @@ def first_method(self): @staticmethod def second_method(a_var): return a_var + 1 - ''') + """) self.assertEqual(expected, refactored) def test_raises_exception_when_sign_in_name_and_kind_mismatch(self): - with self.assertRaisesRegexp(rope.base.exceptions.RefactoringError, "Kind and shortcut in name mismatch"): - self.do_extract_method("code", 0,1, '$second_method', kind="classmethod") + with self.assertRaisesRegexp( + rope.base.exceptions.RefactoringError, "Kind and shortcut in name mismatch" + ): + self.do_extract_method("code", 0, 1, "$second_method", kind="classmethod") def test_extracting_from_static_with_function_arg(self): - code = dedent('''\ + code = dedent("""\ class A: @staticmethod def first_method(someargs): b_var = someargs + 1 - ''') + """) - extract_target = 'someargs + 1' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, 'second_method') - expected = dedent('''\ + extract_target = "someargs + 1" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "second_method") + expected = dedent("""\ class A: @staticmethod def first_method(someargs): @@ -1807,21 +1975,24 @@ def first_method(someargs): @staticmethod def second_method(someargs): return someargs + 1 - ''') + """) self.assertEqual(expected, refactored) def test_extract_function_expression_with_assignment_to_attribute(self): - code = dedent('''\ + code = dedent("""\ class A(object): def func(self): self.var_a = 1 var_bb = self.var_a - ''') - extract_target = '= self.var_a' - start, end = code.index(extract_target)+2, code.index(extract_target)+2 + len(extract_target) - 2 - refactored = self.do_extract_method(code, start, end, 'new_func', similar=True) - expected = dedent('''\ + """) + extract_target = "= self.var_a" + start, end = ( + code.index(extract_target) + 2, + code.index(extract_target) + 2 + len(extract_target) - 2, + ) + refactored = self.do_extract_method(code, start, end, "new_func", similar=True) + expected = dedent("""\ class A(object): def func(self): self.var_a = 1 @@ -1829,21 +2000,24 @@ def func(self): def new_func(self): return self.var_a - ''') + """) self.assertEqual(expected, refactored) def test_extract_function_expression_with_assignment_index(self): - code = dedent('''\ + code = dedent("""\ class A(object): def func(self, val): self[val] = 1 var_bb = self[val] - ''') - extract_target = '= self[val]' - start, end = code.index(extract_target)+2, code.index(extract_target)+2 + len(extract_target) - 2 - refactored = self.do_extract_method(code, start, end, 'new_func', similar=True) - expected = dedent('''\ + """) + extract_target = "= self[val]" + start, end = ( + code.index(extract_target) + 2, + code.index(extract_target) + 2 + len(extract_target) - 2, + ) + refactored = self.do_extract_method(code, start, end, "new_func", similar=True) + expected = dedent("""\ class A(object): def func(self, val): self[val] = 1 @@ -1851,12 +2025,12 @@ def func(self, val): def new_func(self, val): return self[val] - ''') + """) self.assertEqual(expected, refactored) def test_extraction_method_with_global_variable(self): - code = dedent('''\ + code = dedent("""\ g = None def f(): @@ -1866,11 +2040,13 @@ def f(): f() print(g) - ''') - extract_target = 'g = 2' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, '_g') - expected = dedent('''\ + """) + extract_target = "g = 2" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "_g") + expected = dedent("""\ g = None def f(): @@ -1884,11 +2060,11 @@ def _g(): f() print(g) - ''') + """) self.assertEqual(expected, refactored) def test_extraction_method_with_global_variable_and_global_declaration(self): - code = dedent('''\ + code = dedent("""\ g = None def f(): @@ -1898,10 +2074,10 @@ def f(): f() print(g) - ''') + """) start, end = 23, 42 - refactored = self.do_extract_method(code, start, end, '_g') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "_g") + expected = dedent("""\ g = None def f(): @@ -1914,11 +2090,11 @@ def _g(): f() print(g) - ''') + """) self.assertEqual(expected, refactored) def test_extraction_one_line_with_global_variable_read_only(self): - code = dedent('''\ + code = dedent("""\ g = None def f(): @@ -1928,11 +2104,11 @@ def f(): f() print(g) - ''') - extract_target = '= g' + """) + extract_target = "= g" start, end = code.index(extract_target) + 2, code.index(extract_target) + 3 - refactored = self.do_extract_method(code, start, end, '_g') - expected = dedent('''\ + refactored = self.do_extract_method(code, start, end, "_g") + expected = dedent("""\ g = None def f(): @@ -1945,12 +2121,12 @@ def _g(): f() print(g) - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_extraction_one_line_with_global_variable(self): - code = dedent('''\ + code = dedent("""\ g = None def f(): @@ -1961,11 +2137,13 @@ def f(): f() print(g) - ''') - extract_target = 'g := 4' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, '_g') - expected = dedent('''\ + """) + extract_target = "g := 4" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "_g") + expected = dedent("""\ g = None def f(): @@ -1980,12 +2158,12 @@ def _g(): f() print(g) - ''') + """) self.assertEqual(expected, refactored) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_extraction_one_line_with_global_variable_has_postread(self): - code = dedent('''\ + code = dedent("""\ g = None def f(): @@ -1996,11 +2174,13 @@ def f(): f() print(g) - ''') - extract_target = 'g := 4' - start, end = code.index(extract_target), code.index(extract_target) + len(extract_target) - refactored = self.do_extract_method(code, start, end, '_g') - expected = dedent('''\ + """) + extract_target = "g := 4" + start, end = code.index(extract_target), code.index(extract_target) + len( + extract_target + ) + refactored = self.do_extract_method(code, start, end, "_g") + expected = dedent("""\ g = None def f(): @@ -2015,5 +2195,5 @@ def _g(): f() print(g) - ''') + """) self.assertEqual(expected, refactored) diff --git a/ropetest/refactor/importutilstest.py b/ropetest/refactor/importutilstest.py index 6ae907cba..c84e90507 100644 --- a/ropetest/refactor/importutilstest.py +++ b/ropetest/refactor/importutilstest.py @@ -10,950 +10,1034 @@ class ImportUtilsTest(unittest.TestCase): - def setUp(self): super(ImportUtilsTest, self).setUp() self.project = testutils.sample_project() self.import_tools = ImportTools(self.project) - self.mod = testutils.create_module(self.project, 'mod') - self.pkg1 = testutils.create_package(self.project, 'pkg1') - self.mod1 = testutils.create_module(self.project, 'mod1', self.pkg1) - self.pkg2 = testutils.create_package(self.project, 'pkg2') - self.mod2 = testutils.create_module(self.project, 'mod2', self.pkg2) - self.mod3 = testutils.create_module(self.project, 'mod3', self.pkg2) - p1 = testutils.create_package(self.project, 'p1') - p2 = testutils.create_package(self.project, 'p2', p1) - p3 = testutils.create_package(self.project, 'p3', p2) - m1 = testutils.create_module(self.project, 'm1', p3) # noqa - l = testutils.create_module(self.project, 'l', p3) # noqa + self.mod = testutils.create_module(self.project, "mod") + self.pkg1 = testutils.create_package(self.project, "pkg1") + self.mod1 = testutils.create_module(self.project, "mod1", self.pkg1) + self.pkg2 = testutils.create_package(self.project, "pkg2") + self.mod2 = testutils.create_module(self.project, "mod2", self.pkg2) + self.mod3 = testutils.create_module(self.project, "mod3", self.pkg2) + p1 = testutils.create_package(self.project, "p1") + p2 = testutils.create_package(self.project, "p2", p1) + p3 = testutils.create_package(self.project, "p3", p2) + m1 = testutils.create_module(self.project, "m1", p3) # noqa + l = testutils.create_module(self.project, "l", p3) # noqa def tearDown(self): testutils.remove_project(self.project) super(ImportUtilsTest, self).tearDown() def test_get_import_for_module(self): - mod = self.project.find_module('mod') + mod = self.project.find_module("mod") import_statement = self.import_tools.get_import(mod) - self.assertEqual('import mod', - import_statement.get_import_statement()) + self.assertEqual("import mod", import_statement.get_import_statement()) def test_get_import_for_module_in_nested_modules(self): - mod = self.project.find_module('pkg1.mod1') + mod = self.project.find_module("pkg1.mod1") import_statement = self.import_tools.get_import(mod) - self.assertEqual('import pkg1.mod1', - import_statement.get_import_statement()) + self.assertEqual("import pkg1.mod1", import_statement.get_import_statement()) def test_get_import_for_module_in_init_dot_py(self): - init_dot_py = self.pkg1.get_child('__init__.py') + init_dot_py = self.pkg1.get_child("__init__.py") import_statement = self.import_tools.get_import(init_dot_py) - self.assertEqual('import pkg1', - import_statement.get_import_statement()) + self.assertEqual("import pkg1", import_statement.get_import_statement()) def test_get_from_import_for_module(self): - mod = self.project.find_module('mod') - import_statement = self.import_tools.get_from_import(mod, 'a_func') - self.assertEqual('from mod import a_func', - import_statement.get_import_statement()) + mod = self.project.find_module("mod") + import_statement = self.import_tools.get_from_import(mod, "a_func") + self.assertEqual( + "from mod import a_func", import_statement.get_import_statement() + ) def test_get_from_import_for_module_in_nested_modules(self): - mod = self.project.find_module('pkg1.mod1') - import_statement = self.import_tools.get_from_import(mod, 'a_func') - self.assertEqual('from pkg1.mod1 import a_func', - import_statement.get_import_statement()) + mod = self.project.find_module("pkg1.mod1") + import_statement = self.import_tools.get_from_import(mod, "a_func") + self.assertEqual( + "from pkg1.mod1 import a_func", import_statement.get_import_statement() + ) def test_get_from_import_for_module_in_init_dot_py(self): - init_dot_py = self.pkg1.get_child('__init__.py') - import_statement = self.import_tools.\ - get_from_import(init_dot_py, 'a_func') - self.assertEqual('from pkg1 import a_func', - import_statement.get_import_statement()) + init_dot_py = self.pkg1.get_child("__init__.py") + import_statement = self.import_tools.get_from_import(init_dot_py, "a_func") + self.assertEqual( + "from pkg1 import a_func", import_statement.get_import_statement() + ) def test_get_import_statements(self): - self.mod.write('import pkg1\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports - self.assertEqual('import pkg1', - imports[0].import_info.get_import_statement()) + self.assertEqual("import pkg1", imports[0].import_info.get_import_statement()) def test_get_import_statements_with_alias(self): - self.mod.write('import pkg1.mod1 as mod1\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1.mod1 as mod1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports - self.assertEqual('import pkg1.mod1 as mod1', - imports[0].import_info.get_import_statement()) + self.assertEqual( + "import pkg1.mod1 as mod1", imports[0].import_info.get_import_statement() + ) def test_get_import_statements_for_froms(self): - self.mod.write('from pkg1 import mod1\n') - pymod = self.project.get_module('mod') + self.mod.write("from pkg1 import mod1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports - self.assertEqual('from pkg1 import mod1', - imports[0].import_info.get_import_statement()) + self.assertEqual( + "from pkg1 import mod1", imports[0].import_info.get_import_statement() + ) def test_get_multi_line_import_statements_for_froms(self): - self.mod.write('from pkg1 \\\n import mod1\n') - pymod = self.project.get_module('mod') + self.mod.write("from pkg1 \\\n import mod1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports - self.assertEqual('from pkg1 import mod1', - imports[0].import_info.get_import_statement()) + self.assertEqual( + "from pkg1 import mod1", imports[0].import_info.get_import_statement() + ) def test_get_import_statements_for_from_star(self): - self.mod.write('from pkg1 import *\n') - pymod = self.project.get_module('mod') + self.mod.write("from pkg1 import *\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports - self.assertEqual('from pkg1 import *', - imports[0].import_info.get_import_statement()) + self.assertEqual( + "from pkg1 import *", imports[0].import_info.get_import_statement() + ) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_get_import_statements_for_new_relatives(self): - self.mod2.write('from .mod3 import x\n') - pymod = self.project.get_module('pkg2.mod2') + self.mod2.write("from .mod3 import x\n") + pymod = self.project.get_module("pkg2.mod2") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports - self.assertEqual('from .mod3 import x', - imports[0].import_info.get_import_statement()) + self.assertEqual( + "from .mod3 import x", imports[0].import_info.get_import_statement() + ) def test_ignoring_indented_imports(self): - self.mod.write('if True:\n import pkg1\n') - pymod = self.project.get_module('mod') + self.mod.write("if True:\n import pkg1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports self.assertEqual(0, len(imports)) def test_import_get_names(self): - self.mod.write('import pkg1 as pkg\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1 as pkg\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports context = importinfo.ImportContext(self.project, self.project.root) - self.assertEqual(['pkg'], - imports[0].import_info.get_imported_names(context)) + self.assertEqual(["pkg"], imports[0].import_info.get_imported_names(context)) def test_import_get_names_with_alias(self): - self.mod.write('import pkg1.mod1\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1.mod1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports context = importinfo.ImportContext(self.project, self.project.root) - self.assertEqual(['pkg1'], - imports[0].import_info.get_imported_names(context)) + self.assertEqual(["pkg1"], imports[0].import_info.get_imported_names(context)) def test_import_get_names_with_alias2(self): - self.mod1.write('def a_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n") + self.mod.write("from pkg1.mod1 import *\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.imports context = importinfo.ImportContext(self.project, self.project.root) - self.assertEqual(['a_func'], - imports[0].import_info.get_imported_names(context)) + self.assertEqual(["a_func"], imports[0].import_info.get_imported_names(context)) def test_empty_getting_used_imports(self): - self.mod.write('') - pymod = self.project.get_module('mod') + self.mod.write("") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.get_used_imports(pymod) self.assertEqual(0, len(imports)) def test_empty_getting_used_imports2(self): - self.mod.write('import pkg\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.get_used_imports(pymod) self.assertEqual(0, len(imports)) def test_simple_getting_used_imports(self): - self.mod.write('import pkg\nprint(pkg)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg\nprint(pkg)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.get_used_imports(pymod) self.assertEqual(1, len(imports)) - self.assertEqual('import pkg', imports[0].get_import_statement()) + self.assertEqual("import pkg", imports[0].get_import_statement()) def test_simple_getting_used_imports2(self): - self.mod.write('import pkg\ndef a_func():\n print(pkg)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg\ndef a_func():\n print(pkg)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.get_used_imports(pymod) self.assertEqual(1, len(imports)) - self.assertEqual('import pkg', imports[0].get_import_statement()) + self.assertEqual("import pkg", imports[0].get_import_statement()) def test_getting_used_imports_for_nested_scopes(self): - self.mod.write('import pkg1\nprint(pkg1)\n' - 'def a_func():\n pass\nprint(pkg1)\n') - pymod = self.project.get_module('mod') + self.mod.write( + "import pkg1\nprint(pkg1)\n" "def a_func():\n pass\nprint(pkg1)\n" + ) + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) - imports = module_with_imports.get_used_imports( - pymod['a_func'].get_object()) + imports = module_with_imports.get_used_imports(pymod["a_func"].get_object()) self.assertEqual(0, len(imports)) def test_getting_used_imports_for_nested_scopes2(self): - self.mod.write('from pkg1 import mod1\ndef a_func():' - '\n print(mod1)\n') - pymod = self.project.get_module('mod') + self.mod.write("from pkg1 import mod1\ndef a_func():" "\n print(mod1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) - imports = module_with_imports.get_used_imports( - pymod['a_func'].get_object()) + imports = module_with_imports.get_used_imports(pymod["a_func"].get_object()) self.assertEqual(1, len(imports)) - self.assertEqual('from pkg1 import mod1', - imports[0].get_import_statement()) + self.assertEqual("from pkg1 import mod1", imports[0].get_import_statement()) def test_empty_removing_unused_imports(self): - self.mod.write('import pkg1\nprint(pkg1)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\nprint(pkg1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('import pkg1\nprint(pkg1)\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1\nprint(pkg1)\n", module_with_imports.get_changed_source() + ) def test_simple_removing_unused_imports(self): - self.mod.write('import pkg1\n\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\n\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('', module_with_imports.get_changed_source()) + self.assertEqual("", module_with_imports.get_changed_source()) def test_simple_removing_unused_imports_for_froms(self): - self.mod1.write('def a_func():\n pass' - '\ndef another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import a_func, ' - 'another_func\n\na_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass" "\ndef another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import a_func, " "another_func\n\na_func()\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from pkg1.mod1 import a_func\n\na_func()\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import a_func\n\na_func()\n", + module_with_imports.get_changed_source(), + ) def test_simple_removing_unused_imports_for_from_stars(self): - self.mod.write('from pkg1.mod1 import *\n\n') - pymod = self.project.get_module('mod') + self.mod.write("from pkg1.mod1 import *\n\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('', module_with_imports.get_changed_source()) + self.assertEqual("", module_with_imports.get_changed_source()) def test_simple_removing_unused_imports_for_nested_modules(self): - self.mod1.write('def a_func():\n pass\n') - self.mod.write('import pkg1.mod1\npkg1.mod1.a_func()') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n") + self.mod.write("import pkg1.mod1\npkg1.mod1.a_func()") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('import pkg1.mod1\npkg1.mod1.a_func()', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1.mod1\npkg1.mod1.a_func()", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_functions_of_the_same_name(self): - self.mod.write('def a_func():\n pass\ndef a_func():\n pass\n') - pymod = self.project.get_module('mod') + self.mod.write("def a_func():\n pass\ndef a_func():\n pass\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('def a_func():\n pass\ndef a_func():\n pass\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "def a_func():\n pass\ndef a_func():\n pass\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_for_from_import_with_as(self): - self.mod.write('a_var = 1\n') - self.mod1.write('from mod import a_var as myvar\na_var = myvar\n') + self.mod.write("a_var = 1\n") + self.mod1.write("from mod import a_var as myvar\na_var = myvar\n") pymod = self.project.get_pymodule(self.mod1) module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from mod import a_var as myvar\na_var = myvar\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from mod import a_var as myvar\na_var = myvar\n", + module_with_imports.get_changed_source(), + ) def test_not_removing_imports_that_conflict_with_class_names(self): - code = 'import pkg1\nclass A(object):\n pkg1 = 0\n' \ - ' def f(self):\n a_var = pkg1\n' + code = ( + "import pkg1\nclass A(object):\n pkg1 = 0\n" + " def f(self):\n a_var = pkg1\n" + ) self.mod.write(code) - pymod = self.project.get_module('mod') + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() self.assertEqual(code, module_with_imports.get_changed_source()) def test_adding_imports(self): - self.mod.write('\n') - pymod = self.project.get_module('mod') + self.mod.write("\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) new_import = self.import_tools.get_import(self.mod1) module_with_imports.add_import(new_import) - self.assertEqual('import pkg1.mod1\n', - module_with_imports.get_changed_source()) + self.assertEqual("import pkg1.mod1\n", module_with_imports.get_changed_source()) def test_adding_imports_no_pull_to_top(self): - self.mod.write('import pkg2.mod3\nclass A(object):\n pass\n\n' - 'import pkg2.mod2\n') - pymod = self.project.get_module('mod') - self.project.prefs['pull_imports_to_top'] = False + self.mod.write( + "import pkg2.mod3\nclass A(object):\n pass\n\n" "import pkg2.mod2\n" + ) + pymod = self.project.get_module("mod") + self.project.prefs["pull_imports_to_top"] = False module_with_imports = self.import_tools.module_imports(pymod) new_import = self.import_tools.get_import(self.mod1) module_with_imports.add_import(new_import) - self.assertEqual('import pkg2.mod3\nclass A(object):\n pass\n\n' - 'import pkg2.mod2\nimport pkg1.mod1\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg2.mod3\nclass A(object):\n pass\n\n" + "import pkg2.mod2\nimport pkg1.mod1\n", + module_with_imports.get_changed_source(), + ) def test_adding_from_imports(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import a_func\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import a_func\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) - new_import = self.import_tools.get_from_import( - self.mod1, 'another_func') + new_import = self.import_tools.get_from_import(self.mod1, "another_func") module_with_imports.add_import(new_import) - self.assertEqual('from pkg1.mod1 import a_func, another_func\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import a_func, another_func\n", + module_with_imports.get_changed_source(), + ) def test_adding_to_star_imports(self): - self.mod1.write('def a_func():\n pass' - '\ndef another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass" "\ndef another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import *\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) - new_import = self.import_tools.get_from_import( - self.mod1, 'another_func') + new_import = self.import_tools.get_from_import(self.mod1, "another_func") module_with_imports.add_import(new_import) - self.assertEqual('from pkg1.mod1 import *\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import *\n", module_with_imports.get_changed_source() + ) def test_adding_star_imports(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import a_func\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import a_func\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) - new_import = self.import_tools.get_from_import(self.mod1, '*') + new_import = self.import_tools.get_from_import(self.mod1, "*") module_with_imports.add_import(new_import) - self.assertEqual('from pkg1.mod1 import *\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import *\n", module_with_imports.get_changed_source() + ) def test_adding_imports_and_preserving_spaces_after_imports(self): - self.mod.write('import pkg1\n\n\nprint(pkg1)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\n\n\nprint(pkg1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) new_import = self.import_tools.get_import(self.pkg2) module_with_imports.add_import(new_import) - self.assertEqual('import pkg1\nimport pkg2\n\n\nprint(pkg1)\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1\nimport pkg2\n\n\nprint(pkg1)\n", + module_with_imports.get_changed_source(), + ) def test_not_changing_the_format_of_unchanged_imports(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import (a_func,\n another_func)\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import (a_func,\n another_func)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) self.assertEqual( - 'from pkg1.mod1 import (a_func,\n another_func)\n', - module_with_imports.get_changed_source()) + "from pkg1.mod1 import (a_func,\n another_func)\n", + module_with_imports.get_changed_source(), + ) def test_not_changing_the_format_of_unchanged_imports2(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import (a_func)\na_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import (a_func)\na_func()\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from pkg1.mod1 import (a_func)\na_func()\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import (a_func)\na_func()\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_reoccuring_names(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\n' - 'from pkg1.mod1 import a_func\na_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write( + "from pkg1.mod1 import *\n" "from pkg1.mod1 import a_func\na_func()\n" + ) + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from pkg1.mod1 import *\na_func()\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import *\na_func()\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_reoccuring_names2(self): - self.mod.write('import pkg2.mod2\nimport pkg2.mod3\n' - 'print(pkg2.mod2, pkg2.mod3)') - pymod = self.project.get_module('mod') + self.mod.write( + "import pkg2.mod2\nimport pkg2.mod3\n" "print(pkg2.mod2, pkg2.mod3)" + ) + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() self.assertEqual( - 'import pkg2.mod2\nimport pkg2.mod3\nprint(pkg2.mod2, pkg2.mod3)', - module_with_imports.get_changed_source()) + "import pkg2.mod2\nimport pkg2.mod3\nprint(pkg2.mod2, pkg2.mod3)", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_common_packages(self): - self.mod.write('import pkg1.mod1\nimport pkg1' - '\nprint(pkg1, pkg1.mod1)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1.mod1\nimport pkg1" "\nprint(pkg1, pkg1.mod1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('import pkg1.mod1\nprint(pkg1, pkg1.mod1)\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1.mod1\nprint(pkg1, pkg1.mod1)\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_common_packages_reversed(self): - self.mod.write('import pkg1\nimport pkg1.mod1' - '\nprint(pkg1, pkg1.mod1)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\nimport pkg1.mod1" "\nprint(pkg1, pkg1.mod1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_duplicates() - self.assertEqual('import pkg1.mod1\nprint(pkg1, pkg1.mod1)\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1.mod1\nprint(pkg1, pkg1.mod1)\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_common_packages2(self): - self.mod.write('import pkg1.mod1\nimport pkg1.mod2\nprint(pkg1)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1.mod1\nimport pkg1.mod2\nprint(pkg1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('import pkg1.mod1\nprint(pkg1)\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1.mod1\nprint(pkg1)\n", module_with_imports.get_changed_source() + ) def test_removing_unused_imports_and_froms(self): - self.mod1.write('def func1():\n pass\n') - self.mod.write('from pkg1.mod1 import func1\n') - pymod = self.project.get_module('mod') + self.mod1.write("def func1():\n pass\n") + self.mod.write("from pkg1.mod1 import func1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('', module_with_imports.get_changed_source()) + self.assertEqual("", module_with_imports.get_changed_source()) def test_removing_unused_imports_and_froms2(self): - self.mod1.write('def func1():\n pass\n') - self.mod.write('from pkg1.mod1 import func1\nfunc1()') - pymod = self.project.get_module('mod') + self.mod1.write("def func1():\n pass\n") + self.mod.write("from pkg1.mod1 import func1\nfunc1()") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from pkg1.mod1 import func1\nfunc1()', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import func1\nfunc1()", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_froms3(self): - self.mod1.write('def func1():\n pass\n') - self.mod.write('from pkg1.mod1 import func1\n' - 'def a_func():\n func1()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def func1():\n pass\n") + self.mod.write("from pkg1.mod1 import func1\n" "def a_func():\n func1()\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() self.assertEqual( - 'from pkg1.mod1 import func1\ndef a_func():\n func1()\n', - module_with_imports.get_changed_source()) + "from pkg1.mod1 import func1\ndef a_func():\n func1()\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_froms4(self): - self.mod1.write('def func1():\n pass\n') - self.mod.write('from pkg1.mod1 import func1\nclass A(object):\n' - ' def a_func(self):\n func1()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def func1():\n pass\n") + self.mod.write( + "from pkg1.mod1 import func1\nclass A(object):\n" + " def a_func(self):\n func1()\n" + ) + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from pkg1.mod1 import func1\nclass A(object):\n' - ' def a_func(self):\n func1()\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import func1\nclass A(object):\n" + " def a_func(self):\n func1()\n", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_and_getting_attributes(self): - self.mod1.write('class A(object):\n def f(self):\n pass\n') - self.mod.write('from pkg1.mod1 import A\nvar = A().f()') - pymod = self.project.get_module('mod') + self.mod1.write("class A(object):\n def f(self):\n pass\n") + self.mod.write("from pkg1.mod1 import A\nvar = A().f()") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('from pkg1.mod1 import A\nvar = A().f()', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import A\nvar = A().f()", + module_with_imports.get_changed_source(), + ) def test_removing_unused_imports_function_parameters(self): - self.mod1.write('def func1():\n pass\n') - self.mod.write('import pkg1\ndef a_func(pkg1):\n my_var = pkg1\n') - pymod = self.project.get_module('mod') + self.mod1.write("def func1():\n pass\n") + self.mod.write("import pkg1\ndef a_func(pkg1):\n my_var = pkg1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('def a_func(pkg1):\n my_var = pkg1\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "def a_func(pkg1):\n my_var = pkg1\n", + module_with_imports.get_changed_source(), + ) def test_trivial_expanding_star_imports(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import *\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.expand_stars() - self.assertEqual('', module_with_imports.get_changed_source()) + self.assertEqual("", module_with_imports.get_changed_source()) def test_expanding_star_imports(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\na_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import *\na_func()\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.expand_stars() - self.assertEqual('from pkg1.mod1 import a_func\na_func()\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1.mod1 import a_func\na_func()\n", + module_with_imports.get_changed_source(), + ) def test_removing_duplicate_imports(self): - self.mod.write('import pkg1\nimport pkg1\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\nimport pkg1\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_duplicates() - self.assertEqual('import pkg1\n', - module_with_imports.get_changed_source()) + self.assertEqual("import pkg1\n", module_with_imports.get_changed_source()) def test_removing_duplicates_and_reoccuring_names(self): - self.mod.write('import pkg2.mod2\nimport pkg2.mod3\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg2.mod2\nimport pkg2.mod3\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_duplicates() - self.assertEqual('import pkg2.mod2\nimport pkg2.mod3\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg2.mod2\nimport pkg2.mod3\n", + module_with_imports.get_changed_source(), + ) def test_removing_duplicate_imports_for_froms(self): - self.mod1.write( - 'def a_func():\n pass\ndef another_func():\n pass\n') - self.mod.write('from pkg1 import a_func\n' - 'from pkg1 import a_func, another_func\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\ndef another_func():\n pass\n") + self.mod.write( + "from pkg1 import a_func\n" "from pkg1 import a_func, another_func\n" + ) + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_duplicates() - self.assertEqual('from pkg1 import a_func, another_func\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "from pkg1 import a_func, another_func\n", + module_with_imports.get_changed_source(), + ) def test_transforming_froms_to_normal_changing_imports(self): - self.mod1.write('def a_func():\n pass\n') - self.mod.write('from pkg1.mod1 import a_func\nprint(a_func)\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n") + self.mod.write("from pkg1.mod1 import a_func\nprint(a_func)\n") + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import pkg1.mod1\nprint(pkg1.mod1.a_func)\n', - changed_module) + self.assertEqual("import pkg1.mod1\nprint(pkg1.mod1.a_func)\n", changed_module) def test_transforming_froms_to_normal_changing_occurances(self): - self.mod1.write('def a_func():\n pass\n') - self.mod.write('from pkg1.mod1 import a_func\na_func()') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n") + self.mod.write("from pkg1.mod1 import a_func\na_func()") + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import pkg1.mod1\npkg1.mod1.a_func()', - changed_module) + self.assertEqual("import pkg1.mod1\npkg1.mod1.a_func()", changed_module) def test_transforming_froms_to_normal_for_multi_imports(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\na_func()\nanother_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write("from pkg1.mod1 import *\na_func()\nanother_func()\n") + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) self.assertEqual( - 'import pkg1.mod1\npkg1.mod1.a_func()\npkg1.mod1.another_func()\n', - changed_module) + "import pkg1.mod1\npkg1.mod1.a_func()\npkg1.mod1.another_func()\n", + changed_module, + ) def test_transform_froms_to_norm_for_multi_imports_inside_parens(self): - self.mod1.write('def a_func():\n pass\n' - 'def another_func():\n pass\n') - self.mod.write('from pkg1.mod1 import (a_func, \n another_func)' - '\na_func()\nanother_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n" "def another_func():\n pass\n") + self.mod.write( + "from pkg1.mod1 import (a_func, \n another_func)" + "\na_func()\nanother_func()\n" + ) + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) self.assertEqual( - 'import pkg1.mod1\npkg1.mod1.a_func()\npkg1.mod1.another_func()\n', - changed_module) + "import pkg1.mod1\npkg1.mod1.a_func()\npkg1.mod1.another_func()\n", + changed_module, + ) def test_transforming_froms_to_normal_from_stars(self): - self.mod1.write('def a_func():\n pass\n') - self.mod.write('from pkg1.mod1 import *\na_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n") + self.mod.write("from pkg1.mod1 import *\na_func()\n") + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import pkg1.mod1\npkg1.mod1.a_func()\n', - changed_module) + self.assertEqual("import pkg1.mod1\npkg1.mod1.a_func()\n", changed_module) def test_transforming_froms_to_normal_from_stars2(self): - self.mod1.write('a_var = 10') - self.mod.write('import pkg1.mod1\nfrom pkg1.mod1 import a_var\n' - 'def a_func():\n print(pkg1.mod1, a_var)\n') - pymod = self.project.get_module('mod') + self.mod1.write("a_var = 10") + self.mod.write( + "import pkg1.mod1\nfrom pkg1.mod1 import a_var\n" + "def a_func():\n print(pkg1.mod1, a_var)\n" + ) + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import pkg1.mod1\n' - 'def a_func():\n ' - 'print(pkg1.mod1, pkg1.mod1.a_var)\n', - changed_module) + self.assertEqual( + "import pkg1.mod1\n" + "def a_func():\n " + "print(pkg1.mod1, pkg1.mod1.a_var)\n", + changed_module, + ) def test_transforming_froms_to_normal_from_with_alias(self): - self.mod1.write('def a_func():\n pass\n') - self.mod.write( - 'from pkg1.mod1 import a_func as another_func\nanother_func()\n') - pymod = self.project.get_module('mod') + self.mod1.write("def a_func():\n pass\n") + self.mod.write("from pkg1.mod1 import a_func as another_func\nanother_func()\n") + pymod = self.project.get_module("mod") changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import pkg1.mod1\npkg1.mod1.a_func()\n', - changed_module) + self.assertEqual("import pkg1.mod1\npkg1.mod1.a_func()\n", changed_module) def test_transforming_froms_to_normal_for_relatives(self): - self.mod2.write('def a_func():\n pass\n') - self.mod3.write('from mod2 import *\na_func()\n') + self.mod2.write("def a_func():\n pass\n") + self.mod3.write("from mod2 import *\na_func()\n") pymod = self.project.get_pymodule(self.mod3) changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import pkg2.mod2\npkg2.mod2.a_func()\n', - changed_module) + self.assertEqual("import pkg2.mod2\npkg2.mod2.a_func()\n", changed_module) def test_transforming_froms_to_normal_for_os_path(self): - self.mod.write('from os import path\npath.exists(\'.\')\n') + self.mod.write("from os import path\npath.exists('.')\n") pymod = self.project.get_pymodule(self.mod) changed_module = self.import_tools.froms_to_imports(pymod) - self.assertEqual('import os\nos.path.exists(\'.\')\n', changed_module) + self.assertEqual("import os\nos.path.exists('.')\n", changed_module) def test_transform_relatives_imports_to_abs_imports_doing_nothing(self): - self.mod2.write('from pkg1 import mod1\nimport mod1\n') + self.mod2.write("from pkg1 import mod1\nimport mod1\n") pymod = self.project.get_pymodule(self.mod2) - self.assertEqual('from pkg1 import mod1\nimport mod1\n', - self.import_tools.relatives_to_absolutes(pymod)) + self.assertEqual( + "from pkg1 import mod1\nimport mod1\n", + self.import_tools.relatives_to_absolutes(pymod), + ) def test_transform_relatives_to_absolute_imports_for_normal_imports(self): - self.mod2.write('import mod3\n') + self.mod2.write("import mod3\n") pymod = self.project.get_pymodule(self.mod2) - self.assertEqual('import pkg2.mod3\n', - self.import_tools.relatives_to_absolutes(pymod)) + self.assertEqual( + "import pkg2.mod3\n", self.import_tools.relatives_to_absolutes(pymod) + ) def test_transform_relatives_imports_to_absolute_imports_for_froms(self): - self.mod3.write('def a_func():\n pass\n') - self.mod2.write('from mod3 import a_func\n') + self.mod3.write("def a_func():\n pass\n") + self.mod2.write("from mod3 import a_func\n") pymod = self.project.get_pymodule(self.mod2) - self.assertEqual('from pkg2.mod3 import a_func\n', - self.import_tools.relatives_to_absolutes(pymod)) + self.assertEqual( + "from pkg2.mod3 import a_func\n", + self.import_tools.relatives_to_absolutes(pymod), + ) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_transform_rel_imports_to_abs_imports_for_new_relatives(self): - self.mod3.write('def a_func():\n pass\n') - self.mod2.write('from .mod3 import a_func\n') + self.mod3.write("def a_func():\n pass\n") + self.mod2.write("from .mod3 import a_func\n") pymod = self.project.get_pymodule(self.mod2) - self.assertEqual('from pkg2.mod3 import a_func\n', - self.import_tools.relatives_to_absolutes(pymod)) + self.assertEqual( + "from pkg2.mod3 import a_func\n", + self.import_tools.relatives_to_absolutes(pymod), + ) def test_transform_relatives_to_absolute_imports_for_normal_imports2(self): - self.mod2.write('import mod3\nprint(mod3)') + self.mod2.write("import mod3\nprint(mod3)") pymod = self.project.get_pymodule(self.mod2) - self.assertEqual('import pkg2.mod3\nprint(pkg2.mod3)', - self.import_tools.relatives_to_absolutes(pymod)) + self.assertEqual( + "import pkg2.mod3\nprint(pkg2.mod3)", + self.import_tools.relatives_to_absolutes(pymod), + ) def test_transform_relatives_to_absolute_imports_for_aliases(self): - self.mod2.write('import mod3 as mod3\nprint(mod3)') + self.mod2.write("import mod3 as mod3\nprint(mod3)") pymod = self.project.get_pymodule(self.mod2) - self.assertEqual('import pkg2.mod3 as mod3\nprint(mod3)', - self.import_tools.relatives_to_absolutes(pymod)) + self.assertEqual( + "import pkg2.mod3 as mod3\nprint(mod3)", + self.import_tools.relatives_to_absolutes(pymod), + ) def test_organizing_imports(self): - self.mod1.write('import mod1\n') + self.mod1.write("import mod1\n") pymod = self.project.get_pymodule(self.mod1) - self.assertEqual('', self.import_tools.organize_imports(pymod)) + self.assertEqual("", self.import_tools.organize_imports(pymod)) def test_organizing_imports_without_deduplication(self): - contents = 'from pkg2 import mod2\nfrom pkg2 import mod3\n' + contents = "from pkg2 import mod2\nfrom pkg2 import mod3\n" self.mod.write(contents) pymod = self.project.get_pymodule(self.mod) - self.project.prefs['split_imports'] = True - self.assertEqual(contents, - self.import_tools.organize_imports(pymod, - unused=False)) + self.project.prefs["split_imports"] = True + self.assertEqual( + contents, self.import_tools.organize_imports(pymod, unused=False) + ) def test_splitting_imports(self): - self.mod.write('from pkg1 import mod1\nfrom pkg2 import mod2, mod3\n') + self.mod.write("from pkg1 import mod1\nfrom pkg2 import mod2, mod3\n") pymod = self.project.get_pymodule(self.mod) - self.project.prefs['split_imports'] = True - self.assertEqual('from pkg1 import mod1\nfrom pkg2 import mod2\n' - 'from pkg2 import mod3\n', - self.import_tools.organize_imports(pymod, - unused=False)) + self.project.prefs["split_imports"] = True + self.assertEqual( + "from pkg1 import mod1\nfrom pkg2 import mod2\n" "from pkg2 import mod3\n", + self.import_tools.organize_imports(pymod, unused=False), + ) def test_splitting_imports_no_pull_to_top(self): - self.mod.write('from pkg2 import mod3, mod4\n' - 'from pkg1 import mod2\nfrom pkg1 import mod1\n') + self.mod.write( + "from pkg2 import mod3, mod4\n" + "from pkg1 import mod2\nfrom pkg1 import mod1\n" + ) pymod = self.project.get_pymodule(self.mod) - self.project.prefs['split_imports'] = True - self.project.prefs['pull_imports_to_top'] = False - self.assertEqual('from pkg1 import mod2\nfrom pkg1 import mod1\n' - 'from pkg2 import mod3\nfrom pkg2 import mod4\n', - self.import_tools.organize_imports(pymod, - sort=False, - unused=False)) + self.project.prefs["split_imports"] = True + self.project.prefs["pull_imports_to_top"] = False + self.assertEqual( + "from pkg1 import mod2\nfrom pkg1 import mod1\n" + "from pkg2 import mod3\nfrom pkg2 import mod4\n", + self.import_tools.organize_imports(pymod, sort=False, unused=False), + ) def test_splitting_imports_with_filter(self): - self.mod.write('from pkg1 import mod1, mod2\n' - 'from pkg2 import mod3, mod4\n') + self.mod.write("from pkg1 import mod1, mod2\n" "from pkg2 import mod3, mod4\n") pymod = self.project.get_pymodule(self.mod) - self.project.prefs['split_imports'] = True + self.project.prefs["split_imports"] = True def import_filter(stmt): - return stmt.import_info.module_name == 'pkg1' + return stmt.import_info.module_name == "pkg1" self.assertEqual( - 'from pkg1 import mod1\nfrom pkg1 import mod2\n' - 'from pkg2 import mod3, mod4\n', - self.import_tools.organize_imports(pymod, unused=False, - import_filter=import_filter)) + "from pkg1 import mod1\nfrom pkg1 import mod2\n" + "from pkg2 import mod3, mod4\n", + self.import_tools.organize_imports( + pymod, unused=False, import_filter=import_filter + ), + ) def test_splitting_duplicate_imports(self): - self.mod.write('from pkg2 import mod1\nfrom pkg2 import mod1, mod2\n') + self.mod.write("from pkg2 import mod1\nfrom pkg2 import mod1, mod2\n") pymod = self.project.get_pymodule(self.mod) - self.project.prefs['split_imports'] = True - self.assertEqual('from pkg2 import mod1\nfrom pkg2 import mod2\n', - self.import_tools.organize_imports(pymod, - unused=False)) + self.project.prefs["split_imports"] = True + self.assertEqual( + "from pkg2 import mod1\nfrom pkg2 import mod2\n", + self.import_tools.organize_imports(pymod, unused=False), + ) def test_splitting_duplicate_imports2(self): - self.mod.write('from pkg2 import mod1, mod3\n' - 'from pkg2 import mod1, mod2\n' - 'from pkg2 import mod2, mod3\n') + self.mod.write( + "from pkg2 import mod1, mod3\n" + "from pkg2 import mod1, mod2\n" + "from pkg2 import mod2, mod3\n" + ) pymod = self.project.get_pymodule(self.mod) - self.project.prefs['split_imports'] = True - self.assertEqual('from pkg2 import mod1\nfrom pkg2 import mod2\n' - 'from pkg2 import mod3\n', - self.import_tools.organize_imports(pymod, - unused=False)) + self.project.prefs["split_imports"] = True + self.assertEqual( + "from pkg2 import mod1\nfrom pkg2 import mod2\n" "from pkg2 import mod3\n", + self.import_tools.organize_imports(pymod, unused=False), + ) def test_removing_self_imports(self): - self.mod.write('import mod\nmod.a_var = 1\nprint(mod.a_var)\n') + self.mod.write("import mod\nmod.a_var = 1\nprint(mod.a_var)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports2(self): - self.mod1.write('import pkg1.mod1\npkg1.mod1.a_var = 1\n' - 'print(pkg1.mod1.a_var)\n') + self.mod1.write( + "import pkg1.mod1\npkg1.mod1.a_var = 1\n" "print(pkg1.mod1.a_var)\n" + ) pymod = self.project.get_pymodule(self.mod1) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_with_as(self): - self.mod.write('import mod as mymod\n' - 'mymod.a_var = 1\nprint(mymod.a_var)\n') + self.mod.write("import mod as mymod\n" "mymod.a_var = 1\nprint(mymod.a_var)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_for_froms(self): - self.mod1.write('from pkg1 import mod1\n' - 'mod1.a_var = 1\nprint(mod1.a_var)\n') + self.mod1.write("from pkg1 import mod1\n" "mod1.a_var = 1\nprint(mod1.a_var)\n") pymod = self.project.get_pymodule(self.mod1) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_for_froms_with_as(self): - self.mod1.write('from pkg1 import mod1 as mymod\n' - 'mymod.a_var = 1\nprint(mymod.a_var)\n') + self.mod1.write( + "from pkg1 import mod1 as mymod\n" "mymod.a_var = 1\nprint(mymod.a_var)\n" + ) pymod = self.project.get_pymodule(self.mod1) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_for_froms2(self): - self.mod.write('from mod import a_var\na_var = 1\nprint(a_var)\n') + self.mod.write("from mod import a_var\na_var = 1\nprint(a_var)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_for_froms3(self): - self.mod.write('from mod import a_var\na_var = 1\nprint(a_var)\n') + self.mod.write("from mod import a_var\na_var = 1\nprint(a_var)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_for_froms4(self): - self.mod.write('from mod import a_var as myvar\n' - 'a_var = 1\nprint(myvar)\n') + self.mod.write("from mod import a_var as myvar\n" "a_var = 1\nprint(myvar)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(a_var)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(a_var)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_with_no_dot_after_mod(self): - self.mod.write('import mod\nprint(mod)\n') + self.mod.write("import mod\nprint(mod)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import mod\n\n\nprint(mod)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "import mod\n\n\nprint(mod)\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_with_no_dot_after_mod2(self): - self.mod.write('import mod\na_var = 1\n' - 'print(mod\\\n \\\n .var)\n\n') + self.mod.write("import mod\na_var = 1\n" "print(mod\\\n \\\n .var)\n\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(var)\n\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(var)\n\n", self.import_tools.organize_imports(pymod) + ) def test_removing_self_imports_for_from_import_star(self): - self.mod.write('from mod import *\na_var = 1\nprint(myvar)\n') + self.mod.write("from mod import *\na_var = 1\nprint(myvar)\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('a_var = 1\nprint(myvar)\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "a_var = 1\nprint(myvar)\n", self.import_tools.organize_imports(pymod) + ) def test_not_removing_future_imports(self): - self.mod.write('from __future__ import division\n') + self.mod.write("from __future__ import division\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('from __future__ import division\n', - self.import_tools.organize_imports(pymod)) + self.assertEqual( + "from __future__ import division\n", + self.import_tools.organize_imports(pymod), + ) def test_sorting_empty_imports(self): - self.mod.write('') + self.mod.write("") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('', self.import_tools.sort_imports(pymod)) + self.assertEqual("", self.import_tools.sort_imports(pymod)) def test_sorting_one_import(self): - self.mod.write('import pkg1.mod1\n') + self.mod.write("import pkg1.mod1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import pkg1.mod1\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual("import pkg1.mod1\n", self.import_tools.sort_imports(pymod)) def test_sorting_imports_alphabetically(self): - self.mod.write('import pkg2.mod2\nimport pkg1.mod1\n') + self.mod.write("import pkg2.mod2\nimport pkg1.mod1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import pkg1.mod1\nimport pkg2.mod2\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual( + "import pkg1.mod1\nimport pkg2.mod2\n", + self.import_tools.sort_imports(pymod), + ) def test_sorting_imports_purely_alphabetically(self): - self.mod.write('from pkg2 import mod3 as mod0\n' - 'import pkg2.mod2\nimport pkg1.mod1\n') + self.mod.write( + "from pkg2 import mod3 as mod0\n" "import pkg2.mod2\nimport pkg1.mod1\n" + ) pymod = self.project.get_pymodule(self.mod) - self.project.prefs['sort_imports_alphabetically'] = True - self.assertEqual('import pkg1.mod1\nimport pkg2.mod2\n' - 'from pkg2 import mod3 as mod0\n', - self.import_tools.sort_imports(pymod)) + self.project.prefs["sort_imports_alphabetically"] = True + self.assertEqual( + "import pkg1.mod1\nimport pkg2.mod2\n" "from pkg2 import mod3 as mod0\n", + self.import_tools.sort_imports(pymod), + ) def test_sorting_imports_and_froms(self): - self.mod.write('import pkg2.mod2\nfrom pkg1 import mod1\n') + self.mod.write("import pkg2.mod2\nfrom pkg1 import mod1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import pkg2.mod2\nfrom pkg1 import mod1\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual( + "import pkg2.mod2\nfrom pkg1 import mod1\n", + self.import_tools.sort_imports(pymod), + ) def test_sorting_imports_and_standard_modules(self): - self.mod.write('import pkg1\nimport sys\n') + self.mod.write("import pkg1\nimport sys\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import sys\n\nimport pkg1\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual( + "import sys\n\nimport pkg1\n", self.import_tools.sort_imports(pymod) + ) def test_sorting_imports_and_standard_modules2(self): - self.mod.write('import sys\n\nimport time\n') + self.mod.write("import sys\n\nimport time\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import sys\nimport time\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual( + "import sys\nimport time\n", self.import_tools.sort_imports(pymod) + ) def test_sorting_only_standard_modules(self): - self.mod.write('import sys\n') + self.mod.write("import sys\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import sys\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual("import sys\n", self.import_tools.sort_imports(pymod)) def test_sorting_third_party(self): - self.mod.write('import pkg1\nimport a_third_party\n') + self.mod.write("import pkg1\nimport a_third_party\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import a_third_party\n\nimport pkg1\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual( + "import a_third_party\n\nimport pkg1\n", + self.import_tools.sort_imports(pymod), + ) def test_sorting_only_third_parties(self): - self.mod.write('import a_third_party\na_var = 1\n') + self.mod.write("import a_third_party\na_var = 1\n") pymod = self.project.get_pymodule(self.mod) - self.assertEqual('import a_third_party\n\n\na_var = 1\n', - self.import_tools.sort_imports(pymod)) + self.assertEqual( + "import a_third_party\n\n\na_var = 1\n", + self.import_tools.sort_imports(pymod), + ) def test_simple_handling_long_imports(self): - self.mod.write('import pkg1.mod1\n\n\nm = pkg1.mod1\n') + self.mod.write("import pkg1.mod1\n\n\nm = pkg1.mod1\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'import pkg1.mod1\n\n\nm = pkg1.mod1\n', - self.import_tools.handle_long_imports(pymod, maxdots=2)) + "import pkg1.mod1\n\n\nm = pkg1.mod1\n", + self.import_tools.handle_long_imports(pymod, maxdots=2), + ) def test_handling_long_imports_for_many_dots(self): - self.mod.write('import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n') + self.mod.write("import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from p1.p2.p3 import m1\n\n\nm = m1\n', - self.import_tools.handle_long_imports(pymod, maxdots=2)) + "from p1.p2.p3 import m1\n\n\nm = m1\n", + self.import_tools.handle_long_imports(pymod, maxdots=2), + ) def test_handling_long_imports_for_their_length(self): - self.mod.write('import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n') + self.mod.write("import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n', - self.import_tools.handle_long_imports(pymod, maxdots=3, - maxlength=20)) + "import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n", + self.import_tools.handle_long_imports(pymod, maxdots=3, maxlength=20), + ) def test_handling_long_imports_for_many_dots2(self): - self.mod.write('import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n') + self.mod.write("import p1.p2.p3.m1\n\n\nm = p1.p2.p3.m1\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from p1.p2.p3 import m1\n\n\nm = m1\n', - self.import_tools.handle_long_imports(pymod, maxdots=3, - maxlength=10)) + "from p1.p2.p3 import m1\n\n\nm = m1\n", + self.import_tools.handle_long_imports(pymod, maxdots=3, maxlength=10), + ) def test_handling_long_imports_with_one_letter_last(self): - self.mod.write('import p1.p2.p3.l\n\n\nm = p1.p2.p3.l\n') + self.mod.write("import p1.p2.p3.l\n\n\nm = p1.p2.p3.l\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from p1.p2.p3 import l\n\n\nm = l\n', - self.import_tools.handle_long_imports(pymod, maxdots=2)) + "from p1.p2.p3 import l\n\n\nm = l\n", + self.import_tools.handle_long_imports(pymod, maxdots=2), + ) def test_empty_removing_unused_imports_and_eating_blank_lines(self): - self.mod.write('import pkg1\nimport pkg2\n\n\nprint(pkg1)\n') - pymod = self.project.get_module('mod') + self.mod.write("import pkg1\nimport pkg2\n\n\nprint(pkg1)\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) module_with_imports.remove_unused_imports() - self.assertEqual('import pkg1\n\n\nprint(pkg1)\n', - module_with_imports.get_changed_source()) + self.assertEqual( + "import pkg1\n\n\nprint(pkg1)\n", module_with_imports.get_changed_source() + ) def test_sorting_imports_moving_to_top(self): - self.mod.write('import mod\ndef f():\n print(mod, pkg1, pkg2)\n' - 'import pkg1\nimport pkg2\n') - pymod = self.project.get_module('mod') - self.assertEqual('import mod\nimport pkg1\nimport pkg2\n\n\n' - 'def f():\n print(mod, pkg1, pkg2)\n', - self.import_tools.sort_imports(pymod)) + self.mod.write( + "import mod\ndef f():\n print(mod, pkg1, pkg2)\n" + "import pkg1\nimport pkg2\n" + ) + pymod = self.project.get_module("mod") + self.assertEqual( + "import mod\nimport pkg1\nimport pkg2\n\n\n" + "def f():\n print(mod, pkg1, pkg2)\n", + self.import_tools.sort_imports(pymod), + ) def test_sorting_imports_moving_to_top2(self): - self.mod.write('def f():\n print(mod)\nimport mod\n') - pymod = self.project.get_module('mod') - self.assertEqual('import mod\n\n\ndef f():\n print(mod)\n', - self.import_tools.sort_imports(pymod)) + self.mod.write("def f():\n print(mod)\nimport mod\n") + pymod = self.project.get_module("mod") + self.assertEqual( + "import mod\n\n\ndef f():\n print(mod)\n", + self.import_tools.sort_imports(pymod), + ) # Sort pulls imports to the top anyway def test_sorting_imports_no_pull_to_top(self): - code = ('import pkg2\ndef f():\n print(mod, pkg1, pkg2)\n' - 'import pkg1\nimport mod\n') + code = ( + "import pkg2\ndef f():\n print(mod, pkg1, pkg2)\n" + "import pkg1\nimport mod\n" + ) self.mod.write(code) - pymod = self.project.get_module('mod') - self.project.prefs['pull_imports_to_top'] = False + pymod = self.project.get_module("mod") + self.project.prefs["pull_imports_to_top"] = False self.assertEqual( - 'import mod\nimport pkg1\nimport pkg2\n\n\n' - 'def f():\n print(mod, pkg1, pkg2)\n', - self.import_tools.sort_imports(pymod)) + "import mod\nimport pkg1\nimport pkg2\n\n\n" + "def f():\n print(mod, pkg1, pkg2)\n", + self.import_tools.sort_imports(pymod), + ) def test_sorting_imports_moving_to_top_and_module_docs(self): - self.mod.write('"""\ndocs\n"""\ndef f():' - '\n print(mod)\nimport mod\n') - pymod = self.project.get_module('mod') + self.mod.write('"""\ndocs\n"""\ndef f():' "\n print(mod)\nimport mod\n") + pymod = self.project.get_module("mod") self.assertEqual( '"""\ndocs\n"""\nimport mod\n\n\ndef f():\n print(mod)\n', - self.import_tools.sort_imports(pymod)) + self.import_tools.sort_imports(pymod), + ) def test_sorting_imports_moving_to_top_and_module_docs2(self): - self.mod.write('"""\ndocs\n"""\n\n\nimport bbb\nimport aaa\n' - 'def f():\n print(mod)\nimport mod\n') - pymod = self.project.get_module('mod') + self.mod.write( + '"""\ndocs\n"""\n\n\nimport bbb\nimport aaa\n' + "def f():\n print(mod)\nimport mod\n" + ) + pymod = self.project.get_module("mod") self.assertEqual( '"""\ndocs\n"""\n\n\nimport aaa\nimport bbb\n\n' - 'import mod\n\n\ndef f():\n print(mod)\n', - self.import_tools.sort_imports(pymod)) + "import mod\n\n\ndef f():\n print(mod)\n", + self.import_tools.sort_imports(pymod), + ) def test_get_changed_source_preserves_blank_lines(self): self.mod.write( '__author__ = "author"\n\nimport aaa\n\nimport bbb\n\n' - 'def f():\n print(mod)\n') - pymod = self.project.get_module('mod') + "def f():\n print(mod)\n" + ) + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) self.assertEqual( 'import aaa\n\nimport bbb\n\n__author__ = "author"\n\n' - 'def f():\n print(mod)\n', - module_with_imports.get_changed_source()) + "def f():\n print(mod)\n", + module_with_imports.get_changed_source(), + ) def test_sorting_future_imports(self): - self.mod.write('import os\nfrom __future__ import devision\n') - pymod = self.project.get_module('mod') + self.mod.write("import os\nfrom __future__ import devision\n") + pymod = self.project.get_module("mod") self.assertEqual( - 'from __future__ import devision\n\nimport os\n', - self.import_tools.sort_imports(pymod)) + "from __future__ import devision\n\nimport os\n", + self.import_tools.sort_imports(pymod), + ) def test_organizing_imports_all_star(self): - code = expected = dedent('''\ + code = expected = dedent("""\ from package import some_name __all__ = ["some_name"] - ''') + """) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - self.assertEqual( - expected, - self.import_tools.organize_imports(pymod)) + self.assertEqual(expected, self.import_tools.organize_imports(pymod)) def test_organizing_imports_all_star_with_variables(self): - code = expected = dedent('''\ + code = expected = dedent("""\ from package import name_one, name_two @@ -962,29 +1046,25 @@ def test_organizing_imports_all_star_with_variables(self): else: foo = 'name_two' __all__ = [foo] - ''') + """) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - self.assertEqual( - expected, - self.import_tools.organize_imports(pymod)) + self.assertEqual(expected, self.import_tools.organize_imports(pymod)) def test_organizing_imports_all_star_with_inline_if(self): - code = expected = dedent('''\ + code = expected = dedent("""\ from package import name_one, name_two __all__ = ['name_one' if something() else 'name_two'] - ''') + """) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - self.assertEqual( - expected, - self.import_tools.organize_imports(pymod)) + self.assertEqual(expected, self.import_tools.organize_imports(pymod)) - @testutils.only_for_versions_higher('3') + @testutils.only_for_versions_higher("3") def test_organizing_imports_all_star_tolerates_non_list_of_str_1(self): - code = expected = dedent('''\ + code = expected = dedent("""\ from package import name_one, name_two @@ -992,152 +1072,156 @@ def test_organizing_imports_all_star_tolerates_non_list_of_str_1(self): __all__ = [bar, *abc] + mylist __all__ = [foo, 'name_one', *abc] __all__ = [it for it in mylist] - ''') + """) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - self.assertEqual( - expected, - self.import_tools.organize_imports(pymod)) + self.assertEqual(expected, self.import_tools.organize_imports(pymod)) def test_organizing_imports_all_star_tolerates_non_list_of_str_2(self): - code = expected = dedent('''\ + code = expected = dedent("""\ from package import name_one, name_two foo = 'name_two' __all__ = [foo, 3, 'name_one'] __all__ = [it for it in mylist] - ''') + """) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - self.assertEqual( - expected, - self.import_tools.organize_imports(pymod)) + self.assertEqual(expected, self.import_tools.organize_imports(pymod)) @testutils.time_limit(60) def test_organizing_imports_all_star_no_infinite_loop(self): - code = expected = dedent('''\ + code = expected = dedent("""\ from package import name_one, name_two foo = bar bar = foo __all__ = [foo, 'name_one', 'name_two'] - ''') + """) self.mod.write(code) pymod = self.project.get_pymodule(self.mod) - self.assertEqual( - expected, - self.import_tools.organize_imports(pymod)) + self.assertEqual(expected, self.import_tools.organize_imports(pymod)) def test_customized_import_organization(self): - self.mod.write('import sys\nimport sys\n') + self.mod.write("import sys\nimport sys\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'import sys\n', - self.import_tools.organize_imports(pymod, unused=False)) + "import sys\n", self.import_tools.organize_imports(pymod, unused=False) + ) def test_customized_import_organization2(self): - self.mod.write('import sys\n') + self.mod.write("import sys\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'import sys\n', - self.import_tools.organize_imports(pymod, unused=False)) + "import sys\n", self.import_tools.organize_imports(pymod, unused=False) + ) def test_customized_import_organization3(self): - self.mod.write('import sys\nimport mod\n\n\nvar = 1\nprint(mod.var)\n') + self.mod.write("import sys\nimport mod\n\n\nvar = 1\nprint(mod.var)\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'import sys\n\n\nvar = 1\nprint(var)\n', - self.import_tools.organize_imports(pymod, unused=False)) + "import sys\n\n\nvar = 1\nprint(var)\n", + self.import_tools.organize_imports(pymod, unused=False), + ) def test_trivial_filtered_expand_stars(self): - self.pkg1.get_child('__init__.py').write('var1 = 1\n') - self.pkg2.get_child('__init__.py').write('var2 = 1\n') - self.mod.write('from pkg1 import *\nfrom pkg2 import *\n\n' - 'print(var1, var2)\n') + self.pkg1.get_child("__init__.py").write("var1 = 1\n") + self.pkg2.get_child("__init__.py").write("var2 = 1\n") + self.mod.write( + "from pkg1 import *\nfrom pkg2 import *\n\n" "print(var1, var2)\n" + ) pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from pkg1 import *\nfrom pkg2 import *\n\nprint(var1, var2)\n', - self.import_tools.expand_stars(pymod, lambda stmt: False)) + "from pkg1 import *\nfrom pkg2 import *\n\nprint(var1, var2)\n", + self.import_tools.expand_stars(pymod, lambda stmt: False), + ) def _line_filter(self, lineno): def import_filter(import_stmt): return import_stmt.start_line <= lineno < import_stmt.end_line + return import_filter def test_filtered_expand_stars(self): - self.pkg1.get_child('__init__.py').write('var1 = 1\n') - self.pkg2.get_child('__init__.py').write('var2 = 1\n') - self.mod.write('from pkg1 import *\nfrom pkg2 import *\n\n' - 'print(var1, var2)\n') + self.pkg1.get_child("__init__.py").write("var1 = 1\n") + self.pkg2.get_child("__init__.py").write("var2 = 1\n") + self.mod.write( + "from pkg1 import *\nfrom pkg2 import *\n\n" "print(var1, var2)\n" + ) pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from pkg1 import *\nfrom pkg2 import var2\n\nprint(var1, var2)\n', - self.import_tools.expand_stars(pymod, self._line_filter(2))) + "from pkg1 import *\nfrom pkg2 import var2\n\nprint(var1, var2)\n", + self.import_tools.expand_stars(pymod, self._line_filter(2)), + ) def test_filtered_relative_to_absolute(self): - self.mod3.write('var = 1') - self.mod2.write('import mod3\n\nprint(mod3.var)\n') + self.mod3.write("var = 1") + self.mod2.write("import mod3\n\nprint(mod3.var)\n") pymod = self.project.get_pymodule(self.mod2) self.assertEqual( - 'import mod3\n\nprint(mod3.var)\n', - self.import_tools.relatives_to_absolutes( - pymod, lambda stmt: False)) + "import mod3\n\nprint(mod3.var)\n", + self.import_tools.relatives_to_absolutes(pymod, lambda stmt: False), + ) self.assertEqual( - 'import pkg2.mod3\n\nprint(pkg2.mod3.var)\n', - self.import_tools.relatives_to_absolutes( - pymod, self._line_filter(1))) + "import pkg2.mod3\n\nprint(pkg2.mod3.var)\n", + self.import_tools.relatives_to_absolutes(pymod, self._line_filter(1)), + ) def test_filtered_froms_to_normals(self): - self.pkg1.get_child('__init__.py').write('var1 = 1\n') - self.pkg2.get_child('__init__.py').write('var2 = 1\n') - self.mod.write('from pkg1 import var1\nfrom pkg2 import var2\n\n' - 'print(var1, var2)\n') + self.pkg1.get_child("__init__.py").write("var1 = 1\n") + self.pkg2.get_child("__init__.py").write("var2 = 1\n") + self.mod.write( + "from pkg1 import var1\nfrom pkg2 import var2\n\n" "print(var1, var2)\n" + ) pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from pkg1 import var1\nfrom pkg2 ' - 'import var2\n\nprint(var1, var2)\n', - self.import_tools.expand_stars(pymod, lambda stmt: False)) + "from pkg1 import var1\nfrom pkg2 " "import var2\n\nprint(var1, var2)\n", + self.import_tools.expand_stars(pymod, lambda stmt: False), + ) self.assertEqual( - 'from pkg1 import var1\nimport pkg2\n\nprint(var1, pkg2.var2)\n', - self.import_tools.froms_to_imports(pymod, self._line_filter(2))) + "from pkg1 import var1\nimport pkg2\n\nprint(var1, pkg2.var2)\n", + self.import_tools.froms_to_imports(pymod, self._line_filter(2)), + ) def test_filtered_froms_to_normals2(self): - self.pkg1.get_child('__init__.py').write('var1 = 1\n') - self.pkg2.get_child('__init__.py').write('var2 = 1\n') - self.mod.write('from pkg1 import *\nfrom pkg2 import *\n\n' - 'print(var1, var2)\n') + self.pkg1.get_child("__init__.py").write("var1 = 1\n") + self.pkg2.get_child("__init__.py").write("var2 = 1\n") + self.mod.write( + "from pkg1 import *\nfrom pkg2 import *\n\n" "print(var1, var2)\n" + ) pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from pkg1 import *\nimport pkg2\n\nprint(var1, pkg2.var2)\n', - self.import_tools.froms_to_imports(pymod, self._line_filter(2))) + "from pkg1 import *\nimport pkg2\n\nprint(var1, pkg2.var2)\n", + self.import_tools.froms_to_imports(pymod, self._line_filter(2)), + ) def test_filtered_handle_long_imports(self): - self.mod.write('import p1.p2.p3.m1\nimport pkg1.mod1\n\n\n' - 'm = p1.p2.p3.m1, pkg1.mod1\n') + self.mod.write( + "import p1.p2.p3.m1\nimport pkg1.mod1\n\n\n" "m = p1.p2.p3.m1, pkg1.mod1\n" + ) pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'import p1.p2.p3.m1\nfrom pkg1 import mod1\n\n\n' - 'm = p1.p2.p3.m1, mod1\n', + "import p1.p2.p3.m1\nfrom pkg1 import mod1\n\n\n" "m = p1.p2.p3.m1, mod1\n", self.import_tools.handle_long_imports( - pymod, maxlength=5, - import_filter=self._line_filter(2))) + pymod, maxlength=5, import_filter=self._line_filter(2) + ), + ) def test_filtering_and_import_actions_with_more_than_one_phase(self): - self.pkg1.get_child('__init__.py').write('var1 = 1\n') - self.pkg2.get_child('__init__.py').write('var2 = 1\n') - self.mod.write('from pkg1 import *\nfrom pkg2 import *\n\n' - 'print(var2)\n') + self.pkg1.get_child("__init__.py").write("var1 = 1\n") + self.pkg2.get_child("__init__.py").write("var2 = 1\n") + self.mod.write("from pkg1 import *\nfrom pkg2 import *\n\n" "print(var2)\n") pymod = self.project.get_pymodule(self.mod) self.assertEqual( - 'from pkg2 import *\n\nprint(var2)\n', - self.import_tools.expand_stars(pymod, self._line_filter(1))) + "from pkg2 import *\n\nprint(var2)\n", + self.import_tools.expand_stars(pymod, self._line_filter(1)), + ) def test_non_existent_module_and_used_imports(self): - self.mod.write( - 'from does_not_exist import func\n\nfunc()\n') - pymod = self.project.get_module('mod') + self.mod.write("from does_not_exist import func\n\nfunc()\n") + pymod = self.project.get_module("mod") module_with_imports = self.import_tools.module_imports(pymod) imports = module_with_imports.get_used_imports(pymod) @@ -1145,67 +1229,66 @@ def test_non_existent_module_and_used_imports(self): class AddImportTest(unittest.TestCase): - def setUp(self): super(AddImportTest, self).setUp() self.project = testutils.sample_project() - self.mod1 = testutils.create_module(self.project, 'mod1') - self.mod2 = testutils.create_module(self.project, 'mod2') - self.pkg = testutils.create_package(self.project, 'pkg') - self.mod3 = testutils.create_module(self.project, 'mod3', self.pkg) + self.mod1 = testutils.create_module(self.project, "mod1") + self.mod2 = testutils.create_module(self.project, "mod2") + self.pkg = testutils.create_package(self.project, "pkg") + self.mod3 = testutils.create_module(self.project, "mod3", self.pkg) def tearDown(self): testutils.remove_project(self.project) super(AddImportTest, self).tearDown() def test_normal_imports(self): - self.mod2.write('myvar = None\n') - self.mod1.write('\n') - pymod = self.project.get_module('mod1') - result, name = add_import(self.project, pymod, 'mod2', 'myvar') - self.assertEqual('import mod2\n', result) - self.assertEqual('mod2.myvar', name) + self.mod2.write("myvar = None\n") + self.mod1.write("\n") + pymod = self.project.get_module("mod1") + result, name = add_import(self.project, pymod, "mod2", "myvar") + self.assertEqual("import mod2\n", result) + self.assertEqual("mod2.myvar", name) def test_not_reimporting_a_name(self): - self.mod2.write('myvar = None\n') - self.mod1.write('from mod2 import myvar\n') - pymod = self.project.get_module('mod1') - result, name = add_import(self.project, pymod, 'mod2', 'myvar') - self.assertEqual('from mod2 import myvar\n', result) - self.assertEqual('myvar', name) + self.mod2.write("myvar = None\n") + self.mod1.write("from mod2 import myvar\n") + pymod = self.project.get_module("mod1") + result, name = add_import(self.project, pymod, "mod2", "myvar") + self.assertEqual("from mod2 import myvar\n", result) + self.assertEqual("myvar", name) def test_adding_import_when_siblings_are_imported(self): - self.mod2.write('var1 = None\nvar2 = None\n') - self.mod1.write('from mod2 import var1\n') - pymod = self.project.get_module('mod1') - result, name = add_import(self.project, pymod, 'mod2', 'var2') - self.assertEqual('from mod2 import var1, var2\n', result) - self.assertEqual('var2', name) + self.mod2.write("var1 = None\nvar2 = None\n") + self.mod1.write("from mod2 import var1\n") + pymod = self.project.get_module("mod1") + result, name = add_import(self.project, pymod, "mod2", "var2") + self.assertEqual("from mod2 import var1, var2\n", result) + self.assertEqual("var2", name) def test_adding_import_when_the_package_is_imported(self): - self.pkg.get_child('__init__.py').write('var1 = None\n') - self.mod3.write('var2 = None\n') - self.mod1.write('from pkg import var1\n') - pymod = self.project.get_module('mod1') - result, name = add_import(self.project, pymod, 'pkg.mod3', 'var2') - self.assertEqual('from pkg import var1, mod3\n', result) - self.assertEqual('mod3.var2', name) + self.pkg.get_child("__init__.py").write("var1 = None\n") + self.mod3.write("var2 = None\n") + self.mod1.write("from pkg import var1\n") + pymod = self.project.get_module("mod1") + result, name = add_import(self.project, pymod, "pkg.mod3", "var2") + self.assertEqual("from pkg import var1, mod3\n", result) + self.assertEqual("mod3.var2", name) def test_adding_import_for_modules_instead_of_names(self): - self.pkg.get_child('__init__.py').write('var1 = None\n') - self.mod3.write('\n') - self.mod1.write('from pkg import var1\n') - pymod = self.project.get_module('mod1') - result, name = add_import(self.project, pymod, 'pkg.mod3', None) - self.assertEqual('from pkg import var1, mod3\n', result) - self.assertEqual('mod3', name) + self.pkg.get_child("__init__.py").write("var1 = None\n") + self.mod3.write("\n") + self.mod1.write("from pkg import var1\n") + pymod = self.project.get_module("mod1") + result, name = add_import(self.project, pymod, "pkg.mod3", None) + self.assertEqual("from pkg import var1, mod3\n", result) + self.assertEqual("mod3", name) def test_adding_import_for_modules_with_normal_duplicate_imports(self): - self.pkg.get_child('__init__.py').write('var1 = None\n') - self.mod3.write('\n') - self.mod1.write('import pkg.mod3\n') - pymod = self.project.get_module('mod1') - result, name = add_import(self.project, pymod, 'pkg.mod3', None) - self.assertEqual('import pkg.mod3\n', result) - self.assertEqual('pkg.mod3', name) + self.pkg.get_child("__init__.py").write("var1 = None\n") + self.mod3.write("\n") + self.mod1.write("import pkg.mod3\n") + pymod = self.project.get_module("mod1") + result, name = add_import(self.project, pymod, "pkg.mod3", None) + self.assertEqual("import pkg.mod3\n", result) + self.assertEqual("pkg.mod3", name) diff --git a/ropetest/refactor/inlinetest.py b/ropetest/refactor/inlinetest.py index c1566d4cc..b7ce85e25 100644 --- a/ropetest/refactor/inlinetest.py +++ b/ropetest/refactor/inlinetest.py @@ -9,13 +9,12 @@ class InlineTest(unittest.TestCase): - def setUp(self): super(InlineTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') - self.mod2 = testutils.create_module(self.project, 'mod2') + self.mod = testutils.create_module(self.project, "mod") + self.mod2 = testutils.create_module(self.project, "mod2") def tearDown(self): testutils.remove_project(self.project) @@ -33,623 +32,640 @@ def _inline2(self, resource, offset, **kwds): return self.mod.read() def test_simple_case(self): - code = 'a_var = 10\nanother_var = a_var\n' - refactored = self._inline(code, code.index('a_var') + 1) - self.assertEqual('another_var = 10\n', refactored) + code = "a_var = 10\nanother_var = a_var\n" + refactored = self._inline(code, code.index("a_var") + 1) + self.assertEqual("another_var = 10\n", refactored) def test_empty_case(self): - code = 'a_var = 10\n' - refactored = self._inline(code, code.index('a_var') + 1) - self.assertEqual('', refactored) + code = "a_var = 10\n" + refactored = self._inline(code, code.index("a_var") + 1) + self.assertEqual("", refactored) def test_long_definition(self): - code = 'a_var = 10 + (10 + 10)\nanother_var = a_var\n' - refactored = self._inline(code, code.index('a_var') + 1) - self.assertEqual('another_var = 10 + (10 + 10)\n', refactored) + code = "a_var = 10 + (10 + 10)\nanother_var = a_var\n" + refactored = self._inline(code, code.index("a_var") + 1) + self.assertEqual("another_var = 10 + (10 + 10)\n", refactored) def test_explicit_continuation(self): - code = 'a_var = (10 +\n 10)\nanother_var = a_var\n' - refactored = self._inline(code, code.index('a_var') + 1) - self.assertEqual('another_var = (10 + 10)\n', refactored) + code = "a_var = (10 +\n 10)\nanother_var = a_var\n" + refactored = self._inline(code, code.index("a_var") + 1) + self.assertEqual("another_var = (10 + 10)\n", refactored) def test_implicit_continuation(self): - code = 'a_var = 10 +\\\n 10\nanother_var = a_var\n' - refactored = self._inline(code, code.index('a_var') + 1) - self.assertEqual('another_var = 10 + 10\n', refactored) + code = "a_var = 10 +\\\n 10\nanother_var = a_var\n" + refactored = self._inline(code, code.index("a_var") + 1) + self.assertEqual("another_var = 10 + 10\n", refactored) def test_inlining_at_the_end_of_input(self): - code = 'a = 1\nb = a' - refactored = self._inline(code, code.index('a') + 1) - self.assertEqual('b = 1', refactored) + code = "a = 1\nb = a" + refactored = self._inline(code, code.index("a") + 1) + self.assertEqual("b = 1", refactored) def test_on_classes(self): - code = 'class AClass(object):\n pass\n' + code = "class AClass(object):\n pass\n" with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline(code, code.index('AClass') + 1) + self._inline(code, code.index("AClass") + 1) def test_multiple_assignments(self): - code = 'a_var = 10\na_var = 20\n' + code = "a_var = 10\na_var = 20\n" with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline(code, code.index('a_var') + 1) + self._inline(code, code.index("a_var") + 1) def test_tuple_assignments(self): - code = 'a_var, another_var = (20, 30)\n' + code = "a_var, another_var = (20, 30)\n" with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline(code, code.index('a_var') + 1) + self._inline(code, code.index("a_var") + 1) def test_on_unknown_vars(self): - code = 'a_var = another_var\n' + code = "a_var = another_var\n" with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline(code, code.index('another_var') + 1) + self._inline(code, code.index("another_var") + 1) def test_attribute_inlining(self): - code = 'class A(object):\n def __init__(self):\n' \ - ' self.an_attr = 3\n range(self.an_attr)\n' - refactored = self._inline(code, code.index('an_attr') + 1) - expected = 'class A(object):\n def __init__(self):\n' \ - ' range(3)\n' + code = ( + "class A(object):\n def __init__(self):\n" + " self.an_attr = 3\n range(self.an_attr)\n" + ) + refactored = self._inline(code, code.index("an_attr") + 1) + expected = "class A(object):\n def __init__(self):\n" " range(3)\n" self.assertEqual(expected, refactored) def test_attribute_inlining2(self): - code = 'class A(object):\n def __init__(self):\n' \ - ' self.an_attr = 3\n range(self.an_attr)\n' \ - 'a = A()\nrange(a.an_attr)' - refactored = self._inline(code, code.index('an_attr') + 1) - expected = 'class A(object):\n def __init__(self):\n' \ - ' range(3)\n' \ - 'a = A()\nrange(3)' + code = ( + "class A(object):\n def __init__(self):\n" + " self.an_attr = 3\n range(self.an_attr)\n" + "a = A()\nrange(a.an_attr)" + ) + refactored = self._inline(code, code.index("an_attr") + 1) + expected = ( + "class A(object):\n def __init__(self):\n" + " range(3)\n" + "a = A()\nrange(3)" + ) self.assertEqual(expected, refactored) def test_a_function_with_no_occurance(self): - self.mod.write('def a_func():\n pass\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('', self.mod.read()) + self.mod.write("def a_func():\n pass\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("", self.mod.read()) def test_a_function_with_no_occurance2(self): - self.mod.write('a_var = 10\ndef a_func():\n pass\nprint(a_var)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('a_var = 10\nprint(a_var)\n', self.mod.read()) + self.mod.write("a_var = 10\ndef a_func():\n pass\nprint(a_var)\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("a_var = 10\nprint(a_var)\n", self.mod.read()) def test_replacing_calls_with_function_definition_in_other_modules(self): - self.mod.write('def a_func():\n print(1)\n') - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('import mod\nmod.a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('import mod\nprint(1)\n', mod1.read()) + self.mod.write("def a_func():\n print(1)\n") + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("import mod\nmod.a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("import mod\nprint(1)\n", mod1.read()) def test_replacing_calls_with_function_definition_in_other_modules2(self): - self.mod.write('def a_func():\n print(1)\n') - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('import mod\nif True:\n mod.a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('import mod\nif True:\n print(1)\n', mod1.read()) + self.mod.write("def a_func():\n print(1)\n") + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("import mod\nif True:\n mod.a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("import mod\nif True:\n print(1)\n", mod1.read()) def test_replacing_calls_with_method_definition_in_other_modules(self): - self.mod.write('class A(object):\n var = 10\n' - ' def a_func(self):\n print(1)\n') - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('import mod\nmod.A().a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('import mod\nprint(1)\n', mod1.read()) - self.assertEqual('class A(object):\n var = 10\n', self.mod.read()) + self.mod.write( + "class A(object):\n var = 10\n" + " def a_func(self):\n print(1)\n" + ) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("import mod\nmod.A().a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("import mod\nprint(1)\n", mod1.read()) + self.assertEqual("class A(object):\n var = 10\n", self.mod.read()) def test_replacing_calls_with_function_definition_in_defining_module(self): - self.mod.write('def a_func():\n print(1)\na_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print(1)\n', self.mod.read()) + self.mod.write("def a_func():\n print(1)\na_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print(1)\n", self.mod.read()) def test_replac_calls_with_function_definition_in_defining_module2(self): - self.mod.write('def a_func():\n ' - 'for i in range(10):\n print(1)\na_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('for i in range(10):\n print(1)\n', - self.mod.read()) + self.mod.write( + "def a_func():\n " "for i in range(10):\n print(1)\na_func()\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("for i in range(10):\n print(1)\n", self.mod.read()) def test_replacing_calls_with_method_definition_in_defining_modules(self): - self.mod.write('class A(object):\n var = 10\n' - ' def a_func(self):\n print(1)\nA().a_func()') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('class A(object):\n var = 10\nprint(1)\n', - self.mod.read()) + self.mod.write( + "class A(object):\n var = 10\n" + " def a_func(self):\n print(1)\nA().a_func()" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("class A(object):\n var = 10\nprint(1)\n", self.mod.read()) def test_parameters_with_the_same_name_as_passed(self): - self.mod.write('def a_func(var):\n ' - 'print(var)\nvar = 1\na_func(var)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('var = 1\nprint(var)\n', self.mod.read()) + self.mod.write("def a_func(var):\n " "print(var)\nvar = 1\na_func(var)\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("var = 1\nprint(var)\n", self.mod.read()) def test_parameters_with_the_same_name_as_passed2(self): - self.mod.write('def a_func(var):\n ' - 'print(var)\nvar = 1\na_func(var=var)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('var = 1\nprint(var)\n', self.mod.read()) + self.mod.write( + "def a_func(var):\n " "print(var)\nvar = 1\na_func(var=var)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("var = 1\nprint(var)\n", self.mod.read()) def test_simple_parameters_renaming(self): - self.mod.write('def a_func(param):\n ' - 'print(param)\nvar = 1\na_func(var)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('var = 1\nprint(var)\n', self.mod.read()) + self.mod.write( + "def a_func(param):\n " "print(param)\nvar = 1\na_func(var)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("var = 1\nprint(var)\n", self.mod.read()) def test_simple_parameters_renaming_for_multiple_params(self): - self.mod.write('def a_func(param1, param2):\n p = param1 + param2\n' - 'var1 = 1\nvar2 = 1\na_func(var1, var2)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('var1 = 1\nvar2 = 1\np = var1 + var2\n', - self.mod.read()) + self.mod.write( + "def a_func(param1, param2):\n p = param1 + param2\n" + "var1 = 1\nvar2 = 1\na_func(var1, var2)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("var1 = 1\nvar2 = 1\np = var1 + var2\n", self.mod.read()) def test_parameters_renaming_for_passed_constants(self): - self.mod.write('def a_func(param):\n print(param)\na_func(1)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print(1)\n', self.mod.read()) + self.mod.write("def a_func(param):\n print(param)\na_func(1)\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print(1)\n", self.mod.read()) def test_parameters_renaming_for_passed_statements(self): - self.mod.write('def a_func(param):\n ' - 'print(param)\na_func((1 + 2) / 3)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print((1 + 2) / 3)\n', self.mod.read()) + self.mod.write("def a_func(param):\n " "print(param)\na_func((1 + 2) / 3)\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print((1 + 2) / 3)\n", self.mod.read()) def test_simple_parameters_renam_for_multiple_params_using_keywords(self): - self.mod.write('def a_func(param1, param2):\n ' - 'p = param1 + param2\n' - 'var1 = 1\nvar2 = 1\n' - 'a_func(param2=var1, param1=var2)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('var1 = 1\nvar2 = 1\np = var2 + var1\n', - self.mod.read()) + self.mod.write( + "def a_func(param1, param2):\n " + "p = param1 + param2\n" + "var1 = 1\nvar2 = 1\n" + "a_func(param2=var1, param1=var2)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("var1 = 1\nvar2 = 1\np = var2 + var1\n", self.mod.read()) def test_simple_params_renam_for_multi_params_using_mixed_keywords(self): - self.mod.write('def a_func(param1, param2):\n p = param1 + param2\n' - 'var1 = 1\nvar2 = 1\na_func(var2, param2=var1)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('var1 = 1\nvar2 = 1\np = var2 + var1\n', - self.mod.read()) + self.mod.write( + "def a_func(param1, param2):\n p = param1 + param2\n" + "var1 = 1\nvar2 = 1\na_func(var2, param2=var1)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("var1 = 1\nvar2 = 1\np = var2 + var1\n", self.mod.read()) def test_simple_putting_in_default_arguments(self): - self.mod.write('def a_func(param=None):\n print(param)\n' - 'a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print(None)\n', self.mod.read()) + self.mod.write("def a_func(param=None):\n print(param)\n" "a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print(None)\n", self.mod.read()) def test_overriding_default_arguments(self): - self.mod.write('def a_func(param1=1, param2=2):' - '\n print(param1, param2)\n' - 'a_func(param2=3)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print(1, 3)\n', self.mod.read()) + self.mod.write( + "def a_func(param1=1, param2=2):" + "\n print(param1, param2)\n" + "a_func(param2=3)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print(1, 3)\n", self.mod.read()) def test_arguments_containing_comparisons(self): - self.mod.write('def a_func(param1, param2, param3):' - '\n param2.name\n' - 'a_func(2 <= 1, item, True)\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('item.name\n', self.mod.read()) + self.mod.write( + "def a_func(param1, param2, param3):" + "\n param2.name\n" + "a_func(2 <= 1, item, True)\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("item.name\n", self.mod.read()) def test_badly_formatted_text(self): - self.mod.write('def a_func ( param1 = 1 ,param2 = 2 ) :' - '\n print(param1, param2)\n' - 'a_func ( param2 \n = 3 ) \n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print(1, 3)\n', self.mod.read()) + self.mod.write( + "def a_func ( param1 = 1 ,param2 = 2 ) :" + "\n print(param1, param2)\n" + "a_func ( param2 \n = 3 ) \n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print(1, 3)\n", self.mod.read()) def test_passing_first_arguments_for_methods(self): - a_class = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - ' self.a_func(self.var)\n' \ - ' def a_func(self, param):\n' \ - ' print(param)\n' + a_class = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + " self.a_func(self.var)\n" + " def a_func(self, param):\n" + " print(param)\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - ' print(self.var)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + " print(self.var)\n" + ) self.assertEqual(expected, self.mod.read()) def test_passing_first_arguments_for_methods2(self): - a_class = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - ' def a_func(self, param):\n' \ - ' print(param, self.var)\n' \ - 'an_a = A()\n' \ - 'an_a.a_func(1)\n' + a_class = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + " def a_func(self, param):\n" + " print(param, self.var)\n" + "an_a = A()\n" + "an_a.a_func(1)\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - 'an_a = A()\n' \ - 'print(1, an_a.var)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + "an_a = A()\n" + "print(1, an_a.var)\n" + ) self.assertEqual(expected, self.mod.read()) def test_passing_first_arguments_for_methods3(self): - a_class = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - ' def a_func(self, param):\n' \ - ' print(param, self.var)\n' \ - 'an_a = A()\n' \ - 'A.a_func(an_a, 1)\n' + a_class = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + " def a_func(self, param):\n" + " print(param, self.var)\n" + "an_a = A()\n" + "A.a_func(an_a, 1)\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' def __init__(self):\n' \ - ' self.var = 1\n' \ - 'an_a = A()\n' \ - 'print(1, an_a.var)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = ( + "class A(object):\n" + " def __init__(self):\n" + " self.var = 1\n" + "an_a = A()\n" + "print(1, an_a.var)\n" + ) self.assertEqual(expected, self.mod.read()) def test_inlining_staticmethods(self): - a_class = 'class A(object):\n' \ - ' @staticmethod\n' \ - ' def a_func(param):\n' \ - ' print(param)\n' \ - 'A.a_func(1)\n' + a_class = ( + "class A(object):\n" + " @staticmethod\n" + " def a_func(param):\n" + " print(param)\n" + "A.a_func(1)\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' pass\n' \ - 'print(1)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = "class A(object):\n" " pass\n" "print(1)\n" self.assertEqual(expected, self.mod.read()) def test_static_methods2(self): - a_class = 'class A(object):\n' \ - ' var = 10\n' \ - ' @staticmethod\n' \ - ' def a_func(param):\n' \ - ' print(param)\n' \ - 'an_a = A()\n' \ - 'an_a.a_func(1)\n' \ - 'A.a_func(2)\n' + a_class = ( + "class A(object):\n" + " var = 10\n" + " @staticmethod\n" + " def a_func(param):\n" + " print(param)\n" + "an_a = A()\n" + "an_a.a_func(1)\n" + "A.a_func(2)\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' var = 10\n' \ - 'an_a = A()\n' \ - 'print(1)\n' \ - 'print(2)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = ( + "class A(object):\n" + " var = 10\n" + "an_a = A()\n" + "print(1)\n" + "print(2)\n" + ) self.assertEqual(expected, self.mod.read()) def test_inlining_classmethods(self): - a_class = 'class A(object):\n' \ - ' @classmethod\n' \ - ' def a_func(cls, param):\n' \ - ' print(param)\n' \ - 'A.a_func(1)\n' + a_class = ( + "class A(object):\n" + " @classmethod\n" + " def a_func(cls, param):\n" + " print(param)\n" + "A.a_func(1)\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' pass\n' \ - 'print(1)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = "class A(object):\n" " pass\n" "print(1)\n" self.assertEqual(expected, self.mod.read()) def test_inlining_classmethods2(self): - a_class = 'class A(object):\n' \ - ' @classmethod\n' \ - ' def a_func(cls, param):\n' \ - ' return cls\n' \ - 'print(A.a_func(1))\n' + a_class = ( + "class A(object):\n" + " @classmethod\n" + " def a_func(cls, param):\n" + " return cls\n" + "print(A.a_func(1))\n" + ) self.mod.write(a_class) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - expected = 'class A(object):\n' \ - ' pass\n' \ - 'print(A)\n' + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + expected = "class A(object):\n" " pass\n" "print(A)\n" self.assertEqual(expected, self.mod.read()) def test_simple_return_values_and_inlining_functions(self): - self.mod.write('def a_func():\n return 1\na = a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('a = 1\n', - self.mod.read()) + self.mod.write("def a_func():\n return 1\na = a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("a = 1\n", self.mod.read()) def test_simple_return_values_and_inlining_lonely_functions(self): - self.mod.write('def a_func():\n return 1\na_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('1\n', self.mod.read()) + self.mod.write("def a_func():\n return 1\na_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("1\n", self.mod.read()) def test_empty_returns_and_inlining_lonely_functions(self): - self.mod.write('def a_func():\n ' - 'if True:\n return\na_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('if True:\n pass\n', self.mod.read()) + self.mod.write("def a_func():\n " "if True:\n return\na_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("if True:\n pass\n", self.mod.read()) def test_multiple_returns(self): - self.mod.write('def less_than_five(var):\n if var < 5:\n' - ' return True\n return False\n' - 'a = less_than_five(2)\n') + self.mod.write( + "def less_than_five(var):\n if var < 5:\n" + " return True\n return False\n" + "a = less_than_five(2)\n" + ) with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline2(self.mod, self.mod.read().index('less') + 1) + self._inline2(self.mod, self.mod.read().index("less") + 1) def test_multiple_returns_and_not_using_the_value(self): - self.mod.write('def less_than_five(var):\n if var < 5:\n' - ' return True\n ' - 'return False\nless_than_five(2)\n') - self._inline2(self.mod, self.mod.read().index('less') + 1) - self.assertEqual('if 2 < 5:\n True\nFalse\n', self.mod.read()) + self.mod.write( + "def less_than_five(var):\n if var < 5:\n" + " return True\n " + "return False\nless_than_five(2)\n" + ) + self._inline2(self.mod, self.mod.read().index("less") + 1) + self.assertEqual("if 2 < 5:\n True\nFalse\n", self.mod.read()) def test_raising_exception_for_list_arguments(self): - self.mod.write('def a_func(*args):\n print(args)\na_func(1)\n') + self.mod.write("def a_func(*args):\n print(args)\na_func(1)\n") with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline2(self.mod, self.mod.read().index('a_func') + 1) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) def test_raising_exception_for_list_keywods(self): - self.mod.write('def a_func(**kwds):\n print(kwds)\na_func(n=1)\n') + self.mod.write("def a_func(**kwds):\n print(kwds)\na_func(n=1)\n") with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline2(self.mod, self.mod.read().index('a_func') + 1) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) def test_function_parameters_and_returns_in_other_functions(self): - code = 'def a_func(param1, param2):\n' \ - ' return param1 + param2\n' \ - 'range(a_func(20, param2=abs(10)))\n' + code = ( + "def a_func(param1, param2):\n" + " return param1 + param2\n" + "range(a_func(20, param2=abs(10)))\n" + ) self.mod.write(code) - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('range(20 + abs(10))\n', self.mod.read()) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("range(20 + abs(10))\n", self.mod.read()) def test_function_references_other_than_call(self): - self.mod.write('def a_func(param):\n print(param)\nf = a_func\n') + self.mod.write("def a_func(param):\n print(param)\nf = a_func\n") with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline2(self.mod, self.mod.read().index('a_func') + 1) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) def test_function_referencing_itself(self): - self.mod.write('def a_func(var):\n func = a_func\n') + self.mod.write("def a_func(var):\n func = a_func\n") with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline2(self.mod, self.mod.read().index('a_func') + 1) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) def test_recursive_functions(self): - self.mod.write('def a_func(var):\n a_func(var)\n') + self.mod.write("def a_func(var):\n a_func(var)\n") with self.assertRaises(rope.base.exceptions.RefactoringError): - self._inline2(self.mod, self.mod.read().index('a_func') + 1) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) # TODO: inlining on function parameters def xxx_test_inlining_function_default_parameters(self): - self.mod.write('def a_func(p1=1):\n pass\na_func()\n') - self._inline2(self.mod, self.mod.read().index('p1') + 1) - self.assertEqual('def a_func(p1=1):\n pass\na_func()\n', - self.mod.read()) + self.mod.write("def a_func(p1=1):\n pass\na_func()\n") + self._inline2(self.mod, self.mod.read().index("p1") + 1) + self.assertEqual("def a_func(p1=1):\n pass\na_func()\n", self.mod.read()) def test_simple_inlining_after_extra_indented_lines(self): - self.mod.write('def a_func():\n for i in range(10):\n pass\n' - 'if True:\n pass\na_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('if True:\n pass\nfor i in range(10):' - '\n pass\n', - self.mod.read()) + self.mod.write( + "def a_func():\n for i in range(10):\n pass\n" + "if True:\n pass\na_func()\n" + ) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual( + "if True:\n pass\nfor i in range(10):" "\n pass\n", self.mod.read() + ) def test_inlining_a_function_with_pydoc(self): self.mod.write('def a_func():\n """docs"""\n a = 1\na_func()') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('a = 1\n', self.mod.read()) + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("a = 1\n", self.mod.read()) def test_inlining_methods(self): - self.mod.write("class A(object):\n name = 'hey'\n" - " def get_name(self):\n return self.name\n" - "a = A()\nname = a.get_name()\n") - self._inline2(self.mod, self.mod.read().rindex('get_name') + 1) - self.assertEqual("class A(object):\n name = 'hey'\n" - "a = A()\nname = a.name\n", self.mod.read()) + self.mod.write( + "class A(object):\n name = 'hey'\n" + " def get_name(self):\n return self.name\n" + "a = A()\nname = a.get_name()\n" + ) + self._inline2(self.mod, self.mod.read().rindex("get_name") + 1) + self.assertEqual( + "class A(object):\n name = 'hey'\n" "a = A()\nname = a.name\n", + self.mod.read(), + ) def test_simple_returns_with_backslashes(self): - self.mod.write('def a_func():\n return 1' - '\\\n + 2\na = a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('a = 1 + 2\n', self.mod.read()) + self.mod.write("def a_func():\n return 1" "\\\n + 2\na = a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("a = 1 + 2\n", self.mod.read()) def test_a_function_with_pass_body(self): - self.mod.write('def a_func():\n print(1)\na = a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func') + 1) - self.assertEqual('print(1)\na = None\n', self.mod.read()) + self.mod.write("def a_func():\n print(1)\na = a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func") + 1) + self.assertEqual("print(1)\na = None\n", self.mod.read()) def test_inlining_the_last_method_of_a_class(self): - self.mod.write('class A(object):\n' - ' def a_func(self):\n pass\n') - self._inline2(self.mod, self.mod.read().rindex('a_func') + 1) - self.assertEqual('class A(object):\n pass\n', - self.mod.read()) + self.mod.write("class A(object):\n" " def a_func(self):\n pass\n") + self._inline2(self.mod, self.mod.read().rindex("a_func") + 1) + self.assertEqual("class A(object):\n pass\n", self.mod.read()) def test_adding_needed_imports_in_the_dest_module(self): - self.mod.write('import sys\n\ndef ver():\n print(sys.version)\n') - self.mod2.write('import mod\n\nmod.ver()') - self._inline2(self.mod, self.mod.read().index('ver') + 1) - self.assertEqual('import mod\nimport sys\n\nprint(sys.version)\n', - self.mod2.read()) + self.mod.write("import sys\n\ndef ver():\n print(sys.version)\n") + self.mod2.write("import mod\n\nmod.ver()") + self._inline2(self.mod, self.mod.read().index("ver") + 1) + self.assertEqual( + "import mod\nimport sys\n\nprint(sys.version)\n", self.mod2.read() + ) def test_adding_needed_imports_in_the_dest_module_removing_selfs(self): - self.mod.write('import mod2\n\ndef f():\n print(mod2.var)\n') - self.mod2.write('import mod\n\nvar = 1\nmod.f()\n') - self._inline2(self.mod, self.mod.read().index('f(') + 1) - self.assertEqual('import mod\n\nvar = 1\nprint(var)\n', - self.mod2.read()) + self.mod.write("import mod2\n\ndef f():\n print(mod2.var)\n") + self.mod2.write("import mod\n\nvar = 1\nmod.f()\n") + self._inline2(self.mod, self.mod.read().index("f(") + 1) + self.assertEqual("import mod\n\nvar = 1\nprint(var)\n", self.mod2.read()) def test_handling_relative_imports_when_inlining(self): - pkg = testutils.create_package(self.project, 'pkg') - mod3 = testutils.create_module(self.project, 'mod3', pkg) - mod4 = testutils.create_module(self.project, 'mod4', pkg) - mod4.write('var = 1\n') - mod3.write('from . import mod4\n\ndef f():\n print(mod4.var)\n') - self.mod.write('import pkg.mod3\n\npkg.mod3.f()\n') - self._inline2(self.mod, self.mod.read().index('f(') + 1) + pkg = testutils.create_package(self.project, "pkg") + mod3 = testutils.create_module(self.project, "mod3", pkg) + mod4 = testutils.create_module(self.project, "mod4", pkg) + mod4.write("var = 1\n") + mod3.write("from . import mod4\n\ndef f():\n print(mod4.var)\n") + self.mod.write("import pkg.mod3\n\npkg.mod3.f()\n") + self._inline2(self.mod, self.mod.read().index("f(") + 1) # Cannot determine the exact import - self.assertTrue('\n\nprint(mod4.var)\n' in self.mod.read()) + self.assertTrue("\n\nprint(mod4.var)\n" in self.mod.read()) def test_adding_needed_imports_for_elements_in_source(self): - self.mod.write('def f1():\n return f2()\ndef f2():\n return 1\n') - self.mod2.write('import mod\n\nprint(mod.f1())\n') - self._inline2(self.mod, self.mod.read().index('f1') + 1) - self.assertEqual('import mod\nfrom mod import f2\n\nprint(f2())\n', - self.mod2.read()) + self.mod.write("def f1():\n return f2()\ndef f2():\n return 1\n") + self.mod2.write("import mod\n\nprint(mod.f1())\n") + self._inline2(self.mod, self.mod.read().index("f1") + 1) + self.assertEqual( + "import mod\nfrom mod import f2\n\nprint(f2())\n", self.mod2.read() + ) def test_relative_imports_and_changing_inlining_body(self): - pkg = testutils.create_package(self.project, 'pkg') - mod3 = testutils.create_module(self.project, 'mod3', pkg) - mod4 = testutils.create_module(self.project, 'mod4', pkg) - mod4.write('var = 1\n') - mod3.write('import mod4\n\ndef f():\n print(mod4.var)\n') - self.mod.write('import pkg.mod3\n\npkg.mod3.f()\n') - self._inline2(self.mod, self.mod.read().index('f(') + 1) + pkg = testutils.create_package(self.project, "pkg") + mod3 = testutils.create_module(self.project, "mod3", pkg) + mod4 = testutils.create_module(self.project, "mod4", pkg) + mod4.write("var = 1\n") + mod3.write("import mod4\n\ndef f():\n print(mod4.var)\n") + self.mod.write("import pkg.mod3\n\npkg.mod3.f()\n") + self._inline2(self.mod, self.mod.read().index("f(") + 1) self.assertEqual( - 'import pkg.mod3\nimport pkg.mod4\n\nprint(pkg.mod4.var)\n', - self.mod.read()) + "import pkg.mod3\nimport pkg.mod4\n\nprint(pkg.mod4.var)\n", self.mod.read() + ) def test_inlining_with_different_returns(self): - self.mod.write('def f(p):\n return p\n' - 'print(f(1))\nprint(f(2))\nprint(f(1))\n') - self._inline2(self.mod, self.mod.read().index('f(') + 1) - self.assertEqual('print(1)\nprint(2)\nprint(1)\n', - self.mod.read()) + self.mod.write( + "def f(p):\n return p\n" "print(f(1))\nprint(f(2))\nprint(f(1))\n" + ) + self._inline2(self.mod, self.mod.read().index("f(") + 1) + self.assertEqual("print(1)\nprint(2)\nprint(1)\n", self.mod.read()) def test_not_removing_definition_for_variables(self): - code = 'a_var = 10\nanother_var = a_var\n' - refactored = self._inline(code, code.index('a_var') + 1, - remove=False) - self.assertEqual('a_var = 10\nanother_var = 10\n', refactored) + code = "a_var = 10\nanother_var = a_var\n" + refactored = self._inline(code, code.index("a_var") + 1, remove=False) + self.assertEqual("a_var = 10\nanother_var = 10\n", refactored) def test_not_removing_definition_for_methods(self): - code = 'def func():\n print(1)\n\nfunc()\n' - refactored = self._inline(code, code.index('func') + 1, - remove=False) - self.assertEqual('def func():\n print(1)\n\nprint(1)\n', - refactored) + code = "def func():\n print(1)\n\nfunc()\n" + refactored = self._inline(code, code.index("func") + 1, remove=False) + self.assertEqual("def func():\n print(1)\n\nprint(1)\n", refactored) def test_only_current_for_methods(self): - code = 'def func():\n print(1)\n\nfunc()\nfunc()\n' - refactored = self._inline(code, code.rindex('func') + 1, - remove=False, only_current=True) - self.assertEqual('def func():\n print(1)\n\nfunc()\nprint(1)\n', - refactored) + code = "def func():\n print(1)\n\nfunc()\nfunc()\n" + refactored = self._inline( + code, code.rindex("func") + 1, remove=False, only_current=True + ) + self.assertEqual("def func():\n print(1)\n\nfunc()\nprint(1)\n", refactored) def test_only_current_for_variables(self): - code = 'one = 1\n\na = one\nb = one\n' - refactored = self._inline(code, code.rindex('one') + 1, - remove=False, only_current=True) - self.assertEqual('one = 1\n\na = one\nb = 1\n', refactored) + code = "one = 1\n\na = one\nb = one\n" + refactored = self._inline( + code, code.rindex("one") + 1, remove=False, only_current=True + ) + self.assertEqual("one = 1\n\na = one\nb = 1\n", refactored) def test_inlining_one_line_functions(self): - code = 'def f(): return 1\nvar = f()\n' - refactored = self._inline(code, code.rindex('f')) - self.assertEqual('var = 1\n', refactored) + code = "def f(): return 1\nvar = f()\n" + refactored = self._inline(code, code.rindex("f")) + self.assertEqual("var = 1\n", refactored) def test_inlining_one_line_functions_with_breaks(self): - code = 'def f(\np): return p\nvar = f(1)\n' - refactored = self._inline(code, code.rindex('f')) - self.assertEqual('var = 1\n', refactored) + code = "def f(\np): return p\nvar = f(1)\n" + refactored = self._inline(code, code.rindex("f")) + self.assertEqual("var = 1\n", refactored) def test_inlining_one_line_functions_with_breaks2(self): - code = 'def f(\n): return 1\nvar = f()\n' - refactored = self._inline(code, code.rindex('f')) - self.assertEqual('var = 1\n', refactored) + code = "def f(\n): return 1\nvar = f()\n" + refactored = self._inline(code, code.rindex("f")) + self.assertEqual("var = 1\n", refactored) def test_resources_parameter(self): - self.mod.write('def a_func():\n print(1)\n') - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('import mod\nmod.a_func()\n') - self._inline2(self.mod, self.mod.read().index('a_func'), - resources=[self.mod]) - self.assertEqual('', self.mod.read()) - self.assertEqual('import mod\nmod.a_func()\n', mod1.read()) + self.mod.write("def a_func():\n print(1)\n") + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("import mod\nmod.a_func()\n") + self._inline2(self.mod, self.mod.read().index("a_func"), resources=[self.mod]) + self.assertEqual("", self.mod.read()) + self.assertEqual("import mod\nmod.a_func()\n", mod1.read()) def test_inlining_parameters(self): - code = 'def f(p=1):\n pass\nf()\n' - result = self._inline(code, code.index('p')) - self.assertEqual('def f(p=1):\n pass\nf(1)\n', result) + code = "def f(p=1):\n pass\nf()\n" + result = self._inline(code, code.index("p")) + self.assertEqual("def f(p=1):\n pass\nf(1)\n", result) def test_inlining_function_with_line_breaks_in_args(self): - code = 'def f(p): return p\nvar = f(1 +\n1)\n' - refactored = self._inline(code, code.rindex('f')) - self.assertEqual('var = 1 + 1\n', refactored) + code = "def f(p): return p\nvar = f(1 +\n1)\n" + refactored = self._inline(code, code.rindex("f")) + self.assertEqual("var = 1 + 1\n", refactored) def test_inlining_variables_before_comparison(self): - code = 'start = 1\nprint(start <= 2)\n' - refactored = self._inline(code, code.index('start')) - self.assertEqual('print(1 <= 2)\n', refactored) + code = "start = 1\nprint(start <= 2)\n" + refactored = self._inline(code, code.index("start")) + self.assertEqual("print(1 <= 2)\n", refactored) def test_inlining_variables_in_other_modules(self): - self.mod.write('myvar = 1\n') - self.mod2.write('import mod\nprint(mod.myvar)\n') + self.mod.write("myvar = 1\n") + self.mod2.write("import mod\nprint(mod.myvar)\n") self._inline2(self.mod, 2) - self.assertEqual('import mod\nprint(1)\n', self.mod2.read()) + self.assertEqual("import mod\nprint(1)\n", self.mod2.read()) def test_inlining_variables_and_back_importing(self): - self.mod.write('mainvar = 1\nmyvar = mainvar\n') - self.mod2.write('import mod\nprint(mod.myvar)\n') - self._inline2(self.mod, self.mod.read().index('myvar')) - expected = 'import mod\n' \ - 'from mod import mainvar\n' \ - 'print(mainvar)\n' + self.mod.write("mainvar = 1\nmyvar = mainvar\n") + self.mod2.write("import mod\nprint(mod.myvar)\n") + self._inline2(self.mod, self.mod.read().index("myvar")) + expected = "import mod\n" "from mod import mainvar\n" "print(mainvar)\n" self.assertEqual(expected, self.mod2.read()) def test_inlining_variables_and_importing_used_imports(self): - self.mod.write('import sys\nmyvar = sys.argv\n') - self.mod2.write('import mod\nprint(mod.myvar)\n') - self._inline2(self.mod, self.mod.read().index('myvar')) - expected = 'import mod\n' \ - 'import sys\n' \ - 'print(sys.argv)\n' + self.mod.write("import sys\nmyvar = sys.argv\n") + self.mod2.write("import mod\nprint(mod.myvar)\n") + self._inline2(self.mod, self.mod.read().index("myvar")) + expected = "import mod\n" "import sys\n" "print(sys.argv)\n" self.assertEqual(expected, self.mod2.read()) def test_inlining_variables_and_removing_old_froms(self): - self.mod.write('var = 1\n') - self.mod2.write('from mod import var\nprint(var)\n') - self._inline2(self.mod2, self.mod2.read().rindex('var')) - self.assertEqual('print(1)\n', self.mod2.read()) + self.mod.write("var = 1\n") + self.mod2.write("from mod import var\nprint(var)\n") + self._inline2(self.mod2, self.mod2.read().rindex("var")) + self.assertEqual("print(1)\n", self.mod2.read()) def test_inlining_method_and_removing_old_froms(self): - self.mod.write('def f(): return 1\n') - self.mod2.write('from mod import f\nprint(f())\n') - self._inline2(self.mod2, self.mod2.read().rindex('f')) - self.assertEqual('print(1)\n', self.mod2.read()) + self.mod.write("def f(): return 1\n") + self.mod2.write("from mod import f\nprint(f())\n") + self._inline2(self.mod2, self.mod2.read().rindex("f")) + self.assertEqual("print(1)\n", self.mod2.read()) def test_inlining_functions_in_other_modules_and_only_current(self): - code1 = 'def f():\n' \ - ' return 1\n' \ - 'print(f())\n' - code2 = 'import mod\n' \ - 'print(mod.f())\n' \ - 'print(mod.f())\n' + code1 = "def f():\n" " return 1\n" "print(f())\n" + code2 = "import mod\n" "print(mod.f())\n" "print(mod.f())\n" self.mod.write(code1) self.mod2.write(code2) - self._inline2(self.mod2, self.mod2.read().rindex('f'), - remove=False, only_current=True) - expected2 = 'import mod\n' \ - 'print(mod.f())\n' \ - 'print(1)\n' + self._inline2( + self.mod2, self.mod2.read().rindex("f"), remove=False, only_current=True + ) + expected2 = "import mod\n" "print(mod.f())\n" "print(1)\n" self.assertEqual(code1, self.mod.read()) self.assertEqual(expected2, self.mod2.read()) def test_inlining_variables_in_other_modules_and_only_current(self): - code1 = 'var = 1\n' \ - 'print(var)\n' - code2 = 'import mod\n' \ - 'print(mod.var)\n' \ - 'print(mod.var)\n' + code1 = "var = 1\n" "print(var)\n" + code2 = "import mod\n" "print(mod.var)\n" "print(mod.var)\n" self.mod.write(code1) self.mod2.write(code2) - self._inline2(self.mod2, self.mod2.read().rindex('var'), - remove=False, only_current=True) - expected2 = 'import mod\n' \ - 'print(mod.var)\n' \ - 'print(1)\n' + self._inline2( + self.mod2, self.mod2.read().rindex("var"), remove=False, only_current=True + ) + expected2 = "import mod\n" "print(mod.var)\n" "print(1)\n" self.assertEqual(code1, self.mod.read()) self.assertEqual(expected2, self.mod2.read()) def test_inlining_does_not_change_string_constants(self): - code = 'var = 1\n' \ - 'print("var\\\n' \ - '")\n' - expected = 'var = 1\n' \ - 'print("var\\\n' \ - '")\n' - refactored = self._inline(code, code.rindex('var'), - remove=False, only_current=True, docs=False) + code = "var = 1\n" 'print("var\\\n' '")\n' + expected = "var = 1\n" 'print("var\\\n' '")\n' + refactored = self._inline( + code, code.rindex("var"), remove=False, only_current=True, docs=False + ) self.assertEqual(expected, refactored) def test_inlining_does_change_string_constants_if_docs_is_set(self): - code = 'var = 1\n' \ - 'print("var\\\n' \ - '")\n' - expected = 'var = 1\n' \ - 'print("1\\\n' \ - '")\n' - refactored = self._inline(code, code.rindex('var'), - remove=False, only_current=True, docs=True) + code = "var = 1\n" 'print("var\\\n' '")\n' + expected = "var = 1\n" 'print("1\\\n' '")\n' + refactored = self._inline( + code, code.rindex("var"), remove=False, only_current=True, docs=True + ) self.assertEqual(expected, refactored) diff --git a/ropetest/refactor/movetest.py b/ropetest/refactor/movetest.py index 499fc6bf4..bbdd29c72 100644 --- a/ropetest/refactor/movetest.py +++ b/ropetest/refactor/movetest.py @@ -9,788 +9,798 @@ class MoveRefactoringTest(unittest.TestCase): - def setUp(self): super(MoveRefactoringTest, self).setUp() self.project = testutils.sample_project() - self.mod1 = testutils.create_module(self.project, 'mod1') - self.mod2 = testutils.create_module(self.project, 'mod2') - self.mod3 = testutils.create_module(self.project, 'mod3') - self.pkg = testutils.create_package(self.project, 'pkg') - self.mod4 = testutils.create_module(self.project, 'mod4', self.pkg) - self.mod5 = testutils.create_module(self.project, 'mod5', self.pkg) + self.mod1 = testutils.create_module(self.project, "mod1") + self.mod2 = testutils.create_module(self.project, "mod2") + self.mod3 = testutils.create_module(self.project, "mod3") + self.pkg = testutils.create_package(self.project, "pkg") + self.mod4 = testutils.create_module(self.project, "mod4", self.pkg) + self.mod5 = testutils.create_module(self.project, "mod5", self.pkg) def tearDown(self): testutils.remove_project(self.project) super(MoveRefactoringTest, self).tearDown() def _move(self, resource, offset, dest_resource): - changes = move.create_move(self.project, resource, offset).\ - get_changes(dest_resource) + changes = move.create_move(self.project, resource, offset).get_changes( + dest_resource + ) self.project.do(changes) def test_move_constant(self): - self.mod1.write('foo = 123\n') - self._move(self.mod1, self.mod1.read().index('foo') + 1, - self.mod2) - self.assertEqual('', self.mod1.read()) - self.assertEqual('foo = 123\n', self.mod2.read()) + self.mod1.write("foo = 123\n") + self._move(self.mod1, self.mod1.read().index("foo") + 1, self.mod2) + self.assertEqual("", self.mod1.read()) + self.assertEqual("foo = 123\n", self.mod2.read()) def test_move_constant_2(self): - self.mod1.write('bar = 321\nfoo = 123\n') - self._move(self.mod1, self.mod1.read().index('foo') + 1, - self.mod2) - self.assertEqual('bar = 321\n', self.mod1.read()) - self.assertEqual('foo = 123\n', self.mod2.read()) + self.mod1.write("bar = 321\nfoo = 123\n") + self._move(self.mod1, self.mod1.read().index("foo") + 1, self.mod2) + self.assertEqual("bar = 321\n", self.mod1.read()) + self.assertEqual("foo = 123\n", self.mod2.read()) def test_move_constant_multiline(self): - self.mod1.write('foo = (\n 123\n)\n') - self._move(self.mod1, self.mod1.read().index('foo') + 1, - self.mod2) - self.assertEqual('', self.mod1.read()) - self.assertEqual('foo = (\n 123\n)\n', self.mod2.read()) + self.mod1.write("foo = (\n 123\n)\n") + self._move(self.mod1, self.mod1.read().index("foo") + 1, self.mod2) + self.assertEqual("", self.mod1.read()) + self.assertEqual("foo = (\n 123\n)\n", self.mod2.read()) def test_move_constant_multiple_statements(self): - self.mod1.write('foo = 123\nfoo += 3\nfoo = 4\n') - self._move(self.mod1, self.mod1.read().index('foo') + 1, - self.mod2) - self.assertEqual('import mod2\nmod2.foo += 3\nmod2.foo = 4\n', - self.mod1.read()) - self.assertEqual('foo = 123\n', self.mod2.read()) + self.mod1.write("foo = 123\nfoo += 3\nfoo = 4\n") + self._move(self.mod1, self.mod1.read().index("foo") + 1, self.mod2) + self.assertEqual("import mod2\nmod2.foo += 3\nmod2.foo = 4\n", self.mod1.read()) + self.assertEqual("foo = 123\n", self.mod2.read()) def test_simple_moving(self): - self.mod1.write('class AClass(object):\n pass\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('', self.mod1.read()) - self.assertEqual('class AClass(object):\n pass\n', - self.mod2.read()) + self.mod1.write("class AClass(object):\n pass\n") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual("", self.mod1.read()) + self.assertEqual("class AClass(object):\n pass\n", self.mod2.read()) def test_moving_with_comment_prefix(self): - self.mod1.write('a = 1\n# 1\n# 2\nclass AClass(object):\n pass\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('a = 1\n', self.mod1.read()) - self.assertEqual('# 1\n# 2\nclass AClass(object):\n pass\n', - self.mod2.read()) + self.mod1.write("a = 1\n# 1\n# 2\nclass AClass(object):\n pass\n") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual("a = 1\n", self.mod1.read()) + self.assertEqual( + "# 1\n# 2\nclass AClass(object):\n pass\n", self.mod2.read() + ) def test_moving_with_comment_prefix_imports(self): - self.mod1.write('import foo\na = 1\n# 1\n# 2\n' - 'class AClass(foo.FooClass):\n pass\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('a = 1\n', self.mod1.read()) - self.assertEqual('import foo\n\n\n# 1\n# 2\n' - 'class AClass(foo.FooClass):\n pass\n', - self.mod2.read()) + self.mod1.write( + "import foo\na = 1\n# 1\n# 2\n" "class AClass(foo.FooClass):\n pass\n" + ) + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual("a = 1\n", self.mod1.read()) + self.assertEqual( + "import foo\n\n\n# 1\n# 2\n" "class AClass(foo.FooClass):\n pass\n", + self.mod2.read(), + ) def test_changing_other_modules_replacing_normal_imports(self): - self.mod1.write('class AClass(object):\n pass\n') - self.mod3.write('import mod1\na_var = mod1.AClass()\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('import mod2\na_var = mod2.AClass()\n', - self.mod3.read()) + self.mod1.write("class AClass(object):\n pass\n") + self.mod3.write("import mod1\na_var = mod1.AClass()\n") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual("import mod2\na_var = mod2.AClass()\n", self.mod3.read()) def test_changing_other_modules_adding_normal_imports(self): - self.mod1.write('class AClass(object):\n pass\n' - 'def a_function():\n pass\n') - self.mod3.write('import mod1\na_var = mod1.AClass()\n' - 'mod1.a_function()') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('import mod1\nimport mod2\na_var = mod2.AClass()\n' + - 'mod1.a_function()', self.mod3.read()) + self.mod1.write( + "class AClass(object):\n pass\n" "def a_function():\n pass\n" + ) + self.mod3.write("import mod1\na_var = mod1.AClass()\n" "mod1.a_function()") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual( + "import mod1\nimport mod2\na_var = mod2.AClass()\n" + "mod1.a_function()", + self.mod3.read(), + ) def test_adding_imports_prefer_from_module(self): - self.project.prefs['prefer_module_from_imports'] = True - self.mod1.write('class AClass(object):\n pass\n' - 'def a_function():\n pass\n') - self.mod3.write('import mod1\na_var = mod1.AClass()\n' - 'mod1.a_function()') + self.project.prefs["prefer_module_from_imports"] = True + self.mod1.write( + "class AClass(object):\n pass\n" "def a_function():\n pass\n" + ) + self.mod3.write("import mod1\na_var = mod1.AClass()\n" "mod1.a_function()") # Move to mod4 which is in a different package - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod4) - self.assertEqual('import mod1\nfrom pkg import mod4\n' - 'a_var = mod4.AClass()\nmod1.a_function()', - self.mod3.read()) + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod4) + self.assertEqual( + "import mod1\nfrom pkg import mod4\n" + "a_var = mod4.AClass()\nmod1.a_function()", + self.mod3.read(), + ) def test_adding_imports_noprefer_from_module(self): - self.project.prefs['prefer_module_from_imports'] = False - self.mod1.write('class AClass(object):\n pass\n' - 'def a_function():\n pass\n') - self.mod3.write('import mod1\na_var = mod1.AClass()\n' - 'mod1.a_function()') + self.project.prefs["prefer_module_from_imports"] = False + self.mod1.write( + "class AClass(object):\n pass\n" "def a_function():\n pass\n" + ) + self.mod3.write("import mod1\na_var = mod1.AClass()\n" "mod1.a_function()") # Move to mod4 which is in a different package - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod4) - self.assertEqual('import mod1\nimport pkg.mod4\n' - 'a_var = pkg.mod4.AClass()\nmod1.a_function()', - self.mod3.read()) + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod4) + self.assertEqual( + "import mod1\nimport pkg.mod4\n" + "a_var = pkg.mod4.AClass()\nmod1.a_function()", + self.mod3.read(), + ) def test_adding_imports_prefer_from_module_top_level_module(self): - self.project.prefs['prefer_module_from_imports'] = True - self.mod1.write('class AClass(object):\n pass\n' - 'def a_function():\n pass\n') - self.mod3.write('import mod1\na_var = mod1.AClass()\n' - 'mod1.a_function()') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('import mod1\nimport mod2\na_var = mod2.AClass()\n' + - 'mod1.a_function()', self.mod3.read()) + self.project.prefs["prefer_module_from_imports"] = True + self.mod1.write( + "class AClass(object):\n pass\n" "def a_function():\n pass\n" + ) + self.mod3.write("import mod1\na_var = mod1.AClass()\n" "mod1.a_function()") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual( + "import mod1\nimport mod2\na_var = mod2.AClass()\n" + "mod1.a_function()", + self.mod3.read(), + ) def test_changing_other_modules_removing_from_imports(self): - self.mod1.write('class AClass(object):\n pass\n') - self.mod3.write('from mod1 import AClass\na_var = AClass()\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('import mod2\na_var = mod2.AClass()\n', - self.mod3.read()) + self.mod1.write("class AClass(object):\n pass\n") + self.mod3.write("from mod1 import AClass\na_var = AClass()\n") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual("import mod2\na_var = mod2.AClass()\n", self.mod3.read()) def test_changing_source_module(self): - self.mod1.write('class AClass(object):\n pass\na_var = AClass()\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('import mod2\na_var = mod2.AClass()\n', - self.mod1.read()) + self.mod1.write("class AClass(object):\n pass\na_var = AClass()\n") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual("import mod2\na_var = mod2.AClass()\n", self.mod1.read()) def test_changing_destination_module(self): - self.mod1.write('class AClass(object):\n pass\n') - self.mod2.write('from mod1 import AClass\na_var = AClass()\n') - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) - self.assertEqual('class AClass(object):\n ' - 'pass\na_var = AClass()\n', - self.mod2.read()) + self.mod1.write("class AClass(object):\n pass\n") + self.mod2.write("from mod1 import AClass\na_var = AClass()\n") + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) + self.assertEqual( + "class AClass(object):\n " "pass\na_var = AClass()\n", self.mod2.read() + ) def test_folder_destination(self): - folder = self.project.root.create_folder('folder') - self.mod1.write('class AClass(object):\n pass\n') + folder = self.project.root.create_folder("folder") + self.mod1.write("class AClass(object):\n pass\n") with self.assertRaises(exceptions.RefactoringError): - self._move(self.mod1, self.mod1.read().index('AClass') + 1, folder) + self._move(self.mod1, self.mod1.read().index("AClass") + 1, folder) def test_raising_exception_for_moving_non_global_elements(self): - self.mod1.write( - 'def a_func():\n class AClass(object):\n pass\n') + self.mod1.write("def a_func():\n class AClass(object):\n pass\n") with self.assertRaises(exceptions.RefactoringError): - self._move(self.mod1, self.mod1.read().index('AClass') + 1, - self.mod2) + self._move(self.mod1, self.mod1.read().index("AClass") + 1, self.mod2) def test_raising_an_exception_for_moving_non_global_variable(self): - code = 'class TestClass:\n CONSTANT = 5\n' + code = "class TestClass:\n CONSTANT = 5\n" self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - mover = move.create_move(self.project, self.mod1, - code.index('CONSTANT') + 1) + mover = move.create_move( + self.project, self.mod1, code.index("CONSTANT") + 1 + ) def test_raising_exception_for_mov_glob_elemnts_to_the_same_module(self): - self.mod1.write('def a_func():\n pass\n') + self.mod1.write("def a_func():\n pass\n") with self.assertRaises(exceptions.RefactoringError): - self._move(self.mod1, self.mod1.read().index('a_func'), self.mod1) + self._move(self.mod1, self.mod1.read().index("a_func"), self.mod1) def test_moving_used_imports_to_destination_module(self): - self.mod3.write('a_var = 10') - code = 'import mod3\n' \ - 'from mod3 import a_var\n' \ - 'def a_func():\n' \ - ' print(mod3, a_var)\n' - self.mod1.write(code) - self._move(self.mod1, code.index('a_func') + 1, self.mod2) - expected = 'import mod3\n' \ - 'from mod3 import a_var\n\n\n' \ - 'def a_func():\n print(mod3, a_var)\n' + self.mod3.write("a_var = 10") + code = ( + "import mod3\n" + "from mod3 import a_var\n" + "def a_func():\n" + " print(mod3, a_var)\n" + ) + self.mod1.write(code) + self._move(self.mod1, code.index("a_func") + 1, self.mod2) + expected = ( + "import mod3\n" + "from mod3 import a_var\n\n\n" + "def a_func():\n print(mod3, a_var)\n" + ) self.assertEqual(expected, self.mod2.read()) def test_moving_used_names_to_destination_module2(self): - code = 'a_var = 10\n' \ - 'def a_func():\n' \ - ' print(a_var)\n' - self.mod1.write(code) - self._move(self.mod1, code.index('a_func') + 1, self.mod2) - self.assertEqual('a_var = 10\n', self.mod1.read()) - expected = 'from mod1 import a_var\n\n\n' \ - 'def a_func():\n' \ - ' print(a_var)\n' + code = "a_var = 10\n" "def a_func():\n" " print(a_var)\n" + self.mod1.write(code) + self._move(self.mod1, code.index("a_func") + 1, self.mod2) + self.assertEqual("a_var = 10\n", self.mod1.read()) + expected = "from mod1 import a_var\n\n\n" "def a_func():\n" " print(a_var)\n" self.assertEqual(expected, self.mod2.read()) def test_moving_used_underlined_names_to_destination_module(self): - code = '_var = 10\n' \ - 'def a_func():\n' \ - ' print(_var)\n' - self.mod1.write(code) - self._move(self.mod1, code.index('a_func') + 1, self.mod2) - expected = 'from mod1 import _var\n\n\n' \ - 'def a_func():\n' \ - ' print(_var)\n' + code = "_var = 10\n" "def a_func():\n" " print(_var)\n" + self.mod1.write(code) + self._move(self.mod1, code.index("a_func") + 1, self.mod2) + expected = "from mod1 import _var\n\n\n" "def a_func():\n" " print(_var)\n" self.assertEqual(expected, self.mod2.read()) def test_moving_and_used_relative_imports(self): - code = 'import mod5\n' \ - 'def a_func():\n' \ - ' print(mod5)\n' + code = "import mod5\n" "def a_func():\n" " print(mod5)\n" self.mod4.write(code) - self._move(self.mod4, code.index('a_func') + 1, self.mod1) - expected = 'import pkg.mod5\n\n\n' \ - 'def a_func():\n' \ - ' print(pkg.mod5)\n' + self._move(self.mod4, code.index("a_func") + 1, self.mod1) + expected = "import pkg.mod5\n\n\n" "def a_func():\n" " print(pkg.mod5)\n" self.assertEqual(expected, self.mod1.read()) - self.assertEqual('', self.mod4.read()) + self.assertEqual("", self.mod4.read()) def test_moving_modules(self): - code = 'import mod1\nprint(mod1)' + code = "import mod1\nprint(mod1)" self.mod2.write(code) - self._move(self.mod2, code.index('mod1') + 1, self.pkg) - expected = 'import pkg.mod1\nprint(pkg.mod1)' + self._move(self.mod2, code.index("mod1") + 1, self.pkg) + expected = "import pkg.mod1\nprint(pkg.mod1)" self.assertEqual(expected, self.mod2.read()) - self.assertTrue(not self.mod1.exists() and - self.project.find_module('pkg.mod1') is not None) + self.assertTrue( + not self.mod1.exists() and self.project.find_module("pkg.mod1") is not None + ) def test_moving_modules_and_removing_out_of_date_imports(self): - code = 'import pkg.mod4\nprint(pkg.mod4)' + code = "import pkg.mod4\nprint(pkg.mod4)" self.mod2.write(code) - self._move(self.mod2, code.index('mod4') + 1, self.project.root) - expected = 'import mod4\nprint(mod4)' + self._move(self.mod2, code.index("mod4") + 1, self.project.root) + expected = "import mod4\nprint(mod4)" self.assertEqual(expected, self.mod2.read()) - self.assertTrue(self.project.find_module('mod4') is not None) + self.assertTrue(self.project.find_module("mod4") is not None) def test_moving_modules_and_removing_out_of_date_froms(self): - code = 'from pkg import mod4\nprint(mod4)' + code = "from pkg import mod4\nprint(mod4)" self.mod2.write(code) - self._move(self.mod2, code.index('mod4') + 1, self.project.root) - self.assertEqual('import mod4\nprint(mod4)', self.mod2.read()) + self._move(self.mod2, code.index("mod4") + 1, self.project.root) + self.assertEqual("import mod4\nprint(mod4)", self.mod2.read()) def test_moving_modules_and_removing_out_of_date_froms2(self): - self.mod4.write('a_var = 10') - code = 'from pkg.mod4 import a_var\nprint(a_var)\n' + self.mod4.write("a_var = 10") + code = "from pkg.mod4 import a_var\nprint(a_var)\n" self.mod2.write(code) - self._move(self.mod2, code.index('mod4') + 1, self.project.root) - expected = 'from mod4 import a_var\nprint(a_var)\n' + self._move(self.mod2, code.index("mod4") + 1, self.project.root) + expected = "from mod4 import a_var\nprint(a_var)\n" self.assertEqual(expected, self.mod2.read()) def test_moving_modules_and_relative_import(self): - self.mod4.write('import mod5\nprint(mod5)\n') - code = 'import pkg.mod4\nprint(pkg.mod4)' + self.mod4.write("import mod5\nprint(mod5)\n") + code = "import pkg.mod4\nprint(pkg.mod4)" self.mod2.write(code) - self._move(self.mod2, code.index('mod4') + 1, self.project.root) - moved = self.project.find_module('mod4') - expected = 'import pkg.mod5\nprint(pkg.mod5)\n' + self._move(self.mod2, code.index("mod4") + 1, self.project.root) + moved = self.project.find_module("mod4") + expected = "import pkg.mod5\nprint(pkg.mod5)\n" self.assertEqual(expected, moved.read()) def test_moving_module_kwarg_same_name_as_old(self): - self.mod1.write('def foo(mod1=0):\n pass') - code = 'import mod1\nmod1.foo(mod1=1)' + self.mod1.write("def foo(mod1=0):\n pass") + code = "import mod1\nmod1.foo(mod1=1)" self.mod2.write(code) self._move(self.mod1, None, self.pkg) - moved = self.project.find_module('mod2') - expected = 'import pkg.mod1\npkg.mod1.foo(mod1=1)' + moved = self.project.find_module("mod2") + expected = "import pkg.mod1\npkg.mod1.foo(mod1=1)" self.assertEqual(expected, moved.read()) def test_moving_packages(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - code = 'import pkg.mod4\nprint(pkg.mod4)' + pkg2 = testutils.create_package(self.project, "pkg2") + code = "import pkg.mod4\nprint(pkg.mod4)" self.mod1.write(code) - self._move(self.mod1, code.index('pkg') + 1, pkg2) + self._move(self.mod1, code.index("pkg") + 1, pkg2) self.assertFalse(self.pkg.exists()) - self.assertTrue(self.project.find_module('pkg2.pkg.mod4') is not None) - self.assertTrue(self.project.find_module('pkg2.pkg.mod4') is not None) - self.assertTrue(self.project.find_module('pkg2.pkg.mod5') is not None) - expected = 'import pkg2.pkg.mod4\nprint(pkg2.pkg.mod4)' + self.assertTrue(self.project.find_module("pkg2.pkg.mod4") is not None) + self.assertTrue(self.project.find_module("pkg2.pkg.mod4") is not None) + self.assertTrue(self.project.find_module("pkg2.pkg.mod5") is not None) + expected = "import pkg2.pkg.mod4\nprint(pkg2.pkg.mod4)" self.assertEqual(expected, self.mod1.read()) def test_moving_modules_with_self_imports(self): - self.mod1.write('import mod1\nprint(mod1)\n') - self.mod2.write('import mod1\n') - self._move(self.mod2, self.mod2.read().index('mod1') + 1, self.pkg) - moved = self.project.find_module('pkg.mod1') - self.assertEqual('import pkg.mod1\nprint(pkg.mod1)\n', moved.read()) + self.mod1.write("import mod1\nprint(mod1)\n") + self.mod2.write("import mod1\n") + self._move(self.mod2, self.mod2.read().index("mod1") + 1, self.pkg) + moved = self.project.find_module("pkg.mod1") + self.assertEqual("import pkg.mod1\nprint(pkg.mod1)\n", moved.read()) def test_moving_modules_with_from_imports(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - code = ('from pkg import mod4\n' - 'print(mod4)') + pkg2 = testutils.create_package(self.project, "pkg2") + code = "from pkg import mod4\n" "print(mod4)" self.mod1.write(code) - self._move(self.mod1, code.index('pkg') + 1, pkg2) + self._move(self.mod1, code.index("pkg") + 1, pkg2) self.assertFalse(self.pkg.exists()) - self.assertTrue(self.project.find_module('pkg2.pkg.mod4') is not None) - self.assertTrue(self.project.find_module('pkg2.pkg.mod5') is not None) - expected = ('from pkg2.pkg import mod4\n' - 'print(mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg.mod4") is not None) + self.assertTrue(self.project.find_module("pkg2.pkg.mod5") is not None) + expected = "from pkg2.pkg import mod4\n" "print(mod4)" self.assertEqual(expected, self.mod1.read()) def test_moving_modules_with_from_import(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - pkg3 = testutils.create_package(self.project, 'pkg3', pkg2) - pkg4 = testutils.create_package(self.project, 'pkg4', pkg3) - code = ('from pkg import mod4\n' - 'print(mod4)') + pkg2 = testutils.create_package(self.project, "pkg2") + pkg3 = testutils.create_package(self.project, "pkg3", pkg2) + pkg4 = testutils.create_package(self.project, "pkg4", pkg3) + code = "from pkg import mod4\n" "print(mod4)" self.mod1.write(code) self._move(self.mod4, None, pkg4) - self.assertTrue( - self.project.find_module('pkg2.pkg3.pkg4.mod4') is not None) - expected = ('from pkg2.pkg3.pkg4 import mod4\n' - 'print(mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg3.pkg4.mod4") is not None) + expected = "from pkg2.pkg3.pkg4 import mod4\n" "print(mod4)" self.assertEqual(expected, self.mod1.read()) def test_moving_modules_with_multi_from_imports(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - pkg3 = testutils.create_package(self.project, 'pkg3', pkg2) - pkg4 = testutils.create_package(self.project, 'pkg4', pkg3) - code = ('from pkg import mod4, mod5\n' - 'print(mod4)') + pkg2 = testutils.create_package(self.project, "pkg2") + pkg3 = testutils.create_package(self.project, "pkg3", pkg2) + pkg4 = testutils.create_package(self.project, "pkg4", pkg3) + code = "from pkg import mod4, mod5\n" "print(mod4)" self.mod1.write(code) self._move(self.mod4, None, pkg4) - self.assertTrue( - self.project.find_module('pkg2.pkg3.pkg4.mod4') is not None) - expected = ('from pkg import mod5\n' - 'from pkg2.pkg3.pkg4 import mod4\n' - 'print(mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg3.pkg4.mod4") is not None) + expected = ( + "from pkg import mod5\n" "from pkg2.pkg3.pkg4 import mod4\n" "print(mod4)" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_modules_with_from_and_normal_imports(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - pkg3 = testutils.create_package(self.project, 'pkg3', pkg2) - pkg4 = testutils.create_package(self.project, 'pkg4', pkg3) - code = ('from pkg import mod4\n' - 'import pkg.mod4\n' - 'print(mod4)\n' - 'print(pkg.mod4)') + pkg2 = testutils.create_package(self.project, "pkg2") + pkg3 = testutils.create_package(self.project, "pkg3", pkg2) + pkg4 = testutils.create_package(self.project, "pkg4", pkg3) + code = ( + "from pkg import mod4\n" + "import pkg.mod4\n" + "print(mod4)\n" + "print(pkg.mod4)" + ) self.mod1.write(code) self._move(self.mod4, None, pkg4) - self.assertTrue( - self.project.find_module('pkg2.pkg3.pkg4.mod4') is not None) - expected = ('import pkg2.pkg3.pkg4.mod4\n' - 'from pkg2.pkg3.pkg4 import mod4\n' - 'print(mod4)\n' - 'print(pkg2.pkg3.pkg4.mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg3.pkg4.mod4") is not None) + expected = ( + "import pkg2.pkg3.pkg4.mod4\n" + "from pkg2.pkg3.pkg4 import mod4\n" + "print(mod4)\n" + "print(pkg2.pkg3.pkg4.mod4)" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_modules_with_normal_and_from_imports(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - pkg3 = testutils.create_package(self.project, 'pkg3', pkg2) - pkg4 = testutils.create_package(self.project, 'pkg4', pkg3) - code = ('import pkg.mod4\n' - 'from pkg import mod4\n' - 'print(mod4)\n' - 'print(pkg.mod4)') + pkg2 = testutils.create_package(self.project, "pkg2") + pkg3 = testutils.create_package(self.project, "pkg3", pkg2) + pkg4 = testutils.create_package(self.project, "pkg4", pkg3) + code = ( + "import pkg.mod4\n" + "from pkg import mod4\n" + "print(mod4)\n" + "print(pkg.mod4)" + ) self.mod1.write(code) self._move(self.mod4, None, pkg4) - self.assertTrue( - self.project.find_module('pkg2.pkg3.pkg4.mod4') is not None) - expected = ('import pkg2.pkg3.pkg4.mod4\n' - 'from pkg2.pkg3.pkg4 import mod4\n' - 'print(mod4)\n' - 'print(pkg2.pkg3.pkg4.mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg3.pkg4.mod4") is not None) + expected = ( + "import pkg2.pkg3.pkg4.mod4\n" + "from pkg2.pkg3.pkg4 import mod4\n" + "print(mod4)\n" + "print(pkg2.pkg3.pkg4.mod4)" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_modules_from_import_variable(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - pkg3 = testutils.create_package(self.project, 'pkg3', pkg2) - pkg4 = testutils.create_package(self.project, 'pkg4', pkg3) - code = ('from pkg.mod4 import foo\n' - 'print(foo)') + pkg2 = testutils.create_package(self.project, "pkg2") + pkg3 = testutils.create_package(self.project, "pkg3", pkg2) + pkg4 = testutils.create_package(self.project, "pkg4", pkg3) + code = "from pkg.mod4 import foo\n" "print(foo)" self.mod1.write(code) self._move(self.mod4, None, pkg4) - self.assertTrue( - self.project.find_module('pkg2.pkg3.pkg4.mod4') is not None) - expected = ('from pkg2.pkg3.pkg4.mod4 import foo\n' - 'print(foo)') + self.assertTrue(self.project.find_module("pkg2.pkg3.pkg4.mod4") is not None) + expected = "from pkg2.pkg3.pkg4.mod4 import foo\n" "print(foo)" self.assertEqual(expected, self.mod1.read()) def test_moving_modules_normal_import(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - pkg3 = testutils.create_package(self.project, 'pkg3', pkg2) - pkg4 = testutils.create_package(self.project, 'pkg4', pkg3) - code = ('import pkg.mod4\n' - 'print(pkg.mod4)') + pkg2 = testutils.create_package(self.project, "pkg2") + pkg3 = testutils.create_package(self.project, "pkg3", pkg2) + pkg4 = testutils.create_package(self.project, "pkg4", pkg3) + code = "import pkg.mod4\n" "print(pkg.mod4)" self.mod1.write(code) self._move(self.mod4, None, pkg4) - self.assertTrue( - self.project.find_module('pkg2.pkg3.pkg4.mod4') is not None) - expected = ('import pkg2.pkg3.pkg4.mod4\n' - 'print(pkg2.pkg3.pkg4.mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg3.pkg4.mod4") is not None) + expected = "import pkg2.pkg3.pkg4.mod4\n" "print(pkg2.pkg3.pkg4.mod4)" self.assertEqual(expected, self.mod1.read()) def test_moving_package_with_from_and_normal_imports(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - code = ('from pkg import mod4\n' - 'import pkg.mod4\n' - 'print(pkg.mod4)\n' - 'print(mod4)') - self.mod1.write(code) - self._move(self.mod1, code.index('pkg') + 1, pkg2) + pkg2 = testutils.create_package(self.project, "pkg2") + code = ( + "from pkg import mod4\n" + "import pkg.mod4\n" + "print(pkg.mod4)\n" + "print(mod4)" + ) + self.mod1.write(code) + self._move(self.mod1, code.index("pkg") + 1, pkg2) self.assertFalse(self.pkg.exists()) - self.assertTrue(self.project.find_module('pkg2.pkg.mod4') is not None) - self.assertTrue(self.project.find_module('pkg2.pkg.mod5') is not None) - expected = ('from pkg2.pkg import mod4\n' - 'import pkg2.pkg.mod4\n' - 'print(pkg2.pkg.mod4)\n' - 'print(mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg.mod4") is not None) + self.assertTrue(self.project.find_module("pkg2.pkg.mod5") is not None) + expected = ( + "from pkg2.pkg import mod4\n" + "import pkg2.pkg.mod4\n" + "print(pkg2.pkg.mod4)\n" + "print(mod4)" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_package_with_from_and_normal_imports2(self): - pkg2 = testutils.create_package(self.project, 'pkg2') - code = ('import pkg.mod4\n' - 'from pkg import mod4\n' - 'print(pkg.mod4)\n' - 'print(mod4)') - self.mod1.write(code) - self._move(self.mod1, code.index('pkg') + 1, pkg2) + pkg2 = testutils.create_package(self.project, "pkg2") + code = ( + "import pkg.mod4\n" + "from pkg import mod4\n" + "print(pkg.mod4)\n" + "print(mod4)" + ) + self.mod1.write(code) + self._move(self.mod1, code.index("pkg") + 1, pkg2) self.assertFalse(self.pkg.exists()) - self.assertTrue(self.project.find_module('pkg2.pkg.mod4') is not None) - self.assertTrue(self.project.find_module('pkg2.pkg.mod5') is not None) - expected = ('import pkg2.pkg.mod4\n' - 'from pkg2.pkg import mod4\n' - 'print(pkg2.pkg.mod4)\n' - 'print(mod4)') + self.assertTrue(self.project.find_module("pkg2.pkg.mod4") is not None) + self.assertTrue(self.project.find_module("pkg2.pkg.mod5") is not None) + expected = ( + "import pkg2.pkg.mod4\n" + "from pkg2.pkg import mod4\n" + "print(pkg2.pkg.mod4)\n" + "print(mod4)" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_package_and_retaining_blank_lines(self): - pkg2 = testutils.create_package(self.project, 'pkg2', self.pkg) - code = ('"""Docstring followed by blank lines."""\n\n' - 'import pkg.mod4\n\n' - 'from pkg import mod4\n' - 'from x import y\n' - 'from y import z\n' - 'from a import b\n' - 'from b import c\n' - 'print(pkg.mod4)\n' - 'print(mod4)') + pkg2 = testutils.create_package(self.project, "pkg2", self.pkg) + code = ( + '"""Docstring followed by blank lines."""\n\n' + "import pkg.mod4\n\n" + "from pkg import mod4\n" + "from x import y\n" + "from y import z\n" + "from a import b\n" + "from b import c\n" + "print(pkg.mod4)\n" + "print(mod4)" + ) self.mod1.write(code) self._move(self.mod4, None, pkg2) - expected = ('"""Docstring followed by blank lines."""\n\n' - 'import pkg.pkg2.mod4\n\n' - 'from x import y\n' - 'from y import z\n' - 'from a import b\n' - 'from b import c\n' - 'from pkg.pkg2 import mod4\n' - 'print(pkg.pkg2.mod4)\n' - 'print(mod4)') + expected = ( + '"""Docstring followed by blank lines."""\n\n' + "import pkg.pkg2.mod4\n\n" + "from x import y\n" + "from y import z\n" + "from a import b\n" + "from b import c\n" + "from pkg.pkg2 import mod4\n" + "print(pkg.pkg2.mod4)\n" + "print(mod4)" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_functions_to_imported_module(self): - code = 'import mod1\n' \ - 'def a_func():\n' \ - ' var = mod1.a_var\n' - self.mod1.write('a_var = 1\n') + code = "import mod1\n" "def a_func():\n" " var = mod1.a_var\n" + self.mod1.write("a_var = 1\n") self.mod2.write(code) - self._move(self.mod2, code.index('a_func') + 1, self.mod1) - expected = 'def a_func():\n' \ - ' var = a_var\n' \ - 'a_var = 1\n' + self._move(self.mod2, code.index("a_func") + 1, self.mod1) + expected = "def a_func():\n" " var = a_var\n" "a_var = 1\n" self.assertEqual(expected, self.mod1.read()) def test_moving_resources_using_move_module_refactoring(self): - self.mod1.write('a_var = 1') - self.mod2.write('import mod1\nmy_var = mod1.a_var\n') + self.mod1.write("a_var = 1") + self.mod2.write("import mod1\nmy_var = mod1.a_var\n") mover = move.create_move(self.project, self.mod1) mover.get_changes(self.pkg).do() - expected = 'import pkg.mod1\nmy_var = pkg.mod1.a_var\n' + expected = "import pkg.mod1\nmy_var = pkg.mod1.a_var\n" self.assertEqual(expected, self.mod2.read()) - self.assertTrue(self.pkg.get_child('mod1.py') is not None) + self.assertTrue(self.pkg.get_child("mod1.py") is not None) def test_moving_resources_using_move_module_for_packages(self): - self.mod1.write('import pkg\nmy_pkg = pkg') - pkg2 = testutils.create_package(self.project, 'pkg2') + self.mod1.write("import pkg\nmy_pkg = pkg") + pkg2 = testutils.create_package(self.project, "pkg2") mover = move.create_move(self.project, self.pkg) mover.get_changes(pkg2).do() - expected = 'import pkg2.pkg\nmy_pkg = pkg2.pkg' + expected = "import pkg2.pkg\nmy_pkg = pkg2.pkg" self.assertEqual(expected, self.mod1.read()) - self.assertTrue(pkg2.get_child('pkg') is not None) + self.assertTrue(pkg2.get_child("pkg") is not None) def test_moving_resources_using_move_module_for_init_dot_py(self): - self.mod1.write('import pkg\nmy_pkg = pkg') - pkg2 = testutils.create_package(self.project, 'pkg2') - init = self.pkg.get_child('__init__.py') + self.mod1.write("import pkg\nmy_pkg = pkg") + pkg2 = testutils.create_package(self.project, "pkg2") + init = self.pkg.get_child("__init__.py") mover = move.create_move(self.project, init) mover.get_changes(pkg2).do() - self.assertEqual('import pkg2.pkg\nmy_pkg = pkg2.pkg', - self.mod1.read()) - self.assertTrue(pkg2.get_child('pkg') is not None) + self.assertEqual("import pkg2.pkg\nmy_pkg = pkg2.pkg", self.mod1.read()) + self.assertTrue(pkg2.get_child("pkg") is not None) def test_moving_module_and_star_imports(self): - self.mod1.write('a_var = 1') - self.mod2.write('from mod1 import *\na = a_var\n') + self.mod1.write("a_var = 1") + self.mod2.write("from mod1 import *\na = a_var\n") mover = move.create_move(self.project, self.mod1) mover.get_changes(self.pkg).do() - self.assertEqual('from pkg.mod1 import *\na = a_var\n', - self.mod2.read()) + self.assertEqual("from pkg.mod1 import *\na = a_var\n", self.mod2.read()) def test_moving_module_and_not_removing_blanks_after_imports(self): - self.mod4.write('a_var = 1') - self.mod2.write('from pkg import mod4\n' - 'import os\n\n\nprint(mod4.a_var)\n') + self.mod4.write("a_var = 1") + self.mod2.write("from pkg import mod4\n" "import os\n\n\nprint(mod4.a_var)\n") mover = move.create_move(self.project, self.mod4) mover.get_changes(self.project.root).do() - self.assertEqual('import os\nimport mod4\n\n\n' - 'print(mod4.a_var)\n', self.mod2.read()) + self.assertEqual( + "import os\nimport mod4\n\n\n" "print(mod4.a_var)\n", self.mod2.read() + ) def test_moving_module_refactoring_and_nonexistent_destinations(self): - self.mod4.write('a_var = 1') - self.mod2.write('from pkg import mod4\n' - 'import os\n\n\nprint(mod4.a_var)\n') + self.mod4.write("a_var = 1") + self.mod2.write("from pkg import mod4\n" "import os\n\n\nprint(mod4.a_var)\n") with self.assertRaises(exceptions.RefactoringError): mover = move.create_move(self.project, self.mod4) mover.get_changes(None).do() def test_moving_methods_choosing_the_correct_class(self): - code = 'class A(object):\n def a_method(self):\n pass\n' + code = "class A(object):\n def a_method(self):\n pass\n" self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) self.assertTrue(isinstance(mover, move.MoveMethod)) def test_moving_methods_getting_new_method_for_empty_methods(self): - code = 'class A(object):\n def a_method(self):\n pass\n' + code = "class A(object):\n def a_method(self):\n pass\n" self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - self.assertEqual('def new_method(self):\n pass\n', - mover.get_new_method('new_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + self.assertEqual( + "def new_method(self):\n pass\n", mover.get_new_method("new_method") + ) def test_moving_methods_getting_new_method_for_constant_methods(self): - code = 'class A(object):\n def a_method(self):\n return 1\n' + code = "class A(object):\n def a_method(self):\n return 1\n" self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - self.assertEqual('def new_method(self):\n return 1\n', - mover.get_new_method('new_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + self.assertEqual( + "def new_method(self):\n return 1\n", mover.get_new_method("new_method") + ) def test_moving_methods_getting_new_method_passing_simple_paremters(self): - code = 'class A(object):\n' \ - ' def a_method(self, p):\n return p\n' + code = "class A(object):\n" " def a_method(self, p):\n return p\n" self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - self.assertEqual('def new_method(self, p):\n return p\n', - mover.get_new_method('new_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + self.assertEqual( + "def new_method(self, p):\n return p\n", + mover.get_new_method("new_method"), + ) def test_moving_methods_getting_new_method_using_main_object(self): - code = 'class A(object):\n attr = 1\n' \ - ' def a_method(host):\n return host.attr\n' + code = ( + "class A(object):\n attr = 1\n" + " def a_method(host):\n return host.attr\n" + ) self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - self.assertEqual('def new_method(self, host):' - '\n return host.attr\n', - mover.get_new_method('new_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + self.assertEqual( + "def new_method(self, host):" "\n return host.attr\n", + mover.get_new_method("new_method"), + ) def test_moving_methods_getting_new_method_renaming_main_object(self): - code = 'class A(object):\n attr = 1\n' \ - ' def a_method(self):\n return self.attr\n' + code = ( + "class A(object):\n attr = 1\n" + " def a_method(self):\n return self.attr\n" + ) self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - self.assertEqual('def new_method(self, host):' - '\n return host.attr\n', - mover.get_new_method('new_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + self.assertEqual( + "def new_method(self, host):" "\n return host.attr\n", + mover.get_new_method("new_method"), + ) def test_moving_methods_gettin_new_method_with_keyword_arguments(self): - code = 'class A(object):\n attr = 1\n' \ - ' def a_method(self, p=None):\n return p\n' + code = ( + "class A(object):\n attr = 1\n" + " def a_method(self, p=None):\n return p\n" + ) self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - self.assertEqual('def new_method(self, p=None):\n return p\n', - mover.get_new_method('new_method')) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + self.assertEqual( + "def new_method(self, p=None):\n return p\n", + mover.get_new_method("new_method"), + ) def test_moving_methods_gettin_new_method_with_many_kinds_arguments(self): - code = 'class A(object):\n attr = 1\n' \ - ' def a_method(self, p1, *args, **kwds):\n' \ - ' return self.attr\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - expected = 'def new_method(self, host, p1, *args, **kwds):\n' \ - ' return host.attr\n' - self.assertEqual(expected, mover.get_new_method('new_method')) + code = ( + "class A(object):\n attr = 1\n" + " def a_method(self, p1, *args, **kwds):\n" + " return self.attr\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + expected = ( + "def new_method(self, host, p1, *args, **kwds):\n" " return host.attr\n" + ) + self.assertEqual(expected, mover.get_new_method("new_method")) def test_moving_methods_getting_new_method_for_multi_line_methods(self): - code = 'class A(object):\n' \ - ' def a_method(self):\n' \ - ' a = 2\n' \ - ' return a\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) + code = ( + "class A(object):\n" + " def a_method(self):\n" + " a = 2\n" + " return a\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) self.assertEqual( - 'def new_method(self):\n a = 2\n return a\n', - mover.get_new_method('new_method')) + "def new_method(self):\n a = 2\n return a\n", + mover.get_new_method("new_method"), + ) def test_moving_methods_getting_old_method_for_constant_methods(self): - self.mod2.write('class B(object):\n pass\n') - code = 'import mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self):\n' \ - ' return 1\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method').do() - expected = 'import mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self):\n' \ - ' return self.attr.new_method()\n' + self.mod2.write("class B(object):\n pass\n") + code = ( + "import mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self):\n" + " return 1\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method").do() + expected = ( + "import mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self):\n" + " return self.attr.new_method()\n" + ) self.assertEqual(expected, self.mod1.read()) def test_moving_methods_getting_getting_changes_for_goal_class(self): - self.mod2.write('class B(object):\n var = 1\n') - code = 'import mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self):\n' \ - ' return 1\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method').do() - expected = 'class B(object):\n' \ - ' var = 1\n\n\n' \ - ' def new_method(self):\n' \ - ' return 1\n' + self.mod2.write("class B(object):\n var = 1\n") + code = ( + "import mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self):\n" + " return 1\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method").do() + expected = ( + "class B(object):\n" + " var = 1\n\n\n" + " def new_method(self):\n" + " return 1\n" + ) self.assertEqual(expected, self.mod2.read()) def test_moving_methods_getting_getting_changes_for_goal_class2(self): - code = 'class B(object):\n var = 1\n\n' \ - 'class A(object):\n attr = B()\n' \ - ' def a_method(self):\n return 1\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method').do() + code = ( + "class B(object):\n var = 1\n\n" + "class A(object):\n attr = B()\n" + " def a_method(self):\n return 1\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method").do() self.assertEqual( - 'class B(object):\n var = 1\n\n\n' - ' def new_method(self):\n' - ' return 1\n\n' - 'class A(object):\n attr = B()\n' - ' def a_method(self):\n' - ' return self.attr.new_method()\n', - self.mod1.read()) + "class B(object):\n var = 1\n\n\n" + " def new_method(self):\n" + " return 1\n\n" + "class A(object):\n attr = B()\n" + " def a_method(self):\n" + " return self.attr.new_method()\n", + self.mod1.read(), + ) def test_moving_methods_and_nonexistent_attributes(self): - code = 'class A(object):\n' \ - ' def a_method(self):\n return 1\n' + code = "class A(object):\n" " def a_method(self):\n return 1\n" self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('x', 'new_method') + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("x", "new_method") def test_unknown_attribute_type(self): - code = 'class A(object):\n attr = 1\n' \ - ' def a_method(self):\n return 1\n' + code = ( + "class A(object):\n attr = 1\n" + " def a_method(self):\n return 1\n" + ) self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method') + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method") def test_moving_methods_and_moving_used_imports(self): - self.mod2.write('class B(object):\n var = 1\n') - code = 'import sys\nimport mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self):\n' \ - ' return sys.version\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method').do() - code = 'import sys\n' \ - 'class B(object):\n' \ - ' var = 1\n\n\n' \ - ' def new_method(self):\n' \ - ' return sys.version\n' + self.mod2.write("class B(object):\n var = 1\n") + code = ( + "import sys\nimport mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self):\n" + " return sys.version\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method").do() + code = ( + "import sys\n" + "class B(object):\n" + " var = 1\n\n\n" + " def new_method(self):\n" + " return sys.version\n" + ) self.assertEqual(code, self.mod2.read()) def test_moving_methods_getting_getting_changes_for_goal_class3(self): - self.mod2.write('class B(object):\n pass\n') - code = 'import mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self):\n' \ - ' return 1\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method').do() - expected = 'class B(object):\n\n' \ - ' def new_method(self):\n' \ - ' return 1\n' + self.mod2.write("class B(object):\n pass\n") + code = ( + "import mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self):\n" + " return 1\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method").do() + expected = ( + "class B(object):\n\n" " def new_method(self):\n" " return 1\n" + ) self.assertEqual(expected, self.mod2.read()) def test_moving_methods_and_source_class_with_parameters(self): - self.mod2.write('class B(object):\n pass\n') - code = 'import mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self, p):\n return p\n' - self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('a_method')) - mover.get_changes('attr', 'new_method').do() - expected1 = 'import mod2\n\n' \ - 'class A(object):\n' \ - ' attr = mod2.B()\n' \ - ' def a_method(self, p):\n' \ - ' return self.attr.new_method(p)\n' + self.mod2.write("class B(object):\n pass\n") + code = ( + "import mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self, p):\n return p\n" + ) + self.mod1.write(code) + mover = move.create_move(self.project, self.mod1, code.index("a_method")) + mover.get_changes("attr", "new_method").do() + expected1 = ( + "import mod2\n\n" + "class A(object):\n" + " attr = mod2.B()\n" + " def a_method(self, p):\n" + " return self.attr.new_method(p)\n" + ) self.assertEqual(expected1, self.mod1.read()) - expected2 = 'class B(object):\n\n' \ - ' def new_method(self, p):\n' \ - ' return p\n' + expected2 = ( + "class B(object):\n\n" " def new_method(self, p):\n" " return p\n" + ) self.assertEqual(expected2, self.mod2.read()) def test_moving_globals_to_a_module_with_only_docstrings(self): - self.mod1.write('import sys\n\n\ndef f():\n print(sys.version)\n') + self.mod1.write("import sys\n\n\ndef f():\n print(sys.version)\n") self.mod2.write('"""doc\n\nMore docs ...\n\n"""\n') - mover = move.create_move(self.project, self.mod1, - self.mod1.read().index('f()') + 1) + mover = move.create_move( + self.project, self.mod1, self.mod1.read().index("f()") + 1 + ) self.project.do(mover.get_changes(self.mod2)) self.assertEqual( '"""doc\n\nMore docs ...\n\n"""\n' - 'import sys\n\n\ndef f():\n print(sys.version)\n', - self.mod2.read()) + "import sys\n\n\ndef f():\n print(sys.version)\n", + self.mod2.read(), + ) def test_moving_globals_to_a_module_with_only_docstrings2(self): - code = 'import os\n' \ - 'import sys\n\n\n' \ - 'def f():\n' \ - ' print(sys.version, os.path)\n' + code = ( + "import os\n" + "import sys\n\n\n" + "def f():\n" + " print(sys.version, os.path)\n" + ) self.mod1.write(code) self.mod2.write('"""doc\n\nMore docs ...\n\n"""\n') - mover = move.create_move(self.project, self.mod1, - self.mod1.read().index('f()') + 1) + mover = move.create_move( + self.project, self.mod1, self.mod1.read().index("f()") + 1 + ) self.project.do(mover.get_changes(self.mod2)) - expected = '"""doc\n\nMore docs ...\n\n"""\n' \ - 'import os\n' \ - 'import sys\n\n\n' \ - 'def f():\n' \ - ' print(sys.version, os.path)\n' + expected = ( + '"""doc\n\nMore docs ...\n\n"""\n' + "import os\n" + "import sys\n\n\n" + "def f():\n" + " print(sys.version, os.path)\n" + ) self.assertEqual(expected, self.mod2.read()) def test_moving_a_global_when_it_is_used_after_a_multiline_str(self): code = 'def f():\n pass\ns = """\\\n"""\nr = f()\n' self.mod1.write(code) - mover = move.create_move(self.project, self.mod1, - code.index('f()') + 1) + mover = move.create_move(self.project, self.mod1, code.index("f()") + 1) self.project.do(mover.get_changes(self.mod2)) expected = 'import mod2\ns = """\\\n"""\nr = mod2.f()\n' self.assertEqual(expected, self.mod1.read()) def test_raising_an_exception_when_moving_non_package_folders(self): - dir = self.project.root.create_folder('dir') + dir = self.project.root.create_folder("dir") with self.assertRaises(exceptions.RefactoringError): move.create_move(self.project, dir) def test_moving_to_a_module_with_encoding_cookie(self): - code1 = '# -*- coding: utf-8 -*-' + code1 = "# -*- coding: utf-8 -*-" self.mod1.write(code1) - code2 = 'def f(): pass\n' + code2 = "def f(): pass\n" self.mod2.write(code2) - mover = move.create_move(self.project, self.mod2, - code2.index('f()') + 1) + mover = move.create_move(self.project, self.mod2, code2.index("f()") + 1) self.project.do(mover.get_changes(self.mod1)) - expected = '%s\n%s' % (code1, code2) + expected = "%s\n%s" % (code1, code2) self.assertEqual(expected, self.mod1.read()) diff --git a/ropetest/refactor/multiprojecttest.py b/ropetest/refactor/multiprojecttest.py index 61f8c5896..1581838ab 100644 --- a/ropetest/refactor/multiprojecttest.py +++ b/ropetest/refactor/multiprojecttest.py @@ -8,14 +8,13 @@ class MultiProjectRefactoringTest(unittest.TestCase): - def setUp(self): super(MultiProjectRefactoringTest, self).setUp() - self.project1 = testutils.sample_project(foldername='testproject1') - self.project2 = testutils.sample_project(foldername='testproject2') - self.mod1 = self.project1.root.create_file('mod1.py') - self.other = self.project1.root.create_file('other.py') - self.mod2 = self.project2.root.create_file('mod2.py') + self.project1 = testutils.sample_project(foldername="testproject1") + self.project2 = testutils.sample_project(foldername="testproject2") + self.mod1 = self.project1.root.create_file("mod1.py") + self.other = self.project1.root.create_file("other.py") + self.mod2 = self.project2.root.create_file("mod2.py") def tearDown(self): testutils.remove_project(self.project1) @@ -23,46 +22,43 @@ def tearDown(self): super(MultiProjectRefactoringTest, self).tearDown() def test_trivial_rename(self): - self.mod1.write('var = 1\n') - refactoring = multiproject.MultiProjectRefactoring( - rename.Rename, []) + self.mod1.write("var = 1\n") + refactoring = multiproject.MultiProjectRefactoring(rename.Rename, []) renamer = refactoring(self.project1, self.mod1, 1) - multiproject.perform(renamer.get_all_changes('newvar')) - self.assertEqual('newvar = 1\n', self.mod1.read()) + multiproject.perform(renamer.get_all_changes("newvar")) + self.assertEqual("newvar = 1\n", self.mod1.read()) def test_rename(self): - self.mod1.write('var = 1\n') - self.mod2.write('import mod1\nmyvar = mod1.var\n') + self.mod1.write("var = 1\n") + self.mod2.write("import mod1\nmyvar = mod1.var\n") refactoring = multiproject.MultiProjectRefactoring( - rename.Rename, [self.project2]) + rename.Rename, [self.project2] + ) renamer = refactoring(self.project1, self.mod1, 1) - multiproject.perform(renamer.get_all_changes('newvar')) - self.assertEqual('newvar = 1\n', self.mod1.read()) - self.assertEqual('import mod1\nmyvar = mod1.newvar\n', - self.mod2.read()) + multiproject.perform(renamer.get_all_changes("newvar")) + self.assertEqual("newvar = 1\n", self.mod1.read()) + self.assertEqual("import mod1\nmyvar = mod1.newvar\n", self.mod2.read()) def test_move(self): - self.mod1.write('def a_func():\n pass\n') - self.mod2.write('import mod1\nmyvar = mod1.a_func()\n') + self.mod1.write("def a_func():\n pass\n") + self.mod2.write("import mod1\nmyvar = mod1.a_func()\n") refactoring = multiproject.MultiProjectRefactoring( - move.create_move, [self.project2]) - renamer = refactoring(self.project1, self.mod1, - self.mod1.read().index('_func')) + move.create_move, [self.project2] + ) + renamer = refactoring(self.project1, self.mod1, self.mod1.read().index("_func")) multiproject.perform(renamer.get_all_changes(self.other)) - self.assertEqual('', self.mod1.read()) - self.assertEqual('def a_func():\n pass\n', self.other.read()) - self.assertEqual('import other\nmyvar = other.a_func()\n', - self.mod2.read()) + self.assertEqual("", self.mod1.read()) + self.assertEqual("def a_func():\n pass\n", self.other.read()) + self.assertEqual("import other\nmyvar = other.a_func()\n", self.mod2.read()) def test_rename_from_the_project_not_containing_the_change(self): - self.project2.get_prefs().add('python_path', self.project1.address) - self.mod1.write('var = 1\n') - self.mod2.write('import mod1\nmyvar = mod1.var\n') + self.project2.get_prefs().add("python_path", self.project1.address) + self.mod1.write("var = 1\n") + self.mod2.write("import mod1\nmyvar = mod1.var\n") refactoring = multiproject.MultiProjectRefactoring( - rename.Rename, [self.project1]) - renamer = refactoring(self.project2, self.mod2, - self.mod2.read().rindex('var')) - multiproject.perform(renamer.get_all_changes('newvar')) - self.assertEqual('newvar = 1\n', self.mod1.read()) - self.assertEqual('import mod1\nmyvar = mod1.newvar\n', - self.mod2.read()) + rename.Rename, [self.project1] + ) + renamer = refactoring(self.project2, self.mod2, self.mod2.read().rindex("var")) + multiproject.perform(renamer.get_all_changes("newvar")) + self.assertEqual("newvar = 1\n", self.mod1.read()) + self.assertEqual("import mod1\nmyvar = mod1.newvar\n", self.mod2.read()) diff --git a/ropetest/refactor/patchedasttest.py b/ropetest/refactor/patchedasttest.py index 9f58abbd8..cb9e21a5f 100644 --- a/ropetest/refactor/patchedasttest.py +++ b/ropetest/refactor/patchedasttest.py @@ -15,11 +15,11 @@ except NameError: basestring = (str, bytes) -NameConstant = 'Name' if sys.version_info <= (3, 8) else 'NameConstant' -Bytes = 'Bytes' if (3, 0) <= sys.version_info <= (3, 8) else 'Str' +NameConstant = "Name" if sys.version_info <= (3, 8) else "NameConstant" +Bytes = "Bytes" if (3, 0) <= sys.version_info <= (3, 8) else "Str" -class PatchedASTTest(unittest.TestCase): +class PatchedASTTest(unittest.TestCase): def setUp(self): super(PatchedASTTest, self).setUp() @@ -36,798 +36,1055 @@ def test_bytes_string(self): checker.check_children(Bytes, [str_fragment]) def test_integer_literals_and_region(self): - source = 'a = 10\n' + source = "a = 10\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('10') - checker.check_region('Num', start, start + 2) + start = source.index("10") + checker.check_region("Num", start, start + 2) def test_negative_integer_literals_and_region(self): - source = 'a = -10\n' + source = "a = -10\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('-10') + start = source.index("-10") end = start + 3 # Python 3 parses as UnaryOp(op=USub(), operand=Num(n=10)) if pycompat.PY3: - start += 1 - checker.check_region('Num', start, end) + start += 1 + checker.check_region("Num", start, end) def test_scientific_integer_literals_and_region(self): - source = 'a = -1.0e-3\n' + source = "a = -1.0e-3\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('-1.0e-3') + start = source.index("-1.0e-3") end = start + 7 # Python 3 parses as UnaryOp(op=USub(), operand=Num(n=10)) if pycompat.PY3: - start += 1 - checker.check_region('Num', start, end) + start += 1 + checker.check_region("Num", start, end) def test_hex_integer_literals_and_region(self): - source = 'a = 0x1\n' + source = "a = 0x1\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('0x1') - checker.check_region('Num', start, start + 3) + start = source.index("0x1") + checker.check_region("Num", start, start + 3) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_long_literals_and_region(self): - source = 'a = 0x1L\n' + source = "a = 0x1L\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('0x1L') - checker.check_region('Num', start, start + 4) + start = source.index("0x1L") + checker.check_region("Num", start, start + 4) def test_octal_integer_literals_and_region(self): - source = 'a = -0125e1\n' + source = "a = -0125e1\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('-0125e1') + start = source.index("-0125e1") end = start + 7 # Python 3 parses as UnaryOp(op=USub(), operand=Num(n=10)) if pycompat.PY3: - start += 1 - checker.check_region('Num', start, end) + start += 1 + checker.check_region("Num", start, end) def test_integer_literals_and_sorted_children(self): - source = 'a = 10\n' + source = "a = 10\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) # start = source.index('10') - checker.check_children('Num', ['10']) + checker.check_children("Num", ["10"]) def test_ellipsis(self): - source = 'a[...]\n' + source = "a[...]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('...') - checker.check_region('Ellipsis', start, start + len('...')) + start = source.index("...") + checker.check_region("Ellipsis", start, start + len("...")) def test_ass_name_node(self): - source = 'a = 10\n' + source = "a = 10\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('a') - checker.check_region('Name', start, start + 1) - checker.check_children('Name', ['a']) + start = source.index("a") + checker.check_region("Name", start, start + 1) + checker.check_children("Name", ["a"]) def test_assign_node(self): - source = 'a = 10\n' + source = "a = 10\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('a') # noqa - checker.check_region('Assign', 0, len(source) - 1) - checker.check_children( - 'Assign', ['Name', ' ', '=', ' ', 'Num']) + start = source.index("a") # noqa + checker.check_region("Assign", 0, len(source) - 1) + checker.check_children("Assign", ["Name", " ", "=", " ", "Num"]) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_ann_assign_node_without_target(self): - source = 'a: List[int]\n' + source = "a: List[int]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('a') # noqa - checker.check_region('AnnAssign', 0, len(source) - 1) - checker.check_children( - 'AnnAssign', ['Name', '', ':', ' ', 'Subscript']) + start = source.index("a") # noqa + checker.check_region("AnnAssign", 0, len(source) - 1) + checker.check_children("AnnAssign", ["Name", "", ":", " ", "Subscript"]) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_ann_assign_node_with_target(self): - source = 'a: int = 10\n' + source = "a: int = 10\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('a') # noqa - checker.check_region('AnnAssign', 0, len(source) - 1) + start = source.index("a") # noqa + checker.check_region("AnnAssign", 0, len(source) - 1) checker.check_children( - 'AnnAssign', ['Name', '', ':', ' ', 'Name', ' ', '=', ' ', 'Num']) + "AnnAssign", ["Name", "", ":", " ", "Name", " ", "=", " ", "Num"] + ) def test_add_node(self): - source = '1 + 2\n' + source = "1 + 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' ', 'Num']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["Num", " ", "+", " ", "Num"]) def test_lshift_node(self): - source = '1 << 2\n' + source = "1 << 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['Num', ' ', '<<', ' ', 'Num']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["Num", " ", "<<", " ", "Num"]) def test_and_node(self): - source = 'True and True\n' + source = "True and True\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BoolOp', 0, len(source) - 1) - checker.check_children( - 'BoolOp', [NameConstant, ' ', 'and', ' ', NameConstant]) + checker.check_region("BoolOp", 0, len(source) - 1) + checker.check_children("BoolOp", [NameConstant, " ", "and", " ", NameConstant]) def test_basic_closing_parens(self): - source = '1 + (2)\n' + source = "1 + (2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' (', 'Num', ')']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["Num", " ", "+", " (", "Num", ")"]) def test_basic_opening_parens(self): - source = '(1) + 2\n' + source = "(1) + 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['(', 'Num', ') ', '+', ' ', 'Num']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["(", "Num", ") ", "+", " ", "Num"]) def test_basic_opening_biway(self): - source = '(1) + (2)\n' + source = "(1) + (2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['(', 'Num', ') ', '+', ' (', 'Num', ')']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["(", "Num", ") ", "+", " (", "Num", ")"]) def test_basic_opening_double(self): - source = '1 + ((2))\n' + source = "1 + ((2))\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' ((', 'Num', '))']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["Num", " ", "+", " ((", "Num", "))"]) def test_handling_comments(self): - source = '(1 + #(\n2)\n' + source = "(1 + #(\n2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' #(\n', 'Num']) + checker.check_children("BinOp", ["Num", " ", "+", " #(\n", "Num"]) def test_handling_parens_with_spaces(self): - source = '1 + (2\n )\n' + source = "1 + (2\n )\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' (', 'Num', '\n )']) + checker.check_children("BinOp", ["Num", " ", "+", " (", "Num", "\n )"]) def test_handling_strings(self): source = '1 + "("\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' ', 'Str']) + checker.check_children("BinOp", ["Num", " ", "+", " ", "Str"]) def test_handling_implicit_string_concatenation(self): source = "a = '1''2'" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Assign', ['Name', ' ', '=', ' ', 'Str']) - checker.check_children('Str', ["'1''2'"]) + checker.check_children("Assign", ["Name", " ", "=", " ", "Str"]) + checker.check_children("Str", ["'1''2'"]) def test_handling_implicit_string_concatenation_line_breaks(self): source = "a = '1' \\\n'2'" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Assign', ['Name', ' ', '=', ' ', 'Str']) - checker.check_children('Str', ["'1' \\\n'2'"]) + checker.check_children("Assign", ["Name", " ", "=", " ", "Str"]) + checker.check_children("Str", ["'1' \\\n'2'"]) def test_handling_explicit_string_concatenation_line_breaks(self): source = "a = ('1' \n'2')" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Assign', ['Name', ' ', '=', ' (', 'Str', ')']) - checker.check_children('Str', ["'1' \n'2'"]) + checker.check_children("Assign", ["Name", " ", "=", " (", "Str", ")"]) + checker.check_children("Str", ["'1' \n'2'"]) def test_not_concatenating_strings_on_separate_lines(self): source = "'1'\n'2'\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Module', ['', 'Expr', '\n', 'Expr', '\n']) + checker.check_children("Module", ["", "Expr", "\n", "Expr", "\n"]) def test_handling_raw_strings(self): source = 'r"abc"\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Str', ['r"abc"']) + checker.check_children("Str", ['r"abc"']) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_handling_format_strings_basic(self): source = '1 + f"abc{a}"\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'JoinedStr', ['f"', 'abc', 'FormattedValue', '', '"']) - checker.check_children( - 'FormattedValue', ['{', '', 'Name', '', '}']) + checker.check_children("JoinedStr", ['f"', "abc", "FormattedValue", "", '"']) + checker.check_children("FormattedValue", ["{", "", "Name", "", "}"]) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_handling_format_strings_with_implicit_join(self): source = '''"1" + rf'abc{a}' f"""xxx{b} """\n''' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'JoinedStr', ["rf'", 'abc', 'FormattedValue', '\' f"""xxx', 'FormattedValue', ' ', '"""']) - checker.check_children( - 'FormattedValue', ['{', '', 'Name', '', '}']) + "JoinedStr", + [ + "rf'", + "abc", + "FormattedValue", + '\' f"""xxx', + "FormattedValue", + " ", + '"""', + ], + ) + checker.check_children("FormattedValue", ["{", "", "Name", "", "}"]) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_handling_format_strings_with_format_spec(self): source = 'f"abc{a:01}"\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) + checker.check_children("JoinedStr", ['f"', "abc", "FormattedValue", "", '"']) checker.check_children( - 'JoinedStr', ['f"', 'abc', 'FormattedValue', '', '"']) - checker.check_children( - 'FormattedValue', ['{', '', 'Name', '', ':', '', '01', '', '}']) + "FormattedValue", ["{", "", "Name", "", ":", "", "01", "", "}"] + ) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_handling_format_strings_with_inner_format_spec(self): source = 'f"abc{a:{length}01}"\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) + checker.check_children("JoinedStr", ['f"', "abc", "FormattedValue", "", '"']) checker.check_children( - 'JoinedStr', ['f"', 'abc', 'FormattedValue', '', '"']) - checker.check_children( - 'FormattedValue', ['{', '', 'Name', '', ':', '{', 'Name', '}', '01', '', '}']) + "FormattedValue", + ["{", "", "Name", "", ":", "{", "Name", "}", "01", "", "}"], + ) - @testutils.only_for_versions_higher('3.6') + @testutils.only_for_versions_higher("3.6") def test_handling_format_strings_with_expression(self): source = 'f"abc{a + b}"\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'JoinedStr', ['f"', 'abc', 'FormattedValue', '', '"']) - checker.check_children( - 'FormattedValue', ['{', '', 'BinOp', '', '}']) + checker.check_children("JoinedStr", ['f"', "abc", "FormattedValue", "", '"']) + checker.check_children("FormattedValue", ["{", "", "BinOp", "", "}"]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_long_integer_literals(self): source = "0x1L + a" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' ', 'Name']) - checker.check_children('Num', ['0x1L']) + checker.check_children("BinOp", ["Num", " ", "+", " ", "Name"]) + checker.check_children("Num", ["0x1L"]) def test_complex_number_literals(self): source = "1.0e2j + a" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'BinOp', ['Num', ' ', '+', ' ', 'Name']) - checker.check_children('Num', ['1.0e2j']) + checker.check_children("BinOp", ["Num", " ", "+", " ", "Name"]) + checker.check_children("Num", ["1.0e2j"]) def test_ass_attr_node(self): - source = 'a.b = 1\n' + source = "a.b = 1\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Attribute', 0, source.index('=') - 1) - checker.check_children('Attribute', ['Name', '', '.', '', 'b']) + checker.check_region("Attribute", 0, source.index("=") - 1) + checker.check_children("Attribute", ["Name", "", ".", "", "b"]) def test_ass_list_node(self): - source = '[a, b] = 1, 2\n' + source = "[a, b] = 1, 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('List', 0, source.index(']') + 1) - checker.check_children('List', ['[', '', 'Name', '', ',', - ' ', 'Name', '', ']']) + checker.check_region("List", 0, source.index("]") + 1) + checker.check_children("List", ["[", "", "Name", "", ",", " ", "Name", "", "]"]) def test_ass_tuple(self): - source = 'a, b = range(2)\n' + source = "a, b = range(2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Tuple', 0, source.index('=') - 1) - checker.check_children( - 'Tuple', ['Name', '', ',', ' ', 'Name']) + checker.check_region("Tuple", 0, source.index("=") - 1) + checker.check_children("Tuple", ["Name", "", ",", " ", "Name"]) def test_ass_tuple2(self): - source = '(a, b) = range(2)\n' + source = "(a, b) = range(2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Tuple', 0, source.index('=') - 1) + checker.check_region("Tuple", 0, source.index("=") - 1) checker.check_children( - 'Tuple', ['(', '', 'Name', '', ',', ' ', 'Name', '', ')']) + "Tuple", ["(", "", "Name", "", ",", " ", "Name", "", ")"] + ) def test_assert(self): - source = 'assert True\n' + source = "assert True\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Assert', 0, len(source) - 1) - checker.check_children( - 'Assert', ['assert', ' ', NameConstant]) + checker.check_region("Assert", 0, len(source) - 1) + checker.check_children("Assert", ["assert", " ", NameConstant]) def test_assert2(self): source = 'assert True, "error"\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Assert', 0, len(source) - 1) + checker.check_region("Assert", 0, len(source) - 1) checker.check_children( - 'Assert', ['assert', ' ', NameConstant, '', ',', ' ', 'Str']) + "Assert", ["assert", " ", NameConstant, "", ",", " ", "Str"] + ) def test_aug_assign_node(self): - source = 'a += 1\n' + source = "a += 1\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('a') # noqa - checker.check_region('AugAssign', 0, len(source) - 1) - checker.check_children( - 'AugAssign', ['Name', ' ', '+', '', '=', ' ', 'Num']) + start = source.index("a") # noqa + checker.check_region("AugAssign", 0, len(source) - 1) + checker.check_children("AugAssign", ["Name", " ", "+", "", "=", " ", "Num"]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_back_quotenode(self): - source = '`1`\n' + source = "`1`\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Repr', 0, len(source) - 1) - checker.check_children( - 'Repr', ['`', '', 'Num', '', '`']) + checker.check_region("Repr", 0, len(source) - 1) + checker.check_children("Repr", ["`", "", "Num", "", "`"]) def test_bitand(self): - source = '1 & 2\n' + source = "1 & 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children( - 'BinOp', ['Num', ' ', '&', ' ', 'Num']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["Num", " ", "&", " ", "Num"]) def test_bitor(self): - source = '1 | 2\n' + source = "1 | 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'BinOp', ['Num', ' ', '|', ' ', 'Num']) + checker.check_children("BinOp", ["Num", " ", "|", " ", "Num"]) def test_call_func(self): - source = 'f(1, 2)\n' + source = "f(1, 2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Call', 0, len(source) - 1) + checker.check_region("Call", 0, len(source) - 1) checker.check_children( - 'Call', ['Name', '', '(', '', 'Num', '', ',', - ' ', 'Num', '', ')']) + "Call", ["Name", "", "(", "", "Num", "", ",", " ", "Num", "", ")"] + ) def test_call_func_and_keywords(self): - source = 'f(1, p=2)\n' + source = "f(1, p=2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '', 'Num', '', ',', - ' ', 'keyword', '', ')']) + "Call", ["Name", "", "(", "", "Num", "", ",", " ", "keyword", "", ")"] + ) - @testutils.only_for_versions_lower('3.5') + @testutils.only_for_versions_lower("3.5") def test_call_func_and_star_args(self): - source = 'f(1, *args)\n' + source = "f(1, *args)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '', 'Num', '', ',', - ' ', '*', '', 'Name', '', ')']) + "Call", ["Name", "", "(", "", "Num", "", ",", " ", "*", "", "Name", "", ")"] + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_call_func_and_star_argspython35(self): - source = 'f(1, *args)\n' + source = "f(1, *args)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '', 'Num', '', ',', - ' *', 'Starred', '', ')']) + "Call", ["Name", "", "(", "", "Num", "", ",", " *", "Starred", "", ")"] + ) - @testutils.only_for_versions_lower('3.5') + @testutils.only_for_versions_lower("3.5") def test_call_func_and_only_dstar_args(self): - source = 'f(**kwds)\n' + source = "f(**kwds)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Call', ['Name', '', '(', '', '**', '', 'Name', '', ')']) + checker.check_children("Call", ["Name", "", "(", "", "**", "", "Name", "", ")"]) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_call_func_and_only_dstar_args_python35(self): - source = 'f(**kwds)\n' + source = "f(**kwds)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Call', ['Name', '', '(', '**', 'keyword', '', ')']) + checker.check_children("Call", ["Name", "", "(", "**", "keyword", "", ")"]) - @testutils.only_for_versions_lower('3.5') + @testutils.only_for_versions_lower("3.5") def test_call_func_and_both_varargs_and_kwargs(self): - source = 'f(*args, **kwds)\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Call', ['Name', '', '(', '', '*', '', 'Name', '', ',', - ' ', '**', '', 'Name', '', ')']) + source = "f(*args, **kwds)\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_children( + "Call", + [ + "Name", + "", + "(", + "", + "*", + "", + "Name", + "", + ",", + " ", + "**", + "", + "Name", + "", + ")", + ], + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_call_func_and_both_varargs_and_kwargs_python35(self): - source = 'f(*args, **kwds)\n' + source = "f(*args, **kwds)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '*', 'Starred', '', ',', - ' **', 'keyword', '', ')']) + "Call", + ["Name", "", "(", "*", "Starred", "", ",", " **", "keyword", "", ")"], + ) def test_class_node(self): source = 'class A(object):\n """class docs"""\n pass\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Class', 0, len(source) - 1) - checker.check_children( - 'Class', ['class', ' ', 'A', '', '(', '', 'Name', '', ')', - '', ':', '\n ', 'Expr', '\n ', 'Pass']) + checker.check_region("Class", 0, len(source) - 1) + checker.check_children( + "Class", + [ + "class", + " ", + "A", + "", + "(", + "", + "Name", + "", + ")", + "", + ":", + "\n ", + "Expr", + "\n ", + "Pass", + ], + ) def test_class_with_no_bases(self): - source = 'class A:\n pass\n' + source = "class A:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Class', 0, len(source) - 1) - checker.check_children( - 'Class', ['class', ' ', 'A', '', ':', '\n ', 'Pass']) + checker.check_region("Class", 0, len(source) - 1) + checker.check_children("Class", ["class", " ", "A", "", ":", "\n ", "Pass"]) def test_simple_compare(self): - source = '1 < 2\n' + source = "1 < 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Compare', 0, len(source) - 1) - checker.check_children( - 'Compare', ['Num', ' ', '<', ' ', 'Num']) + checker.check_region("Compare", 0, len(source) - 1) + checker.check_children("Compare", ["Num", " ", "<", " ", "Num"]) def test_multiple_compare(self): - source = '1 < 2 <= 3\n' + source = "1 < 2 <= 3\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Compare', 0, len(source) - 1) + checker.check_region("Compare", 0, len(source) - 1) checker.check_children( - 'Compare', ['Num', ' ', '<', ' ', 'Num', ' ', - '<=', ' ', 'Num']) + "Compare", ["Num", " ", "<", " ", "Num", " ", "<=", " ", "Num"] + ) def test_decorators_node(self): - source = '@d\ndef f():\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('FunctionDef', 0, len(source) - 1) - checker.check_children( - 'FunctionDef', - ['@', '', 'Name', '\n', 'def', ' ', 'f', '', '(', '', 'arguments', - '', ')', '', ':', '\n ', 'Pass']) + source = "@d\ndef f():\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("FunctionDef", 0, len(source) - 1) + checker.check_children( + "FunctionDef", + [ + "@", + "", + "Name", + "\n", + "def", + " ", + "f", + "", + "(", + "", + "arguments", + "", + ")", + "", + ":", + "\n ", + "Pass", + ], + ) - @testutils.only_for('2.6') + @testutils.only_for("2.6") def test_decorators_for_classes(self): - source = '@d\nclass C(object):\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('ClassDef', 0, len(source) - 1) - checker.check_children( - 'ClassDef', - ['@', '', 'Name', '\n', 'class', ' ', 'C', '', '(', '', 'Name', - '', ')', '', ':', '\n ', 'Pass']) + source = "@d\nclass C(object):\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("ClassDef", 0, len(source) - 1) + checker.check_children( + "ClassDef", + [ + "@", + "", + "Name", + "\n", + "class", + " ", + "C", + "", + "(", + "", + "Name", + "", + ")", + "", + ":", + "\n ", + "Pass", + ], + ) def test_both_varargs_and_kwargs(self): - source = 'def f(*args, **kwds):\n pass\n' + source = "def f(*args, **kwds):\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'arguments', ['*', '', 'args', '', ',', ' ', '**', '', 'kwds']) + "arguments", ["*", "", "args", "", ",", " ", "**", "", "kwds"] + ) def test_function_node(self): - source = 'def f():\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('Function', 0, len(source) - 1) - checker.check_children('Function', - ['def', ' ', 'f', '', '(', '', 'arguments', '', - ')', '', ':', '\n ', 'Pass']) + source = "def f():\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("Function", 0, len(source) - 1) + checker.check_children( + "Function", + [ + "def", + " ", + "f", + "", + "(", + "", + "arguments", + "", + ")", + "", + ":", + "\n ", + "Pass", + ], + ) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_async_function_node(self): - source = 'async def f():\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('AsyncFunction', 0, len(source) - 1) - checker.check_children('AsyncFunction', - ['async', ' ', 'def', ' ', 'f', '', '(', '', 'arguments', '', - ')', '', ':', '\n ', 'Pass']) + source = "async def f():\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("AsyncFunction", 0, len(source) - 1) + checker.check_children( + "AsyncFunction", + [ + "async", + " ", + "def", + " ", + "f", + "", + "(", + "", + "arguments", + "", + ")", + "", + ":", + "\n ", + "Pass", + ], + ) def test_function_node2(self): source = 'def f(p1, **p2):\n """docs"""\n pass\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Function', 0, len(source) - 1) - checker.check_children( - 'Function', ['def', ' ', 'f', '', '(', '', 'arguments', - '', ')', '', ':', '\n ', 'Expr', '\n ', - 'Pass']) + checker.check_region("Function", 0, len(source) - 1) + checker.check_children( + "Function", + [ + "def", + " ", + "f", + "", + "(", + "", + "arguments", + "", + ")", + "", + ":", + "\n ", + "Expr", + "\n ", + "Pass", + ], + ) expected_child = pycompat.ast_arg_type.__name__ checker.check_children( - 'arguments', [expected_child, '', ',', - ' ', '**', '', 'p2']) + "arguments", [expected_child, "", ",", " ", "**", "", "p2"] + ) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_function_node_and_tuple_parameters(self): - source = 'def f(a, (b, c)):\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('Function', 0, len(source) - 1) - checker.check_children( - 'Function', ['def', ' ', 'f', '', '(', '', 'arguments', - '', ')', '', ':', '\n ', 'Pass']) - checker.check_children( - 'arguments', ['Name', '', ',', ' ', 'Tuple']) + source = "def f(a, (b, c)):\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("Function", 0, len(source) - 1) + checker.check_children( + "Function", + [ + "def", + " ", + "f", + "", + "(", + "", + "arguments", + "", + ")", + "", + ":", + "\n ", + "Pass", + ], + ) + checker.check_children("arguments", ["Name", "", ",", " ", "Tuple"]) def test_dict_node(self): - source = '{1: 2, 3: 4}\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('Dict', 0, len(source) - 1) - checker.check_children( - 'Dict', ['{', '', 'Num', '', ':', ' ', 'Num', '', ',', - ' ', 'Num', '', ':', ' ', 'Num', '', '}']) + source = "{1: 2, 3: 4}\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("Dict", 0, len(source) - 1) + checker.check_children( + "Dict", + [ + "{", + "", + "Num", + "", + ":", + " ", + "Num", + "", + ",", + " ", + "Num", + "", + ":", + " ", + "Num", + "", + "}", + ], + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_dict_node_with_unpacking(self): - source = '{**dict1, **dict2}\n' + source = "{**dict1, **dict2}\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Dict', 0, len(source) - 1) + checker.check_region("Dict", 0, len(source) - 1) checker.check_children( - 'Dict', ['{', '', '**', '', 'Name', '', ',', - ' ', '**', '', 'Name', '', '}']) + "Dict", ["{", "", "**", "", "Name", "", ",", " ", "**", "", "Name", "", "}"] + ) def test_div_node(self): - source = '1 / 2\n' + source = "1 / 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('BinOp', 0, len(source) - 1) - checker.check_children('BinOp', ['Num', ' ', '/', ' ', 'Num']) + checker.check_region("BinOp", 0, len(source) - 1) + checker.check_children("BinOp", ["Num", " ", "/", " ", "Num"]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_simple_exec_node(self): source = 'exec ""\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Exec', 0, len(source) - 1) - checker.check_children('Exec', ['exec', '', '', ' ', 'Str', '', '']) + checker.check_region("Exec", 0, len(source) - 1) + checker.check_children("Exec", ["exec", "", "", " ", "Str", "", ""]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_exec_node(self): source = 'exec "" in locals(), globals()\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Exec', 0, len(source) - 1) - checker.check_children( - 'Exec', ['exec', '', '', ' ', 'Str', ' ', 'in', - ' ', 'Call', '', ',', ' ', 'Call', '', '']) + checker.check_region("Exec", 0, len(source) - 1) + checker.check_children( + "Exec", + [ + "exec", + "", + "", + " ", + "Str", + " ", + "in", + " ", + "Call", + "", + ",", + " ", + "Call", + "", + "", + ], + ) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_exec_node_with_parens(self): source = 'exec("", locals(), globals())\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Exec', 0, len(source) - 1) - checker.check_children( - 'Exec', ['exec', '', '(', '', 'Str', '', ',', - ' ', 'Call', '', ',', ' ', 'Call', '', ')']) + checker.check_region("Exec", 0, len(source) - 1) + checker.check_children( + "Exec", + [ + "exec", + "", + "(", + "", + "Str", + "", + ",", + " ", + "Call", + "", + ",", + " ", + "Call", + "", + ")", + ], + ) def test_for_node(self): - source = dedent('''\ + source = dedent("""\ for i in range(1): pass else: pass - ''') - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('For', 0, len(source) - 1) - checker.check_children( - 'For', ['for', ' ', 'Name', ' ', 'in', ' ', 'Call', '', - ':', '\n ', 'Pass', '\n', - 'else', '', ':', '\n ', 'Pass']) + """) + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("For", 0, len(source) - 1) + checker.check_children( + "For", + [ + "for", + " ", + "Name", + " ", + "in", + " ", + "Call", + "", + ":", + "\n ", + "Pass", + "\n", + "else", + "", + ":", + "\n ", + "Pass", + ], + ) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_async_for_node(self): - source = dedent('''\ + source = dedent("""\ async def foo(): async for i in range(1): pass else: pass - ''') - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('AsyncFor', source.index('async for'), len(source) - 1) - checker.check_children( - 'AsyncFor', ['async', ' ', 'for', ' ', 'Name', ' ', 'in', ' ', 'Call', '', - ':', '\n ', 'Pass', '\n ', - 'else', '', ':', '\n ', 'Pass']) + """) + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("AsyncFor", source.index("async for"), len(source) - 1) + checker.check_children( + "AsyncFor", + [ + "async", + " ", + "for", + " ", + "Name", + " ", + "in", + " ", + "Call", + "", + ":", + "\n ", + "Pass", + "\n ", + "else", + "", + ":", + "\n ", + "Pass", + ], + ) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_named_expr_node(self): - source = 'if a := 10 == 10:\n pass\n' + source = "if a := 10 == 10:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.index('a') - checker.check_region('NamedExpr', start, start + 13) - checker.check_children('NamedExpr', ['Name', ' ', ':=', ' ', 'Compare']) + start = source.index("a") + checker.check_region("NamedExpr", start, start + 13) + checker.check_children("NamedExpr", ["Name", " ", ":=", " ", "Compare"]) def test_normal_from_node(self): - source = 'from x import y\n' + source = "from x import y\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('ImportFrom', 0, len(source) - 1) + checker.check_region("ImportFrom", 0, len(source) - 1) checker.check_children( - 'ImportFrom', ['from', ' ', 'x', ' ', 'import', ' ', 'alias']) - checker.check_children('alias', ['y']) + "ImportFrom", ["from", " ", "x", " ", "import", " ", "alias"] + ) + checker.check_children("alias", ["y"]) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_from_node(self): - source = 'from ..x import y as z\n' + source = "from ..x import y as z\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('ImportFrom', 0, len(source) - 1) + checker.check_region("ImportFrom", 0, len(source) - 1) checker.check_children( - 'ImportFrom', ['from', ' ', '..', '', 'x', ' ', - 'import', ' ', 'alias']) - checker.check_children('alias', ['y', ' ', 'as', ' ', 'z']) + "ImportFrom", ["from", " ", "..", "", "x", " ", "import", " ", "alias"] + ) + checker.check_children("alias", ["y", " ", "as", " ", "z"]) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_from_node_relative_import(self): - source = 'from . import y as z\n' + source = "from . import y as z\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('ImportFrom', 0, len(source) - 1) + checker.check_region("ImportFrom", 0, len(source) - 1) checker.check_children( - 'ImportFrom', ['from', ' ', '.', '', '', ' ', - 'import', ' ', 'alias']) - checker.check_children('alias', ['y', ' ', 'as', ' ', 'z']) + "ImportFrom", ["from", " ", ".", "", "", " ", "import", " ", "alias"] + ) + checker.check_children("alias", ["y", " ", "as", " ", "z"]) def test_simple_gen_expr_node(self): - source = 'zip(i for i in x)\n' + source = "zip(i for i in x)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('GeneratorExp', 4, len(source) - 2) + checker.check_region("GeneratorExp", 4, len(source) - 2) + checker.check_children("GeneratorExp", ["Name", " ", "comprehension"]) checker.check_children( - 'GeneratorExp', ['Name', ' ', 'comprehension']) - checker.check_children( - 'comprehension', ['for', ' ', 'Name', ' ', 'in', ' ', 'Name']) + "comprehension", ["for", " ", "Name", " ", "in", " ", "Name"] + ) def test_gen_expr_node_handling_surrounding_parens(self): - source = '(i for i in x)\n' + source = "(i for i in x)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('GeneratorExp', 0, len(source) - 1) + checker.check_region("GeneratorExp", 0, len(source) - 1) checker.check_children( - 'GeneratorExp', ['(', '', 'Name', ' ', 'comprehension', '', ')']) + "GeneratorExp", ["(", "", "Name", " ", "comprehension", "", ")"] + ) def test_gen_expr_node2(self): - source = 'zip(i for i in range(1) if i == 1)\n' + source = "zip(i for i in range(1) if i == 1)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'comprehension', ['for', ' ', 'Name', ' ', 'in', ' ', 'Call', - ' ', 'if', ' ', 'Compare']) + "comprehension", + ["for", " ", "Name", " ", "in", " ", "Call", " ", "if", " ", "Compare"], + ) def test_get_attr_node(self): - source = 'a.b\n' + source = "a.b\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Attribute', 0, len(source) - 1) - checker.check_children('Attribute', ['Name', '', '.', '', 'b']) + checker.check_region("Attribute", 0, len(source) - 1) + checker.check_children("Attribute", ["Name", "", ".", "", "b"]) def test_global_node(self): - source = 'global a, b\n' + source = "global a, b\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Global', 0, len(source) - 1) - checker.check_children('Global', ['global', ' ', 'a', '', ',', ' ', - 'b']) + checker.check_region("Global", 0, len(source) - 1) + checker.check_children("Global", ["global", " ", "a", "", ",", " ", "b"]) def test_if_node(self): - source = 'if True:\n pass\nelse:\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('If', 0, len(source) - 1) - checker.check_children( - 'If', ['if', ' ', NameConstant, '', ':', '\n ', 'Pass', '\n', - 'else', '', ':', '\n ', 'Pass']) + source = "if True:\n pass\nelse:\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("If", 0, len(source) - 1) + checker.check_children( + "If", + [ + "if", + " ", + NameConstant, + "", + ":", + "\n ", + "Pass", + "\n", + "else", + "", + ":", + "\n ", + "Pass", + ], + ) def test_if_node2(self): - source = 'if True:\n pass\nelif False:\n pass\n' + source = "if True:\n pass\nelif False:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('If', 0, len(source) - 1) + checker.check_region("If", 0, len(source) - 1) checker.check_children( - 'If', ['if', ' ', NameConstant, '', ':', '\n ', 'Pass', '\n', - 'If']) + "If", ["if", " ", NameConstant, "", ":", "\n ", "Pass", "\n", "If"] + ) def test_if_node3(self): - source = 'if True:\n pass\nelse:\n' \ - ' if True:\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_region('If', 0, len(source) - 1) - checker.check_children( - 'If', ['if', ' ', NameConstant, '', ':', '\n ', 'Pass', '\n', - 'else', '', ':', '\n ', 'If']) + source = "if True:\n pass\nelse:\n" " if True:\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_region("If", 0, len(source) - 1) + checker.check_children( + "If", + [ + "if", + " ", + NameConstant, + "", + ":", + "\n ", + "Pass", + "\n", + "else", + "", + ":", + "\n ", + "If", + ], + ) def test_import_node(self): - source = 'import a, b as c\n' + source = "import a, b as c\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Import', 0, len(source) - 1) + checker.check_region("Import", 0, len(source) - 1) checker.check_children( - 'Import', ['import', ' ', 'alias', '', ',', ' ', 'alias']) + "Import", ["import", " ", "alias", "", ",", " ", "alias"] + ) def test_lambda_node(self): - source = 'lambda a, b=1, *z: None\n' + source = "lambda a, b=1, *z: None\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Lambda', 0, len(source) - 1) + checker.check_region("Lambda", 0, len(source) - 1) checker.check_children( - 'Lambda', ['lambda', ' ', 'arguments', '', ':', ' ', NameConstant]) + "Lambda", ["lambda", " ", "arguments", "", ":", " ", NameConstant] + ) expected_child = pycompat.ast_arg_type.__name__ checker.check_children( - 'arguments', [expected_child, '', ',', ' ', - expected_child, '', '=', '', - 'Num', '', ',', ' ', '*', '', 'z']) + "arguments", + [ + expected_child, + "", + ",", + " ", + expected_child, + "", + "=", + "", + "Num", + "", + ",", + " ", + "*", + "", + "z", + ], + ) def test_list_node(self): - source = '[1, 2]\n' + source = "[1, 2]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('List', 0, len(source) - 1) - checker.check_children( - 'List', ['[', '', 'Num', '', ',', ' ', 'Num', '', ']']) + checker.check_region("List", 0, len(source) - 1) + checker.check_children("List", ["[", "", "Num", "", ",", " ", "Num", "", "]"]) def test_list_comp_node(self): - source = '[i for i in range(1) if True]\n' + source = "[i for i in range(1) if True]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('ListComp', 0, len(source) - 1) + checker.check_region("ListComp", 0, len(source) - 1) checker.check_children( - 'ListComp', ['[', '', 'Name', ' ', 'comprehension', '', ']']) + "ListComp", ["[", "", "Name", " ", "comprehension", "", "]"] + ) checker.check_children( - 'comprehension', ['for', ' ', 'Name', ' ', 'in', ' ', - 'Call', ' ', 'if', ' ', NameConstant]) + "comprehension", + ["for", " ", "Name", " ", "in", " ", "Call", " ", "if", " ", NameConstant], + ) def test_list_comp_node_with_multiple_comprehensions(self): - source = '[i for i in range(1) for j in range(1) if True]\n' + source = "[i for i in range(1) for j in range(1) if True]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('ListComp', 0, len(source) - 1) + checker.check_region("ListComp", 0, len(source) - 1) checker.check_children( - 'ListComp', ['[', '', 'Name', ' ', 'comprehension', - ' ', 'comprehension', '', ']']) + "ListComp", + ["[", "", "Name", " ", "comprehension", " ", "comprehension", "", "]"], + ) checker.check_children( - 'comprehension', ['for', ' ', 'Name', ' ', 'in', ' ', - 'Call', ' ', 'if', ' ', NameConstant]) + "comprehension", + ["for", " ", "Name", " ", "in", " ", "Call", " ", "if", " ", NameConstant], + ) def test_set_node(self): # make sure we are in a python version with set literals - source = '{1, 2}\n' + source = "{1, 2}\n" try: eval(source) @@ -836,13 +1093,12 @@ def test_set_node(self): ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Set', 0, len(source) - 1) - checker.check_children( - 'Set', ['{', '', 'Num', '', ',', ' ', 'Num', '', '}']) + checker.check_region("Set", 0, len(source) - 1) + checker.check_children("Set", ["{", "", "Num", "", ",", " ", "Num", "", "}"]) def test_set_comp_node(self): # make sure we are in a python version with set comprehensions - source = '{i for i in range(1) if True}\n' + source = "{i for i in range(1) if True}\n" try: eval(source) @@ -851,16 +1107,18 @@ def test_set_comp_node(self): ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('SetComp', 0, len(source) - 1) + checker.check_region("SetComp", 0, len(source) - 1) checker.check_children( - 'SetComp', ['{', '', 'Name', ' ', 'comprehension', '', '}']) + "SetComp", ["{", "", "Name", " ", "comprehension", "", "}"] + ) checker.check_children( - 'comprehension', ['for', ' ', 'Name', ' ', 'in', ' ', - 'Call', ' ', 'if', ' ', NameConstant]) + "comprehension", + ["for", " ", "Name", " ", "in", " ", "Call", " ", "if", " ", NameConstant], + ) def test_dict_comp_node(self): # make sure we are in a python version with dict comprehensions - source = '{i:i for i in range(1) if True}\n' + source = "{i:i for i in range(1) if True}\n" try: eval(source) @@ -869,352 +1127,462 @@ def test_dict_comp_node(self): ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('DictComp', 0, len(source) - 1) + checker.check_region("DictComp", 0, len(source) - 1) checker.check_children( - 'DictComp', ['{', '', 'Name', '', ':', '', 'Name', - ' ', 'comprehension', '', '}']) + "DictComp", + ["{", "", "Name", "", ":", "", "Name", " ", "comprehension", "", "}"], + ) checker.check_children( - 'comprehension', ['for', ' ', 'Name', ' ', 'in', ' ', - 'Call', ' ', 'if', ' ', NameConstant]) + "comprehension", + ["for", " ", "Name", " ", "in", " ", "Call", " ", "if", " ", NameConstant], + ) def test_ext_slice_node(self): - source = 'x = xs[0,:]\n' + source = "x = xs[0,:]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) if sys.version_info >= (3, 9): - checker.check_region('Tuple', 7, len(source) - 2) - checker.check_children('Tuple', ['Num', '', ',', '', 'Slice']) + checker.check_region("Tuple", 7, len(source) - 2) + checker.check_children("Tuple", ["Num", "", ",", "", "Slice"]) else: - checker.check_region('ExtSlice', 7, len(source) - 2) - checker.check_children('ExtSlice', ['Index', '', ',', '', 'Slice']) + checker.check_region("ExtSlice", 7, len(source) - 2) + checker.check_children("ExtSlice", ["Index", "", ",", "", "Slice"]) def test_simple_module_node(self): - source = 'pass\n' + source = "pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Module', 0, len(source)) - checker.check_children('Module', ['', 'Pass', '\n']) + checker.check_region("Module", 0, len(source)) + checker.check_children("Module", ["", "Pass", "\n"]) def test_module_node(self): source = '"""docs"""\npass\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Module', 0, len(source)) - checker.check_children('Module', ['', 'Expr', '\n', 'Pass', '\n']) - checker.check_children('Str', ['"""docs"""']) + checker.check_region("Module", 0, len(source)) + checker.check_children("Module", ["", "Expr", "\n", "Pass", "\n"]) + checker.check_children("Str", ['"""docs"""']) def test_not_and_or_nodes(self): - source = 'not True or False\n' + source = "not True or False\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Expr', ['BoolOp']) - checker.check_children('BoolOp', ['UnaryOp', ' ', 'or', ' ', NameConstant]) + checker.check_children("Expr", ["BoolOp"]) + checker.check_children("BoolOp", ["UnaryOp", " ", "or", " ", NameConstant]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_print_node(self): - source = 'print >>out, 1,\n' + source = "print >>out, 1,\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Print', 0, len(source) - 1) - checker.check_children('Print', ['print', ' ', '>>', '', 'Name', '', - ',', ' ', 'Num', '', ',']) + checker.check_region("Print", 0, len(source) - 1) + checker.check_children( + "Print", ["print", " ", ">>", "", "Name", "", ",", " ", "Num", "", ","] + ) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_printnl_node(self): - source = 'print 1\n' + source = "print 1\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Print', 0, len(source) - 1) - checker.check_children('Print', ['print', ' ', 'Num']) + checker.check_region("Print", 0, len(source) - 1) + checker.check_children("Print", ["print", " ", "Num"]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_raise_node_for_python2(self): - source = 'raise x, y, z\n' + source = "raise x, y, z\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Raise', 0, len(source) - 1) + checker.check_region("Raise", 0, len(source) - 1) checker.check_children( - 'Raise', ['raise', ' ', 'Name', '', ',', ' ', 'Name', '', ',', - ' ', 'Name']) + "Raise", ["raise", " ", "Name", "", ",", " ", "Name", "", ",", " ", "Name"] + ) # @#testutils.only_for('3') - @unittest.skipIf(sys.version < '3', 'This is wrong') + @unittest.skipIf(sys.version < "3", "This is wrong") def test_raise_node_for_python3(self): - source = 'raise x(y)\n' + source = "raise x(y)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_region('Raise', 0, len(source) - 1) - checker.check_children( - 'Raise', ['raise', ' ', 'Call']) + checker.check_region("Raise", 0, len(source) - 1) + checker.check_children("Raise", ["raise", " ", "Call"]) def test_return_node(self): - source = 'def f():\n return None\n' + source = "def f():\n return None\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Return', ['return', ' ', NameConstant]) + checker.check_children("Return", ["return", " ", NameConstant]) def test_empty_return_node(self): - source = 'def f():\n return\n' + source = "def f():\n return\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Return', ['return']) + checker.check_children("Return", ["return"]) def test_simple_slice_node(self): - source = 'a[1:2]\n' + source = "a[1:2]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Subscript', ['Name', '', '[', '', 'Slice', '', ']']) - checker.check_children( - 'Slice', ['Num', '', ':', '', 'Num']) + checker.check_children("Subscript", ["Name", "", "[", "", "Slice", "", "]"]) + checker.check_children("Slice", ["Num", "", ":", "", "Num"]) def test_slice_node2(self): - source = 'a[:]\n' + source = "a[:]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Subscript', ['Name', '', '[', '', 'Slice', - '', ']']) - checker.check_children('Slice', [':']) + checker.check_children("Subscript", ["Name", "", "[", "", "Slice", "", "]"]) + checker.check_children("Slice", [":"]) def test_simple_subscript(self): - source = 'a[1]\n' + source = "a[1]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) if sys.version_info >= (3, 9): - checker.check_children( - 'Subscript', ['Name', '', '[', '', 'Num', '', ']']) + checker.check_children("Subscript", ["Name", "", "[", "", "Num", "", "]"]) else: - checker.check_children( - 'Subscript', ['Name', '', '[', '', 'Index', '', ']']) - checker.check_children('Index', ['Num']) + checker.check_children("Subscript", ["Name", "", "[", "", "Index", "", "]"]) + checker.check_children("Index", ["Num"]) def test_tuple_node(self): - source = '(1, 2)\n' + source = "(1, 2)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Tuple', ['(', '', 'Num', '', ',', ' ', 'Num', '', ')']) + checker.check_children("Tuple", ["(", "", "Num", "", ",", " ", "Num", "", ")"]) def test_tuple_node2(self): - source = '#(\n1, 2\n' + source = "#(\n1, 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Tuple', ['Num', '', ',', ' ', 'Num']) + checker.check_children("Tuple", ["Num", "", ",", " ", "Num"]) def test_one_item_tuple_node(self): - source = '(1,)\n' + source = "(1,)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Tuple', ['(', '', 'Num', ',', ')']) + checker.check_children("Tuple", ["(", "", "Num", ",", ")"]) def test_empty_tuple_node(self): - source = '()\n' + source = "()\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Tuple', ['(', '', ')']) + checker.check_children("Tuple", ["(", "", ")"]) def test_yield_node(self): - source = 'def f():\n yield None\n' + source = "def f():\n yield None\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Yield', ['yield', ' ', NameConstant]) + checker.check_children("Yield", ["yield", " ", NameConstant]) def test_while_node(self): - source = 'while True:\n pass\nelse:\n pass\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'While', ['while', ' ', NameConstant, '', ':', '\n ', 'Pass', '\n', - 'else', '', ':', '\n ', 'Pass']) + source = "while True:\n pass\nelse:\n pass\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_children( + "While", + [ + "while", + " ", + NameConstant, + "", + ":", + "\n ", + "Pass", + "\n", + "else", + "", + ":", + "\n ", + "Pass", + ], + ) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_with_node(self): - source = 'from __future__ import with_statement\n' +\ - 'with a as b:\n' +\ - ' pass\n' + source = ( + "from __future__ import with_statement\n" + "with a as b:\n" + " pass\n" + ) ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'With', ['with', ' ', 'Name', ' ', 'as', ' ', 'Name', '', ':', - '\n ', 'Pass']) + "With", + ["with", " ", "Name", " ", "as", " ", "Name", "", ":", "\n ", "Pass"], + ) def test_try_finally_node(self): - source = 'try:\n pass\nfinally:\n pass\n' + source = "try:\n pass\nfinally:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - node_to_test = 'Try' if pycompat.PY3 else 'TryFinally' + node_to_test = "Try" if pycompat.PY3 else "TryFinally" if pycompat.PY3: - expected_children = ['try', '', ':', '\n ', - 'Pass', '\n', 'finally', - '', ':', '\n ', 'Pass'] + expected_children = [ + "try", + "", + ":", + "\n ", + "Pass", + "\n", + "finally", + "", + ":", + "\n ", + "Pass", + ] else: - expected_children = ['try', '', ':', '\n ', - 'Pass', '\n', 'finally', '', ':', '\n ', - 'Pass'] - checker.check_children( - node_to_test, expected_children) - - @testutils.only_for_versions_lower('3') + expected_children = [ + "try", + "", + ":", + "\n ", + "Pass", + "\n", + "finally", + "", + ":", + "\n ", + "Pass", + ] + checker.check_children(node_to_test, expected_children) + + @testutils.only_for_versions_lower("3") def test_try_except_node(self): - source = 'try:\n pass\nexcept Exception, e:\n pass\n' + source = "try:\n pass\nexcept Exception, e:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'TryExcept', ['try', '', ':', '\n ', 'Pass', '\n', - ('excepthandler', 'ExceptHandler')]) + "TryExcept", + [ + "try", + "", + ":", + "\n ", + "Pass", + "\n", + ("excepthandler", "ExceptHandler"), + ], + ) checker.check_children( - ('excepthandler', 'ExceptHandler'), - ['except', ' ', 'Name', '', ',', ' ', 'Name', '', ':', - '\n ', 'Pass']) + ("excepthandler", "ExceptHandler"), + ["except", " ", "Name", "", ",", " ", "Name", "", ":", "\n ", "Pass"], + ) def test_try_except_node__with_as_syntax(self): - source = 'try:\n pass\nexcept Exception as e:\n pass\n' + source = "try:\n pass\nexcept Exception as e:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - node_to_test = 'Try' if pycompat.PY3 else 'TryExcept' - checker.check_children( - node_to_test, ['try', '', ':', '\n ', 'Pass', '\n', - ('excepthandler', 'ExceptHandler')]) - expected_child = 'e' if pycompat.PY3 else 'Name' + node_to_test = "Try" if pycompat.PY3 else "TryExcept" checker.check_children( - ('excepthandler', 'ExceptHandler'), - ['except', ' ', 'Name', ' ', 'as', ' ', expected_child, '', ':', - '\n ', 'Pass']) + node_to_test, + [ + "try", + "", + ":", + "\n ", + "Pass", + "\n", + ("excepthandler", "ExceptHandler"), + ], + ) + expected_child = "e" if pycompat.PY3 else "Name" + checker.check_children( + ("excepthandler", "ExceptHandler"), + [ + "except", + " ", + "Name", + " ", + "as", + " ", + expected_child, + "", + ":", + "\n ", + "Pass", + ], + ) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_try_except_and_finally_node(self): - source = 'try:\n pass\nexcept:\n pass\nfinally:\n pass\n' + source = "try:\n pass\nexcept:\n pass\nfinally:\n pass\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - node_to_test = 'Try' if pycompat.PY3 else 'TryFinally' + node_to_test = "Try" if pycompat.PY3 else "TryFinally" if pycompat.PY3: - expected_children = ['try', '', ':', '\n ', 'Pass', '\n', - 'ExceptHandler', '\n', - 'finally', '', ':', '\n ', 'Pass'] + expected_children = [ + "try", + "", + ":", + "\n ", + "Pass", + "\n", + "ExceptHandler", + "\n", + "finally", + "", + ":", + "\n ", + "Pass", + ] else: - expected_children = ['TryExcept', '\n', - 'finally', '', ':', '\n ', 'Pass'] - checker.check_children( - node_to_test, - expected_children - ) + expected_children = [ + "TryExcept", + "\n", + "finally", + "", + ":", + "\n ", + "Pass", + ] + checker.check_children(node_to_test, expected_children) def test_ignoring_comments(self): - source = '#1\n1\n' + source = "#1\n1\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - start = source.rindex('1') - checker.check_region('Num', start, start + 1) + start = source.rindex("1") + checker.check_region("Num", start, start + 1) def test_simple_sliceobj(self): - source = 'a[1::3]\n' + source = "a[1::3]\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Slice', ['Num', '', ':', '', ':', '', 'Num']) + checker.check_children("Slice", ["Num", "", ":", "", ":", "", "Num"]) def test_ignoring_strings_that_start_with_a_char(self): source = 'r"""("""\n1\n' ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Module', ['', 'Expr', '\n', 'Expr', '\n']) + checker.check_children("Module", ["", "Expr", "\n", "Expr", "\n"]) - @testutils.only_for_versions_lower('3') + @testutils.only_for_versions_lower("3") def test_how_to_handle_old_not_equals(self): - source = '1 <> 2\n' + source = "1 <> 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Compare', ['Num', ' ', '<>', ' ', 'Num']) + checker.check_children("Compare", ["Num", " ", "<>", " ", "Num"]) def test_semicolon(self): - source = '1;\n' + source = "1;\n" patchedast.get_patched_ast(source, True) - @testutils.only_for('2.5') + @testutils.only_for("2.5") def test_if_exp_node(self): - source = '1 if True else 2\n' + source = "1 if True else 2\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'IfExp', ['Num', ' ', 'if', ' ', NameConstant, ' ', 'else', - ' ', 'Num']) + "IfExp", ["Num", " ", "if", " ", NameConstant, " ", "else", " ", "Num"] + ) def test_delete_node(self): - source = 'del a, b\n' + source = "del a, b\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Delete', ['del', ' ', 'Name', '', ',', ' ', 'Name']) + checker.check_children("Delete", ["del", " ", "Name", "", ",", " ", "Name"]) - @testutils.only_for_versions_lower('3.5') + @testutils.only_for_versions_lower("3.5") def test_starargs_before_keywords_legacy(self): - source = 'foo(*args, a=1)\n' + source = "foo(*args, a=1)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '', '*', '', 'Name', '', ',', ' ', - 'keyword', '', ')']) + "Call", + ["Name", "", "(", "", "*", "", "Name", "", ",", " ", "keyword", "", ")"], + ) - @testutils.only_for_versions_lower('3.5') + @testutils.only_for_versions_lower("3.5") def test_starargs_in_keywords_legacy(self): - source = 'foo(a=1, *args, b=2)\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Call', ['Name', '', '(', '', 'keyword', '', ',', ' ', '*', '', - 'Name', '', ',', ' ', 'keyword', '',')']) + source = "foo(a=1, *args, b=2)\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_children( + "Call", + [ + "Name", + "", + "(", + "", + "keyword", + "", + ",", + " ", + "*", + "", + "Name", + "", + ",", + " ", + "keyword", + "", + ")", + ], + ) - @testutils.only_for_versions_lower('3.5') + @testutils.only_for_versions_lower("3.5") def test_starargs_after_keywords_legacy(self): - source = 'foo(a=1, *args)\n' + source = "foo(a=1, *args)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '', 'keyword', '', ',', ' ', '*', '', - 'Name', '', ')']) + "Call", + ["Name", "", "(", "", "keyword", "", ",", " ", "*", "", "Name", "", ")"], + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_starargs_before_keywords(self): - source = 'foo(*args, a=1)\n' + source = "foo(*args, a=1)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '*', 'Starred', '', ',', ' ', - 'keyword', '', ')']) + "Call", ["Name", "", "(", "*", "Starred", "", ",", " ", "keyword", "", ")"] + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_starargs_in_keywords(self): - source = 'foo(a=1, *args, b=2)\n' - ast_frag = patchedast.get_patched_ast(source, True) - checker = _ResultChecker(self, ast_frag) - checker.check_children( - 'Call', ['Name', '', '(', '', 'keyword', '', ',', ' *', - 'Starred', '', ',', ' ', 'keyword', '',')']) + source = "foo(a=1, *args, b=2)\n" + ast_frag = patchedast.get_patched_ast(source, True) + checker = _ResultChecker(self, ast_frag) + checker.check_children( + "Call", + [ + "Name", + "", + "(", + "", + "keyword", + "", + ",", + " *", + "Starred", + "", + ",", + " ", + "keyword", + "", + ")", + ], + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_starargs_after_keywords(self): - source = 'foo(a=1, *args)\n' + source = "foo(a=1, *args)\n" ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) checker.check_children( - 'Call', ['Name', '', '(', '', 'keyword', '', ',', ' *', - 'Starred', '', ')']) + "Call", ["Name", "", "(", "", "keyword", "", ",", " *", "Starred", "", ")"] + ) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_await_node(self): - source = dedent('''\ + source = dedent("""\ async def f(): await sleep() - ''') + """) ast_frag = patchedast.get_patched_ast(source, True) checker = _ResultChecker(self, ast_frag) - checker.check_children('Await', ['await', ' ', 'Call']) + checker.check_children("Await", ["await", " ", "Call"]) class _ResultChecker(object): - def __init__(self, test_case, ast): self.test_case = test_case self.ast = ast @@ -1222,7 +1590,7 @@ def __init__(self, test_case, ast): def check_region(self, text, start, end): node = self._find_node(text) if node is None: - self.test_case.fail('Node <%s> cannot be found' % text) + self.test_case.fail("Node <%s> cannot be found" % text) self.test_case.assertEqual((start, end), node.region) def _find_node(self, text): @@ -1235,8 +1603,13 @@ class Search(object): def __call__(self, node): for text in goal: - if sys.version_info >= (3, 8) and text in ['Num', 'Str', 'NameConstant', 'Ellipsis']: - text = 'Constant' + if sys.version_info >= (3, 8) and text in [ + "Num", + "Str", + "NameConstant", + "Ellipsis", + ]: + text = "Constant" if str(node).startswith(text): self.result = node break @@ -1244,6 +1617,7 @@ def __call__(self, node): self.result = node break return self.result is not None + search = Search() ast.call_for_nodes(self.ast, search, recursive=True) return search.result @@ -1251,7 +1625,7 @@ def __call__(self, node): def check_children(self, text, children): node = self._find_node(text) if node is None: - self.test_case.fail('Node <%s> cannot be found' % text) + self.test_case.fail("Node <%s> cannot be found" % text) result = list(node.sorted_children) self.test_case.assertEqual(len(children), len(result)) for expected, child in zip(children, result): @@ -1259,15 +1633,20 @@ def check_children(self, text, children): if not isinstance(expected, (tuple, list)): goals = [expected] for goal in goals: - if goal == '' or isinstance(child, basestring): + if goal == "" or isinstance(child, basestring): self.test_case.assertEqual(goal, child) break else: - self.test_case.assertNotEqual( - '', text, 'probably ignoring some node') - if sys.version_info >= (3, 8) and expected in ['Num', 'Str', 'NameConstant', 'Ellipsis']: - expected = 'Constant' + self.test_case.assertNotEqual("", text, "probably ignoring some node") + if sys.version_info >= (3, 8) and expected in [ + "Num", + "Str", + "NameConstant", + "Ellipsis", + ]: + expected = "Constant" self.test_case.assertTrue( child.__class__.__name__.startswith(expected), - msg='Expected <%s> but was <%s>' % - (expected, child.__class__.__name__)) + msg="Expected <%s> but was <%s>" + % (expected, child.__class__.__name__), + ) diff --git a/ropetest/refactor/renametest.py b/ropetest/refactor/renametest.py index 6d8bd6aee..112331238 100644 --- a/ropetest/refactor/renametest.py +++ b/ropetest/refactor/renametest.py @@ -1,5 +1,6 @@ import sys from textwrap import dedent + try: import unittest2 as unittest except ImportError: @@ -13,7 +14,6 @@ class RenameRefactoringTest(unittest.TestCase): - def setUp(self): super(RenameRefactoringTest, self).setUp() self.project = testutils.sample_project() @@ -23,601 +23,637 @@ def tearDown(self): super(RenameRefactoringTest, self).tearDown() def _local_rename(self, source_code, offset, new_name): - testmod = testutils.create_module(self.project, 'testmod') + testmod = testutils.create_module(self.project, "testmod") testmod.write(source_code) - changes = Rename(self.project, testmod, offset).\ - get_changes(new_name, resources=[testmod]) + changes = Rename(self.project, testmod, offset).get_changes( + new_name, resources=[testmod] + ) self.project.do(changes) return testmod.read() def _rename(self, resource, offset, new_name, **kwds): - changes = Rename(self.project, resource, offset).\ - get_changes(new_name, **kwds) + changes = Rename(self.project, resource, offset).get_changes(new_name, **kwds) self.project.do(changes) def test_simple_global_variable_renaming(self): - refactored = self._local_rename('a_var = 20\n', 2, 'new_var') - self.assertEqual('new_var = 20\n', refactored) + refactored = self._local_rename("a_var = 20\n", 2, "new_var") + self.assertEqual("new_var = 20\n", refactored) def test_variable_renaming_only_in_its_scope(self): refactored = self._local_rename( - 'a_var = 20\ndef a_func():\n a_var = 10\n', 32, 'new_var') - self.assertEqual('a_var = 20\ndef a_func():\n new_var = 10\n', - refactored) + "a_var = 20\ndef a_func():\n a_var = 10\n", 32, "new_var" + ) + self.assertEqual("a_var = 20\ndef a_func():\n new_var = 10\n", refactored) def test_not_renaming_dot_name(self): refactored = self._local_rename( - "replace = True\n'aaa'.replace('a', 'b')\n", 1, 'new_var') - self.assertEqual("new_var = True\n'aaa'.replace('a', 'b')\n", - refactored) + "replace = True\n'aaa'.replace('a', 'b')\n", 1, "new_var" + ) + self.assertEqual("new_var = True\n'aaa'.replace('a', 'b')\n", refactored) def test_renaming_multiple_names_in_the_same_line(self): refactored = self._local_rename( - 'a_var = 10\na_var = 10 + a_var / 2\n', 2, 'new_var') - self.assertEqual('new_var = 10\nnew_var = 10 + new_var / 2\n', - refactored) + "a_var = 10\na_var = 10 + a_var / 2\n", 2, "new_var" + ) + self.assertEqual("new_var = 10\nnew_var = 10 + new_var / 2\n", refactored) def test_renaming_names_when_getting_some_attribute(self): refactored = self._local_rename( - "a_var = 'a b c'\na_var.split('\\n')\n", 2, 'new_var') - self.assertEqual("new_var = 'a b c'\nnew_var.split('\\n')\n", - refactored) + "a_var = 'a b c'\na_var.split('\\n')\n", 2, "new_var" + ) + self.assertEqual("new_var = 'a b c'\nnew_var.split('\\n')\n", refactored) def test_renaming_names_when_getting_some_attribute2(self): refactored = self._local_rename( - "a_var = 'a b c'\na_var.split('\\n')\n", 20, 'new_var') - self.assertEqual("new_var = 'a b c'\nnew_var.split('\\n')\n", - refactored) + "a_var = 'a b c'\na_var.split('\\n')\n", 20, "new_var" + ) + self.assertEqual("new_var = 'a b c'\nnew_var.split('\\n')\n", refactored) def test_renaming_function_parameters1(self): refactored = self._local_rename( - "def f(a_param):\n print(a_param)\n", 8, 'new_param') - self.assertEqual("def f(new_param):\n print(new_param)\n", - refactored) + "def f(a_param):\n print(a_param)\n", 8, "new_param" + ) + self.assertEqual("def f(new_param):\n print(new_param)\n", refactored) def test_renaming_function_parameters2(self): refactored = self._local_rename( - "def f(a_param):\n print(a_param)\n", 30, 'new_param') - self.assertEqual("def f(new_param):\n print(new_param)\n", - refactored) + "def f(a_param):\n print(a_param)\n", 30, "new_param" + ) + self.assertEqual("def f(new_param):\n print(new_param)\n", refactored) def test_renaming_occurrences_inside_functions(self): - code = 'def a_func(p1):\n a = p1\na_func(1)\n' - refactored = self._local_rename(code, code.index('p1') + 1, - 'new_param') + code = "def a_func(p1):\n a = p1\na_func(1)\n" + refactored = self._local_rename(code, code.index("p1") + 1, "new_param") self.assertEqual( - 'def a_func(new_param):\n a = new_param\na_func(1)\n', - refactored) + "def a_func(new_param):\n a = new_param\na_func(1)\n", refactored + ) def test_renaming_comprehension_loop_variables(self): - code = '[b_var for b_var, c_var in d_var if b_var == c_var]' - refactored = self._local_rename(code, code.index('b_var') + 1, - 'new_var') + code = "[b_var for b_var, c_var in d_var if b_var == c_var]" + refactored = self._local_rename(code, code.index("b_var") + 1, "new_var") self.assertEqual( - '[new_var for new_var, c_var in d_var if new_var == c_var]', - refactored) + "[new_var for new_var, c_var in d_var if new_var == c_var]", refactored + ) def test_renaming_list_comprehension_loop_variables_in_assignment(self): - code = 'a_var = [b_var for b_var, c_var in d_var if b_var == c_var]' - refactored = self._local_rename(code, code.index('b_var') + 1, - 'new_var') + code = "a_var = [b_var for b_var, c_var in d_var if b_var == c_var]" + refactored = self._local_rename(code, code.index("b_var") + 1, "new_var") self.assertEqual( - 'a_var = [new_var for new_var, c_var in d_var if new_var == c_var]', - refactored) + "a_var = [new_var for new_var, c_var in d_var if new_var == c_var]", + refactored, + ) def test_renaming_generator_comprehension_loop_variables(self): - code = 'a_var = (b_var for b_var, c_var in d_var if b_var == c_var)' - refactored = self._local_rename(code, code.index('b_var') + 1, - 'new_var') + code = "a_var = (b_var for b_var, c_var in d_var if b_var == c_var)" + refactored = self._local_rename(code, code.index("b_var") + 1, "new_var") self.assertEqual( - 'a_var = (new_var for new_var, c_var in d_var if new_var == c_var)', - refactored) + "a_var = (new_var for new_var, c_var in d_var if new_var == c_var)", + refactored, + ) @unittest.expectedFailure def test_renaming_comprehension_loop_variables_scope(self): # FIXME: variable scoping for comprehensions is incorrect, we currently # don't create a scope for comprehension - code = dedent('''\ + code = dedent("""\ [b_var for b_var, c_var in d_var if b_var == c_var] b_var = 10 - ''') - refactored = self._local_rename(code, code.index('b_var') + 1, - 'new_var') + """) + refactored = self._local_rename(code, code.index("b_var") + 1, "new_var") self.assertEqual( - '[new_var for new_var, c_var in d_var if new_var == c_var]\nb_var = 10\n', - refactored) + "[new_var for new_var, c_var in d_var if new_var == c_var]\nb_var = 10\n", + refactored, + ) - @testutils.only_for_versions_higher('3.8') + @testutils.only_for_versions_higher("3.8") def test_renaming_inline_assignment(self): - code = dedent('''\ + code = dedent("""\ while a_var := next(foo): print(a_var) - ''') - refactored = self._local_rename(code, code.index('a_var') + 1, - 'new_var') + """) + refactored = self._local_rename(code, code.index("a_var") + 1, "new_var") self.assertEqual( - dedent('''\ + dedent("""\ while new_var := next(foo): print(new_var) - '''), + """), refactored, ) def test_renaming_arguments_for_normal_args_changing_calls(self): - code = 'def a_func(p1=None, p2=None):\n pass\na_func(p2=1)\n' - refactored = self._local_rename(code, code.index('p2') + 1, 'p3') + code = "def a_func(p1=None, p2=None):\n pass\na_func(p2=1)\n" + refactored = self._local_rename(code, code.index("p2") + 1, "p3") self.assertEqual( - 'def a_func(p1=None, p3=None):\n pass\na_func(p3=1)\n', - refactored) + "def a_func(p1=None, p3=None):\n pass\na_func(p3=1)\n", refactored + ) def test_renaming_function_parameters_of_class_init(self): - code = 'class A(object):\n def __init__(self, a_param):' \ - '\n pass\n' \ - 'a_var = A(a_param=1)\n' - refactored = self._local_rename(code, code.index('a_param') + 1, - 'new_param') - expected = 'class A(object):\n ' \ - 'def __init__(self, new_param):\n pass\n' \ - 'a_var = A(new_param=1)\n' + code = ( + "class A(object):\n def __init__(self, a_param):" + "\n pass\n" + "a_var = A(a_param=1)\n" + ) + refactored = self._local_rename(code, code.index("a_param") + 1, "new_param") + expected = ( + "class A(object):\n " + "def __init__(self, new_param):\n pass\n" + "a_var = A(new_param=1)\n" + ) self.assertEqual(expected, refactored) def test_rename_functions_parameters_and_occurences_in_other_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('def a_func(a_param):\n print(a_param)\n') - mod2.write('from mod1 import a_func\na_func(a_param=10)\n') - self._rename(mod1, mod1.read().index('a_param') + 1, 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n', - mod1.read()) - self.assertEqual('from mod1 import a_func\na_func(new_param=10)\n', - mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("def a_func(a_param):\n print(a_param)\n") + mod2.write("from mod1 import a_func\na_func(a_param=10)\n") + self._rename(mod1, mod1.read().index("a_param") + 1, "new_param") + self.assertEqual("def a_func(new_param):\n print(new_param)\n", mod1.read()) + self.assertEqual("from mod1 import a_func\na_func(new_param=10)\n", mod2.read()) def test_renaming_with_backslash_continued_names(self): refactored = self._local_rename( - "replace = True\n'ali'.\\\nreplace\n", 2, 'is_replace') - self.assertEqual("is_replace = True\n'ali'.\\\nreplace\n", - refactored) + "replace = True\n'ali'.\\\nreplace\n", 2, "is_replace" + ) + self.assertEqual("is_replace = True\n'ali'.\\\nreplace\n", refactored) - @testutils.only_for('3.6') + @testutils.only_for("3.6") def test_renaming_occurrence_in_f_string(self): refactored = self._local_rename( - "a_var = 20\na_string=f'value: {a_var}'\n", 2, 'new_var') - self.assertEqual("new_var = 20\na_string=f'value: {new_var}'\n", - refactored) + "a_var = 20\na_string=f'value: {a_var}'\n", 2, "new_var" + ) + self.assertEqual("new_var = 20\na_string=f'value: {new_var}'\n", refactored) - @testutils.only_for('3.6') + @testutils.only_for("3.6") def test_renaming_occurrence_in_nested_f_string(self): refactored = self._local_rename( - "a_var = 20\na_string=f'{f\"{a_var}\"}'\n", 2, 'new_var') - self.assertEqual( - "new_var = 20\na_string=f'{f\"{new_var}\"}'\n", - refactored) + "a_var = 20\na_string=f'{f\"{a_var}\"}'\n", 2, "new_var" + ) + self.assertEqual("new_var = 20\na_string=f'{f\"{new_var}\"}'\n", refactored) - @testutils.only_for('3.6') + @testutils.only_for("3.6") def test_not_renaming_string_contents_in_f_string(self): refactored = self._local_rename( - "a_var = 20\na_string=f'{\"a_var\"}'\n", 2, 'new_var') - self.assertEqual("new_var = 20\na_string=f'{\"a_var\"}'\n", - refactored) + "a_var = 20\na_string=f'{\"a_var\"}'\n", 2, "new_var" + ) + self.assertEqual("new_var = 20\na_string=f'{\"a_var\"}'\n", refactored) def test_not_renaming_string_contents(self): - refactored = self._local_rename("a_var = 20\na_string='a_var'\n", - 2, 'new_var') - self.assertEqual("new_var = 20\na_string='a_var'\n", - refactored) + refactored = self._local_rename("a_var = 20\na_string='a_var'\n", 2, "new_var") + self.assertEqual("new_var = 20\na_string='a_var'\n", refactored) def test_not_renaming_comment_contents(self): - refactored = self._local_rename("a_var = 20\n# a_var\n", - 2, 'new_var') + refactored = self._local_rename("a_var = 20\n# a_var\n", 2, "new_var") self.assertEqual("new_var = 20\n# a_var\n", refactored) def test_renaming_all_occurances_in_containing_scope(self): - code = 'if True:\n a_var = 1\nelse:\n a_var = 20\n' - refactored = self._local_rename(code, 16, 'new_var') + code = "if True:\n a_var = 1\nelse:\n a_var = 20\n" + refactored = self._local_rename(code, 16, "new_var") self.assertEqual( - 'if True:\n new_var = 1\nelse:\n new_var = 20\n', refactored) + "if True:\n new_var = 1\nelse:\n new_var = 20\n", refactored + ) def test_renaming_a_variable_with_arguement_name(self): - code = 'a_var = 10\ndef a_func(a_var):\n print(a_var)\n' - refactored = self._local_rename(code, 1, 'new_var') + code = "a_var = 10\ndef a_func(a_var):\n print(a_var)\n" + refactored = self._local_rename(code, 1, "new_var") self.assertEqual( - 'new_var = 10\ndef a_func(a_var):\n print(a_var)\n', refactored) + "new_var = 10\ndef a_func(a_var):\n print(a_var)\n", refactored + ) def test_renaming_an_arguement_with_variable_name(self): - code = 'a_var = 10\ndef a_func(a_var):\n print(a_var)\n' - refactored = self._local_rename(code, len(code) - 3, 'new_var') + code = "a_var = 10\ndef a_func(a_var):\n print(a_var)\n" + refactored = self._local_rename(code, len(code) - 3, "new_var") self.assertEqual( - 'a_var = 10\ndef a_func(new_var):\n print(new_var)\n', - refactored) + "a_var = 10\ndef a_func(new_var):\n print(new_var)\n", refactored + ) def test_renaming_function_with_local_variable_name(self): - code = 'def a_func():\n a_func=20\na_func()' - refactored = self._local_rename(code, len(code) - 3, 'new_func') - self.assertEqual('def new_func():\n a_func=20\nnew_func()', - refactored) + code = "def a_func():\n a_func=20\na_func()" + refactored = self._local_rename(code, len(code) - 3, "new_func") + self.assertEqual("def new_func():\n a_func=20\nnew_func()", refactored) def test_renaming_functions(self): - code = 'def a_func():\n pass\na_func()\n' - refactored = self._local_rename(code, len(code) - 5, 'new_func') - self.assertEqual('def new_func():\n pass\nnew_func()\n', - refactored) + code = "def a_func():\n pass\na_func()\n" + refactored = self._local_rename(code, len(code) - 5, "new_func") + self.assertEqual("def new_func():\n pass\nnew_func()\n", refactored) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_renaming_async_function(self): - code = 'async def a_func():\n pass\na_func()' - refactored = self._local_rename(code, len(code) - 5, 'new_func') - self.assertEqual('async def new_func():\n pass\nnew_func()', - refactored) + code = "async def a_func():\n pass\na_func()" + refactored = self._local_rename(code, len(code) - 5, "new_func") + self.assertEqual("async def new_func():\n pass\nnew_func()", refactored) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_renaming_await(self): - code = 'async def b_func():\n pass\nasync def a_func():\n await b_func()' - refactored = self._local_rename(code, len(code) - 5, 'new_func') - self.assertEqual('async def new_func():\n pass\nasync def a_func():\n await new_func()', - refactored) - + code = "async def b_func():\n pass\nasync def a_func():\n await b_func()" + refactored = self._local_rename(code, len(code) - 5, "new_func") + self.assertEqual( + "async def new_func():\n pass\nasync def a_func():\n await new_func()", + refactored, + ) def test_renaming_functions_across_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\na_func()\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('import mod1\nmod1.a_func()\n') - self._rename(mod1, len(mod1.read()) - 5, 'new_func') - self.assertEqual('def new_func():\n pass\nnew_func()\n', - mod1.read()) - self.assertEqual('import mod1\nmod1.new_func()\n', mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\na_func()\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("import mod1\nmod1.a_func()\n") + self._rename(mod1, len(mod1.read()) - 5, "new_func") + self.assertEqual("def new_func():\n pass\nnew_func()\n", mod1.read()) + self.assertEqual("import mod1\nmod1.new_func()\n", mod2.read()) def test_renaming_functions_across_modules_from_import(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\na_func()\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('from mod1 import a_func\na_func()\n') - self._rename(mod1, len(mod1.read()) - 5, 'new_func') - self.assertEqual('def new_func():\n pass\nnew_func()\n', - mod1.read()) - self.assertEqual('from mod1 import new_func\nnew_func()\n', - mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\na_func()\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("from mod1 import a_func\na_func()\n") + self._rename(mod1, len(mod1.read()) - 5, "new_func") + self.assertEqual("def new_func():\n pass\nnew_func()\n", mod1.read()) + self.assertEqual("from mod1 import new_func\nnew_func()\n", mod2.read()) def test_renaming_functions_from_another_module(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\na_func()\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('import mod1\nmod1.a_func()\n') - self._rename(mod2, len(mod2.read()) - 5, 'new_func') - self.assertEqual('def new_func():\n pass\nnew_func()\n', - mod1.read()) - self.assertEqual('import mod1\nmod1.new_func()\n', mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\na_func()\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("import mod1\nmod1.a_func()\n") + self._rename(mod2, len(mod2.read()) - 5, "new_func") + self.assertEqual("def new_func():\n pass\nnew_func()\n", mod1.read()) + self.assertEqual("import mod1\nmod1.new_func()\n", mod2.read()) def test_applying_all_changes_together(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('import mod2\nmod2.a_func()\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('def a_func():\n pass\na_func()\n') - self._rename(mod2, len(mod2.read()) - 5, 'new_func') - self.assertEqual('import mod2\nmod2.new_func()\n', mod1.read()) - self.assertEqual('def new_func():\n pass\nnew_func()\n', - mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("import mod2\nmod2.a_func()\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("def a_func():\n pass\na_func()\n") + self._rename(mod2, len(mod2.read()) - 5, "new_func") + self.assertEqual("import mod2\nmod2.new_func()\n", mod1.read()) + self.assertEqual("def new_func():\n pass\nnew_func()\n", mod2.read()) def test_renaming_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('from mod1 import a_func\n') - self._rename(mod2, mod2.read().index('mod1') + 1, 'newmod') - self.assertTrue(not mod1.exists() and - self.project.find_module('newmod') is not None) - self.assertEqual('from newmod import a_func\n', mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("from mod1 import a_func\n") + self._rename(mod2, mod2.read().index("mod1") + 1, "newmod") + self.assertTrue( + not mod1.exists() and self.project.find_module("newmod") is not None + ) + self.assertEqual("from newmod import a_func\n", mod2.read()) def test_renaming_modules_aliased(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('import mod1 as m\nm.a_func()\n') - self._rename(mod1, None, 'newmod') - self.assertTrue(not mod1.exists() and - self.project.find_module('newmod') is not None) - self.assertEqual('import newmod as m\nm.a_func()\n', mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("import mod1 as m\nm.a_func()\n") + self._rename(mod1, None, "newmod") + self.assertTrue( + not mod1.exists() and self.project.find_module("newmod") is not None + ) + self.assertEqual("import newmod as m\nm.a_func()\n", mod2.read()) def test_renaming_packages(self): - pkg = testutils.create_package(self.project, 'pkg') - mod1 = testutils.create_module(self.project, 'mod1', pkg) - mod1.write('def a_func():\n pass\n') - mod2 = testutils.create_module(self.project, 'mod2', pkg) - mod2.write('from pkg.mod1 import a_func\n') - self._rename(mod2, 6, 'newpkg') - self.assertTrue(self.project.find_module('newpkg.mod1') is not None) - new_mod2 = self.project.find_module('newpkg.mod2') - self.assertEqual('from newpkg.mod1 import a_func\n', new_mod2.read()) + pkg = testutils.create_package(self.project, "pkg") + mod1 = testutils.create_module(self.project, "mod1", pkg) + mod1.write("def a_func():\n pass\n") + mod2 = testutils.create_module(self.project, "mod2", pkg) + mod2.write("from pkg.mod1 import a_func\n") + self._rename(mod2, 6, "newpkg") + self.assertTrue(self.project.find_module("newpkg.mod1") is not None) + new_mod2 = self.project.find_module("newpkg.mod2") + self.assertEqual("from newpkg.mod1 import a_func\n", new_mod2.read()) def test_module_dependencies(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class AClass(object):\n pass\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('import mod1\na_var = mod1.AClass()\n') - self.project.get_pymodule(mod2).get_attributes()['mod1'] - mod1.write('def AClass():\n return 0\n') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("class AClass(object):\n pass\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("import mod1\na_var = mod1.AClass()\n") + self.project.get_pymodule(mod2).get_attributes()["mod1"] + mod1.write("def AClass():\n return 0\n") - self._rename(mod2, len(mod2.read()) - 3, 'a_func') - self.assertEqual('def a_func():\n return 0\n', mod1.read()) - self.assertEqual('import mod1\na_var = mod1.a_func()\n', mod2.read()) + self._rename(mod2, len(mod2.read()) - 3, "a_func") + self.assertEqual("def a_func():\n return 0\n", mod1.read()) + self.assertEqual("import mod1\na_var = mod1.a_func()\n", mod2.read()) def test_renaming_class_attributes(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class AClass(object):\n def __init__(self):\n' - ' self.an_attr = 10\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('import mod1\na_var = mod1.AClass()\n' - 'another_var = a_var.an_attr') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write( + "class AClass(object):\n def __init__(self):\n" + " self.an_attr = 10\n" + ) + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("import mod1\na_var = mod1.AClass()\n" "another_var = a_var.an_attr") - self._rename(mod1, mod1.read().index('an_attr'), 'attr') - self.assertEqual('class AClass(object):\n def __init__(self):\n' - ' self.attr = 10\n', mod1.read()) + self._rename(mod1, mod1.read().index("an_attr"), "attr") + self.assertEqual( + "class AClass(object):\n def __init__(self):\n" + " self.attr = 10\n", + mod1.read(), + ) self.assertEqual( - 'import mod1\na_var = mod1.AClass()\nanother_var = a_var.attr', - mod2.read()) + "import mod1\na_var = mod1.AClass()\nanother_var = a_var.attr", mod2.read() + ) def test_renaming_class_attributes2(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('class AClass(object):\n def __init__(self):\n' - ' an_attr = 10\n self.an_attr = 10\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('import mod1\na_var = mod1.AClass()\n' - 'another_var = a_var.an_attr') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write( + "class AClass(object):\n def __init__(self):\n" + " an_attr = 10\n self.an_attr = 10\n" + ) + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("import mod1\na_var = mod1.AClass()\n" "another_var = a_var.an_attr") - self._rename(mod1, mod1.read().rindex('an_attr'), 'attr') + self._rename(mod1, mod1.read().rindex("an_attr"), "attr") self.assertEqual( - 'class AClass(object):\n def __init__(self):\n' - ' an_attr = 10\n self.attr = 10\n', mod1.read()) + "class AClass(object):\n def __init__(self):\n" + " an_attr = 10\n self.attr = 10\n", + mod1.read(), + ) self.assertEqual( - 'import mod1\na_var = mod1.AClass()\nanother_var = a_var.attr', - mod2.read()) + "import mod1\na_var = mod1.AClass()\nanother_var = a_var.attr", mod2.read() + ) def test_renaming_methods_in_subclasses(self): - mod = testutils.create_module(self.project, 'mod1') - mod.write('class A(object):\n def a_method(self):\n pass\n' - 'class B(A):\n def a_method(self):\n pass\n') + mod = testutils.create_module(self.project, "mod1") + mod.write( + "class A(object):\n def a_method(self):\n pass\n" + "class B(A):\n def a_method(self):\n pass\n" + ) - self._rename(mod, mod.read().rindex('a_method') + 1, 'new_method', - in_hierarchy=True) + self._rename( + mod, mod.read().rindex("a_method") + 1, "new_method", in_hierarchy=True + ) self.assertEqual( - 'class A(object):\n def new_method(self):\n pass\n' - 'class B(A):\n def new_method(self):\n pass\n', - mod.read()) + "class A(object):\n def new_method(self):\n pass\n" + "class B(A):\n def new_method(self):\n pass\n", + mod.read(), + ) def test_renaming_methods_in_sibling_classes(self): - mod = testutils.create_module(self.project, 'mod1') - mod.write('class A(object):\n def a_method(self):\n pass\n' - 'class B(A):\n def a_method(self):\n pass\n' - 'class C(A):\n def a_method(self):\n pass\n') + mod = testutils.create_module(self.project, "mod1") + mod.write( + "class A(object):\n def a_method(self):\n pass\n" + "class B(A):\n def a_method(self):\n pass\n" + "class C(A):\n def a_method(self):\n pass\n" + ) - self._rename(mod, mod.read().rindex('a_method') + 1, 'new_method', - in_hierarchy=True) + self._rename( + mod, mod.read().rindex("a_method") + 1, "new_method", in_hierarchy=True + ) self.assertEqual( - 'class A(object):\n def new_method(self):\n pass\n' - 'class B(A):\n def new_method(self):\n pass\n' - 'class C(A):\n def new_method(self):\n pass\n', - mod.read()) + "class A(object):\n def new_method(self):\n pass\n" + "class B(A):\n def new_method(self):\n pass\n" + "class C(A):\n def new_method(self):\n pass\n", + mod.read(), + ) def test_not_renaming_methods_in_hierarchies(self): - mod = testutils.create_module(self.project, 'mod1') - mod.write('class A(object):\n def a_method(self):\n pass\n' - 'class B(A):\n def a_method(self):\n pass\n') + mod = testutils.create_module(self.project, "mod1") + mod.write( + "class A(object):\n def a_method(self):\n pass\n" + "class B(A):\n def a_method(self):\n pass\n" + ) - self._rename(mod, mod.read().rindex('a_method') + 1, 'new_method', - in_hierarchy=False) + self._rename( + mod, mod.read().rindex("a_method") + 1, "new_method", in_hierarchy=False + ) self.assertEqual( - 'class A(object):\n def a_method(self):\n pass\n' - 'class B(A):\n def new_method(self):\n pass\n', - mod.read()) + "class A(object):\n def a_method(self):\n pass\n" + "class B(A):\n def new_method(self):\n pass\n", + mod.read(), + ) def test_undoing_refactorings(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\na_func()\n') - self._rename(mod1, len(mod1.read()) - 5, 'new_func') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\na_func()\n") + self._rename(mod1, len(mod1.read()) - 5, "new_func") self.project.history.undo() - self.assertEqual('def a_func():\n pass\na_func()\n', mod1.read()) + self.assertEqual("def a_func():\n pass\na_func()\n", mod1.read()) def test_undoing_renaming_modules(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('def a_func():\n pass\n') - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('from mod1 import a_func\n') - self._rename(mod2, 6, 'newmod') + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("def a_func():\n pass\n") + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("from mod1 import a_func\n") + self._rename(mod2, 6, "newmod") self.project.history.undo() - self.assertEqual('mod1.py', mod1.path) - self.assertEqual('from mod1 import a_func\n', mod2.read()) + self.assertEqual("mod1.py", mod1.path) + self.assertEqual("from mod1 import a_func\n", mod2.read()) def test_rename_in_module_renaming_one_letter_names_for_expressions(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod1.write('a = 10\nprint(1+a)\n') - pymod = self.project.get_module('mod1') - old_pyname = pymod['a'] - finder = rope.refactor.occurrences.create_finder( - self.project, 'a', old_pyname) + mod1 = testutils.create_module(self.project, "mod1") + mod1.write("a = 10\nprint(1+a)\n") + pymod = self.project.get_module("mod1") + old_pyname = pymod["a"] + finder = rope.refactor.occurrences.create_finder(self.project, "a", old_pyname) refactored = rename.rename_in_module( - finder, 'new_var', pymodule=pymod, replace_primary=True) - self.assertEqual('new_var = 10\nprint(1+new_var)\n', refactored) + finder, "new_var", pymodule=pymod, replace_primary=True + ) + self.assertEqual("new_var = 10\nprint(1+new_var)\n", refactored) def test_renaming_for_loop_variable(self): - code = 'for var in range(10):\n print(var)\n' - refactored = self._local_rename(code, code.find('var') + 1, 'new_var') - self.assertEqual('for new_var in range(10):\n print(new_var)\n', - refactored) + code = "for var in range(10):\n print(var)\n" + refactored = self._local_rename(code, code.find("var") + 1, "new_var") + self.assertEqual("for new_var in range(10):\n print(new_var)\n", refactored) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_renaming_async_for_loop_variable(self): - code = 'async def func():\n async for var in range(10):\n print(var)\n' - refactored = self._local_rename(code, code.find('var') + 1, 'new_var') - self.assertEqual('async def func():\n async for new_var in range(10):\n print(new_var)\n', - refactored) + code = ( + "async def func():\n async for var in range(10):\n print(var)\n" + ) + refactored = self._local_rename(code, code.find("var") + 1, "new_var") + self.assertEqual( + "async def func():\n async for new_var in range(10):\n print(new_var)\n", + refactored, + ) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_renaming_async_with_context_manager(self): - code = 'def a_cm(): pass\n'\ - 'async def a_func():\n async with a_cm() as x: pass' - refactored = self._local_rename(code, code.find('a_cm') + 1, 'another_cm') - expected = 'def another_cm(): pass\n'\ - 'async def a_func():\n async with another_cm() as x: pass' + code = ( + "def a_cm(): pass\n" "async def a_func():\n async with a_cm() as x: pass" + ) + refactored = self._local_rename(code, code.find("a_cm") + 1, "another_cm") + expected = ( + "def another_cm(): pass\n" + "async def a_func():\n async with another_cm() as x: pass" + ) self.assertEqual(refactored, expected) - @testutils.only_for('3.5') + @testutils.only_for("3.5") def test_renaming_async_with_as_variable(self): - code = 'async def func():\n async with a_func() as var:\n print(var)\n' - refactored = self._local_rename(code, code.find('var') + 1, 'new_var') - self.assertEqual('async def func():\n async with a_func() as new_var:\n print(new_var)\n', - refactored) + code = ( + "async def func():\n async with a_func() as var:\n print(var)\n" + ) + refactored = self._local_rename(code, code.find("var") + 1, "new_var") + self.assertEqual( + "async def func():\n async with a_func() as new_var:\n print(new_var)\n", + refactored, + ) def test_renaming_parameters(self): - code = 'def a_func(param):\n print(param)\na_func(param=hey)\n' - refactored = self._local_rename(code, code.find('param') + 1, - 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n' - 'a_func(new_param=hey)\n', refactored) + code = "def a_func(param):\n print(param)\na_func(param=hey)\n" + refactored = self._local_rename(code, code.find("param") + 1, "new_param") + self.assertEqual( + "def a_func(new_param):\n print(new_param)\n" "a_func(new_param=hey)\n", + refactored, + ) def test_renaming_assigned_parameters(self): - code = 'def f(p):\n p = p + 1\n return p\nf(p=1)\n' - refactored = self._local_rename(code, code.find('p'), 'arg') - self.assertEqual('def f(arg):\n arg = arg + 1\n' - ' return arg\nf(arg=1)\n', refactored) + code = "def f(p):\n p = p + 1\n return p\nf(p=1)\n" + refactored = self._local_rename(code, code.find("p"), "arg") + self.assertEqual( + "def f(arg):\n arg = arg + 1\n" " return arg\nf(arg=1)\n", refactored + ) def test_renaming_parameters_not_renaming_others(self): - code = 'def a_func(param):' \ - '\n print(param)\nparam=10\na_func(param)\n' - refactored = self._local_rename(code, code.find('param') + 1, - 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n' - 'param=10\na_func(param)\n', refactored) + code = "def a_func(param):" "\n print(param)\nparam=10\na_func(param)\n" + refactored = self._local_rename(code, code.find("param") + 1, "new_param") + self.assertEqual( + "def a_func(new_param):\n print(new_param)\n" + "param=10\na_func(param)\n", + refactored, + ) def test_renaming_parameters_not_renaming_others2(self): - code = 'def a_func(param):\n print(param)\n' \ - 'param=10\na_func(param=param)' - refactored = self._local_rename(code, code.find('param') + 1, - 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n' - 'param=10\na_func(new_param=param)', refactored) + code = "def a_func(param):\n print(param)\n" "param=10\na_func(param=param)" + refactored = self._local_rename(code, code.find("param") + 1, "new_param") + self.assertEqual( + "def a_func(new_param):\n print(new_param)\n" + "param=10\na_func(new_param=param)", + refactored, + ) def test_renaming_parameters_with_multiple_params(self): - code = 'def a_func(param1, param2):\n print(param1)\n'\ - 'a_func(param1=1, param2=2)\n' - refactored = self._local_rename(code, code.find('param1') + 1, - 'new_param') + code = ( + "def a_func(param1, param2):\n print(param1)\n" + "a_func(param1=1, param2=2)\n" + ) + refactored = self._local_rename(code, code.find("param1") + 1, "new_param") self.assertEqual( - 'def a_func(new_param, param2):\n print(new_param)\n' - 'a_func(new_param=1, param2=2)\n', refactored) + "def a_func(new_param, param2):\n print(new_param)\n" + "a_func(new_param=1, param2=2)\n", + refactored, + ) def test_renaming_parameters_with_multiple_params2(self): - code = 'def a_func(param1, param2):\n print(param1)\n' \ - 'a_func(param1=1, param2=2)\n' - refactored = self._local_rename(code, code.rfind('param2') + 1, - 'new_param') - self.assertEqual('def a_func(param1, new_param):\n print(param1)\n' - 'a_func(param1=1, new_param=2)\n', refactored) + code = ( + "def a_func(param1, param2):\n print(param1)\n" + "a_func(param1=1, param2=2)\n" + ) + refactored = self._local_rename(code, code.rfind("param2") + 1, "new_param") + self.assertEqual( + "def a_func(param1, new_param):\n print(param1)\n" + "a_func(param1=1, new_param=2)\n", + refactored, + ) def test_renaming_parameters_on_calls(self): - code = 'def a_func(param):\n print(param)\na_func(param = hey)\n' - refactored = self._local_rename(code, code.rfind('param') + 1, - 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n' - 'a_func(new_param = hey)\n', refactored) + code = "def a_func(param):\n print(param)\na_func(param = hey)\n" + refactored = self._local_rename(code, code.rfind("param") + 1, "new_param") + self.assertEqual( + "def a_func(new_param):\n print(new_param)\n" + "a_func(new_param = hey)\n", + refactored, + ) def test_renaming_parameters_spaces_before_call(self): - code = 'def a_func(param):\n print(param)\na_func (param=hey)\n' - refactored = self._local_rename(code, code.rfind('param') + 1, - 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n' - 'a_func (new_param=hey)\n', refactored) + code = "def a_func(param):\n print(param)\na_func (param=hey)\n" + refactored = self._local_rename(code, code.rfind("param") + 1, "new_param") + self.assertEqual( + "def a_func(new_param):\n print(new_param)\n" + "a_func (new_param=hey)\n", + refactored, + ) def test_renaming_parameter_like_objects_after_keywords(self): - code = 'def a_func(param):\n print(param)\ndict(param=hey)\n' - refactored = self._local_rename(code, code.find('param') + 1, - 'new_param') - self.assertEqual('def a_func(new_param):\n print(new_param)\n' - 'dict(param=hey)\n', refactored) + code = "def a_func(param):\n print(param)\ndict(param=hey)\n" + refactored = self._local_rename(code, code.find("param") + 1, "new_param") + self.assertEqual( + "def a_func(new_param):\n print(new_param)\n" "dict(param=hey)\n", + refactored, + ) def test_renaming_variables_in_init_dot_pys(self): - pkg = testutils.create_package(self.project, 'pkg') - init_dot_py = pkg.get_child('__init__.py') - init_dot_py.write('a_var = 10\n') - mod = testutils.create_module(self.project, 'mod') - mod.write('import pkg\nprint(pkg.a_var)\n') - self._rename(mod, mod.read().index('a_var') + 1, 'new_var') - self.assertEqual('new_var = 10\n', init_dot_py.read()) - self.assertEqual('import pkg\nprint(pkg.new_var)\n', mod.read()) + pkg = testutils.create_package(self.project, "pkg") + init_dot_py = pkg.get_child("__init__.py") + init_dot_py.write("a_var = 10\n") + mod = testutils.create_module(self.project, "mod") + mod.write("import pkg\nprint(pkg.a_var)\n") + self._rename(mod, mod.read().index("a_var") + 1, "new_var") + self.assertEqual("new_var = 10\n", init_dot_py.read()) + self.assertEqual("import pkg\nprint(pkg.new_var)\n", mod.read()) def test_renaming_variables_in_init_dot_pys2(self): - pkg = testutils.create_package(self.project, 'pkg') - init_dot_py = pkg.get_child('__init__.py') - init_dot_py.write('a_var = 10\n') - mod = testutils.create_module(self.project, 'mod') - mod.write('import pkg\nprint(pkg.a_var)\n') - self._rename(init_dot_py, - init_dot_py.read().index('a_var') + 1, 'new_var') - self.assertEqual('new_var = 10\n', init_dot_py.read()) - self.assertEqual('import pkg\nprint(pkg.new_var)\n', mod.read()) + pkg = testutils.create_package(self.project, "pkg") + init_dot_py = pkg.get_child("__init__.py") + init_dot_py.write("a_var = 10\n") + mod = testutils.create_module(self.project, "mod") + mod.write("import pkg\nprint(pkg.a_var)\n") + self._rename(init_dot_py, init_dot_py.read().index("a_var") + 1, "new_var") + self.assertEqual("new_var = 10\n", init_dot_py.read()) + self.assertEqual("import pkg\nprint(pkg.new_var)\n", mod.read()) def test_renaming_variables_in_init_dot_pys3(self): - pkg = testutils.create_package(self.project, 'pkg') - init_dot_py = pkg.get_child('__init__.py') - init_dot_py.write('a_var = 10\n') - mod = testutils.create_module(self.project, 'mod') - mod.write('import pkg\nprint(pkg.a_var)\n') - self._rename(mod, mod.read().index('a_var') + 1, 'new_var') - self.assertEqual('new_var = 10\n', init_dot_py.read()) - self.assertEqual('import pkg\nprint(pkg.new_var)\n', mod.read()) + pkg = testutils.create_package(self.project, "pkg") + init_dot_py = pkg.get_child("__init__.py") + init_dot_py.write("a_var = 10\n") + mod = testutils.create_module(self.project, "mod") + mod.write("import pkg\nprint(pkg.a_var)\n") + self._rename(mod, mod.read().index("a_var") + 1, "new_var") + self.assertEqual("new_var = 10\n", init_dot_py.read()) + self.assertEqual("import pkg\nprint(pkg.new_var)\n", mod.read()) def test_renaming_resources_using_rename_module_refactoring(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('a_var = 1') - mod2.write('import mod1\nmy_var = mod1.a_var\n') + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("a_var = 1") + mod2.write("import mod1\nmy_var = mod1.a_var\n") renamer = rename.Rename(self.project, mod1) - renamer.get_changes('newmod').do() - self.assertEqual('import newmod\nmy_var = newmod.a_var\n', - mod2.read()) + renamer.get_changes("newmod").do() + self.assertEqual("import newmod\nmy_var = newmod.a_var\n", mod2.read()) def test_renam_resources_using_rename_module_refactor_for_packages(self): - mod1 = testutils.create_module(self.project, 'mod1') - pkg = testutils.create_package(self.project, 'pkg') - mod1.write('import pkg\nmy_pkg = pkg') + mod1 = testutils.create_module(self.project, "mod1") + pkg = testutils.create_package(self.project, "pkg") + mod1.write("import pkg\nmy_pkg = pkg") renamer = rename.Rename(self.project, pkg) - renamer.get_changes('newpkg').do() - self.assertEqual('import newpkg\nmy_pkg = newpkg', mod1.read()) + renamer.get_changes("newpkg").do() + self.assertEqual("import newpkg\nmy_pkg = newpkg", mod1.read()) def test_renam_resources_use_rename_module_refactor_for_init_dot_py(self): - mod1 = testutils.create_module(self.project, 'mod1') - pkg = testutils.create_package(self.project, 'pkg') - mod1.write('import pkg\nmy_pkg = pkg') - renamer = rename.Rename(self.project, pkg.get_child('__init__.py')) - renamer.get_changes('newpkg').do() - self.assertEqual('import newpkg\nmy_pkg = newpkg', mod1.read()) + mod1 = testutils.create_module(self.project, "mod1") + pkg = testutils.create_package(self.project, "pkg") + mod1.write("import pkg\nmy_pkg = pkg") + renamer = rename.Rename(self.project, pkg.get_child("__init__.py")) + renamer.get_changes("newpkg").do() + self.assertEqual("import newpkg\nmy_pkg = newpkg", mod1.read()) def test_renaming_global_variables(self): - code = 'a_var = 1\ndef a_func():\n global a_var\n var = a_var\n' - refactored = self._local_rename(code, code.index('a_var'), 'new_var') + code = "a_var = 1\ndef a_func():\n global a_var\n var = a_var\n" + refactored = self._local_rename(code, code.index("a_var"), "new_var") self.assertEqual( - 'new_var = 1\ndef a_func():\n ' - 'global new_var\n var = new_var\n', - refactored) + "new_var = 1\ndef a_func():\n " "global new_var\n var = new_var\n", + refactored, + ) def test_renaming_global_variables2(self): - code = 'a_var = 1\ndef a_func():\n global a_var\n var = a_var\n' - refactored = self._local_rename(code, code.rindex('a_var'), 'new_var') + code = "a_var = 1\ndef a_func():\n global a_var\n var = a_var\n" + refactored = self._local_rename(code, code.rindex("a_var"), "new_var") self.assertEqual( - 'new_var = 1\ndef a_func():\n ' - 'global new_var\n var = new_var\n', - refactored) + "new_var = 1\ndef a_func():\n " "global new_var\n var = new_var\n", + refactored, + ) def test_renaming_when_unsure(self): - code = 'class C(object):\n def a_func(self):\n pass\n' \ - 'def f(arg):\n arg.a_func()\n' - mod1 = testutils.create_module(self.project, 'mod1') + code = ( + "class C(object):\n def a_func(self):\n pass\n" + "def f(arg):\n arg.a_func()\n" + ) + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.index('a_func'), - 'new_func', unsure=self._true) + self._rename(mod1, code.index("a_func"), "new_func", unsure=self._true) self.assertEqual( - 'class C(object):\n def new_func(self):\n pass\n' - 'def f(arg):\n arg.new_func()\n', - mod1.read()) + "class C(object):\n def new_func(self):\n pass\n" + "def f(arg):\n arg.new_func()\n", + mod1.read(), + ) def _true(self, *args): return True @@ -625,204 +661,224 @@ def _true(self, *args): def test_renaming_when_unsure_with_confirmation(self): def confirm(occurrence): return False - code = 'class C(object):\n def a_func(self):\n pass\n' \ - 'def f(arg):\n arg.a_func()\n' - mod1 = testutils.create_module(self.project, 'mod1') + + code = ( + "class C(object):\n def a_func(self):\n pass\n" + "def f(arg):\n arg.a_func()\n" + ) + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.index('a_func'), 'new_func', unsure=confirm) + self._rename(mod1, code.index("a_func"), "new_func", unsure=confirm) self.assertEqual( - 'class C(object):\n def new_func(self):\n pass\n' - 'def f(arg):\n arg.a_func()\n', mod1.read()) + "class C(object):\n def new_func(self):\n pass\n" + "def f(arg):\n arg.a_func()\n", + mod1.read(), + ) def test_renaming_when_unsure_not_renaming_knowns(self): - code = 'class C1(object):\n def a_func(self):\n pass\n' \ - 'class C2(object):\n def a_func(self):\n pass\n' \ - 'c1 = C1()\nc1.a_func()\nc2 = C2()\nc2.a_func()\n' - mod1 = testutils.create_module(self.project, 'mod1') + code = ( + "class C1(object):\n def a_func(self):\n pass\n" + "class C2(object):\n def a_func(self):\n pass\n" + "c1 = C1()\nc1.a_func()\nc2 = C2()\nc2.a_func()\n" + ) + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.index('a_func'), 'new_func', unsure=self._true) + self._rename(mod1, code.index("a_func"), "new_func", unsure=self._true) self.assertEqual( - 'class C1(object):\n def new_func(self):\n pass\n' - 'class C2(object):\n def a_func(self):\n pass\n' - 'c1 = C1()\nc1.new_func()\nc2 = C2()\nc2.a_func()\n', - mod1.read()) + "class C1(object):\n def new_func(self):\n pass\n" + "class C2(object):\n def a_func(self):\n pass\n" + "c1 = C1()\nc1.new_func()\nc2 = C2()\nc2.a_func()\n", + mod1.read(), + ) def test_renaming_in_strings_and_comments(self): - code = 'a_var = 1\n# a_var\n' - mod1 = testutils.create_module(self.project, 'mod1') + code = "a_var = 1\n# a_var\n" + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.index('a_var'), 'new_var', docs=True) - self.assertEqual('new_var = 1\n# new_var\n', mod1.read()) + self._rename(mod1, code.index("a_var"), "new_var", docs=True) + self.assertEqual("new_var = 1\n# new_var\n", mod1.read()) def test_not_renaming_in_strings_and_comments_where_not_visible(self): - code = 'def f():\n a_var = 1\n# a_var\n' - mod1 = testutils.create_module(self.project, 'mod1') + code = "def f():\n a_var = 1\n# a_var\n" + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.index('a_var'), 'new_var', docs=True) - self.assertEqual('def f():\n new_var = 1\n# a_var\n', mod1.read()) + self._rename(mod1, code.index("a_var"), "new_var", docs=True) + self.assertEqual("def f():\n new_var = 1\n# a_var\n", mod1.read()) def test_not_renaming_all_text_occurrences_in_strings_and_comments(self): - code = 'a_var = 1\n# a_vard _a_var\n' - mod1 = testutils.create_module(self.project, 'mod1') + code = "a_var = 1\n# a_vard _a_var\n" + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.index('a_var'), 'new_var', docs=True) - self.assertEqual('new_var = 1\n# a_vard _a_var\n', mod1.read()) + self._rename(mod1, code.index("a_var"), "new_var", docs=True) + self.assertEqual("new_var = 1\n# a_vard _a_var\n", mod1.read()) def test_renaming_occurrences_in_overwritten_scopes(self): refactored = self._local_rename( - 'a_var = 20\ndef f():\n print(a_var)\n' - 'def f():\n print(a_var)\n', 2, 'new_var') - self.assertEqual('new_var = 20\ndef f():\n print(new_var)\n' - 'def f():\n print(new_var)\n', refactored) + "a_var = 20\ndef f():\n print(a_var)\n" "def f():\n print(a_var)\n", + 2, + "new_var", + ) + self.assertEqual( + "new_var = 20\ndef f():\n print(new_var)\n" + "def f():\n print(new_var)\n", + refactored, + ) def test_renaming_occurrences_in_overwritten_scopes2(self): - code = 'def f():\n a_var = 1\n print(a_var)\n' \ - 'def f():\n a_var = 1\n print(a_var)\n' - refactored = self._local_rename(code, code.index('a_var') + 1, - 'new_var') - self.assertEqual(code.replace('a_var', 'new_var', 2), refactored) + code = ( + "def f():\n a_var = 1\n print(a_var)\n" + "def f():\n a_var = 1\n print(a_var)\n" + ) + refactored = self._local_rename(code, code.index("a_var") + 1, "new_var") + self.assertEqual(code.replace("a_var", "new_var", 2), refactored) - @testutils.only_for_versions_higher('3.5') + @testutils.only_for_versions_higher("3.5") def test_renaming_in_generalized_dict_unpacking(self): - code = dedent('''\ + code = dedent("""\ a_var = {**{'stuff': 'can'}, **{'stuff': 'crayon'}} if "stuff" in a_var: print("ya") - ''') - mod1 = testutils.create_module(self.project, 'mod1') + """) + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - refactored = self._local_rename(code, code.index('a_var') + 1, - 'new_var') - expected = dedent('''\ + refactored = self._local_rename(code, code.index("a_var") + 1, "new_var") + expected = dedent("""\ new_var = {**{'stuff': 'can'}, **{'stuff': 'crayon'}} if "stuff" in new_var: print("ya") - ''') + """) self.assertEqual(expected, refactored) def test_dos_line_ending_and_renaming(self): - code = '\r\na = 1\r\n\r\nprint(2 + a + 2)\r\n' - offset = code.replace('\r\n', '\n').rindex('a') - refactored = self._local_rename(code, offset, 'b') - self.assertEqual('\nb = 1\n\nprint(2 + b + 2)\n', - refactored.replace('\r\n', '\n')) + code = "\r\na = 1\r\n\r\nprint(2 + a + 2)\r\n" + offset = code.replace("\r\n", "\n").rindex("a") + refactored = self._local_rename(code, offset, "b") + self.assertEqual( + "\nb = 1\n\nprint(2 + b + 2)\n", refactored.replace("\r\n", "\n") + ) def test_multi_byte_strs_and_renaming(self): - s = u'{LATIN SMALL LETTER I WITH DIAERESIS}' * 4 - code = u'# -*- coding: utf-8 -*-\n# ' + s + \ - '\na = 1\nprint(2 + a + 2)\n' - refactored = self._local_rename(code, code.rindex('a'), 'b') - self.assertEqual(u'# -*- coding: utf-8 -*-\n# ' + s + - '\nb = 1\nprint(2 + b + 2)\n', refactored) + s = u"{LATIN SMALL LETTER I WITH DIAERESIS}" * 4 + code = u"# -*- coding: utf-8 -*-\n# " + s + "\na = 1\nprint(2 + a + 2)\n" + refactored = self._local_rename(code, code.rindex("a"), "b") + self.assertEqual( + u"# -*- coding: utf-8 -*-\n# " + s + "\nb = 1\nprint(2 + b + 2)\n", + refactored, + ) def test_resources_parameter(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('def f():\n pass\n') - mod2.write('import mod1\nmod1.f()\n') - self._rename(mod1, mod1.read().rindex('f'), 'g', - resources=[mod1]) - self.assertEqual('def g():\n pass\n', mod1.read()) - self.assertEqual('import mod1\nmod1.f()\n', mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("def f():\n pass\n") + mod2.write("import mod1\nmod1.f()\n") + self._rename(mod1, mod1.read().rindex("f"), "g", resources=[mod1]) + self.assertEqual("def g():\n pass\n", mod1.read()) + self.assertEqual("import mod1\nmod1.f()\n", mod2.read()) def test_resources_parameter_not_changing_defining_module(self): - mod1 = testutils.create_module(self.project, 'mod1') - mod2 = testutils.create_module(self.project, 'mod2') - mod1.write('def f():\n pass\n') - mod2.write('import mod1\nmod1.f()\n') - self._rename(mod1, mod1.read().rindex('f'), 'g', - resources=[mod2]) - self.assertEqual('def f():\n pass\n', mod1.read()) - self.assertEqual('import mod1\nmod1.g()\n', mod2.read()) + mod1 = testutils.create_module(self.project, "mod1") + mod2 = testutils.create_module(self.project, "mod2") + mod1.write("def f():\n pass\n") + mod2.write("import mod1\nmod1.f()\n") + self._rename(mod1, mod1.read().rindex("f"), "g", resources=[mod2]) + self.assertEqual("def f():\n pass\n", mod1.read()) + self.assertEqual("import mod1\nmod1.g()\n", mod2.read()) # XXX: with variables should not leak - @testutils.only_for('2.5') + @testutils.only_for("2.5") def xxx_test_with_statement_variables_should_not_leak(self): code = 'f = 1\nwith open("1.txt") as f:\n print(f)\n' if sys.version_info < (2, 6, 0): - code = 'from __future__ import with_statement\n' + code - mod1 = testutils.create_module(self.project, 'mod1') + code = "from __future__ import with_statement\n" + code + mod1 = testutils.create_module(self.project, "mod1") mod1.write(code) - self._rename(mod1, code.rindex('f'), 'file') + self._rename(mod1, code.rindex("f"), "file") expected = 'f = 1\nwith open("1.txt") as file:\n print(file)\n' self.assertEqual(expected, mod1.read()) class ChangeOccurrencesTest(unittest.TestCase): - def setUp(self): self.project = testutils.sample_project() - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(ChangeOccurrencesTest, self).tearDown() def test_simple_case(self): - self.mod.write('a_var = 1\nprint(a_var)\n') - changer = rename.ChangeOccurrences(self.project, self.mod, - self.mod.read().index('a_var')) - changer.get_changes('new_var').do() - self.assertEqual('new_var = 1\nprint(new_var)\n', self.mod.read()) + self.mod.write("a_var = 1\nprint(a_var)\n") + changer = rename.ChangeOccurrences( + self.project, self.mod, self.mod.read().index("a_var") + ) + changer.get_changes("new_var").do() + self.assertEqual("new_var = 1\nprint(new_var)\n", self.mod.read()) def test_only_performing_inside_scopes(self): - self.mod.write('a_var = 1\nnew_var = 2\ndef f():\n print(a_var)\n') - changer = rename.ChangeOccurrences(self.project, self.mod, - self.mod.read().rindex('a_var')) - changer.get_changes('new_var').do() + self.mod.write("a_var = 1\nnew_var = 2\ndef f():\n print(a_var)\n") + changer = rename.ChangeOccurrences( + self.project, self.mod, self.mod.read().rindex("a_var") + ) + changer.get_changes("new_var").do() self.assertEqual( - 'a_var = 1\nnew_var = 2\ndef f():\n print(new_var)\n', - self.mod.read()) + "a_var = 1\nnew_var = 2\ndef f():\n print(new_var)\n", self.mod.read() + ) def test_only_performing_on_calls(self): - self.mod.write('def f1():\n pass\ndef f2():\n pass\n' - 'g = f1\na = f1()\n') - changer = rename.ChangeOccurrences(self.project, self.mod, - self.mod.read().rindex('f1')) - changer.get_changes('f2', only_calls=True).do() + self.mod.write( + "def f1():\n pass\ndef f2():\n pass\n" "g = f1\na = f1()\n" + ) + changer = rename.ChangeOccurrences( + self.project, self.mod, self.mod.read().rindex("f1") + ) + changer.get_changes("f2", only_calls=True).do() self.assertEqual( - 'def f1():\n pass\ndef f2():\n pass\ng = f1\na = f2()\n', - self.mod.read()) + "def f1():\n pass\ndef f2():\n pass\ng = f1\na = f2()\n", + self.mod.read(), + ) def test_only_performing_on_reads(self): - self.mod.write('a = 1\nb = 2\nprint(a)\n') - changer = rename.ChangeOccurrences(self.project, self.mod, - self.mod.read().rindex('a')) - changer.get_changes('b', writes=False).do() - self.assertEqual('a = 1\nb = 2\nprint(b)\n', self.mod.read()) + self.mod.write("a = 1\nb = 2\nprint(a)\n") + changer = rename.ChangeOccurrences( + self.project, self.mod, self.mod.read().rindex("a") + ) + changer.get_changes("b", writes=False).do() + self.assertEqual("a = 1\nb = 2\nprint(b)\n", self.mod.read()) class ImplicitInterfacesTest(unittest.TestCase): - def setUp(self): super(ImplicitInterfacesTest, self).setUp() self.project = testutils.sample_project(validate_objectdb=True) self.pycore = self.project.pycore - self.mod1 = testutils.create_module(self.project, 'mod1') - self.mod2 = testutils.create_module(self.project, 'mod2') + self.mod1 = testutils.create_module(self.project, "mod1") + self.mod2 = testutils.create_module(self.project, "mod2") def tearDown(self): testutils.remove_project(self.project) super(ImplicitInterfacesTest, self).tearDown() def _rename(self, resource, offset, new_name, **kwds): - changes = Rename(self.project, resource, offset).\ - get_changes(new_name, **kwds) + changes = Rename(self.project, resource, offset).get_changes(new_name, **kwds) self.project.do(changes) def test_performing_rename_on_parameters(self): - self.mod1.write('def f(arg):\n arg.run()\n') - self.mod2.write('import mod1\n\n\n' - 'class A(object):\n def run(self):\n pass\n' - 'class B(object):\n def run(self):\n pass\n' - 'mod1.f(A())\nmod1.f(B())\n') + self.mod1.write("def f(arg):\n arg.run()\n") + self.mod2.write( + "import mod1\n\n\n" + "class A(object):\n def run(self):\n pass\n" + "class B(object):\n def run(self):\n pass\n" + "mod1.f(A())\nmod1.f(B())\n" + ) self.pycore.analyze_module(self.mod2) - self._rename(self.mod1, self.mod1.read().index('run'), 'newrun') - self.assertEqual('def f(arg):\n arg.newrun()\n', self.mod1.read()) + self._rename(self.mod1, self.mod1.read().index("run"), "newrun") + self.assertEqual("def f(arg):\n arg.newrun()\n", self.mod1.read()) self.assertEqual( - 'import mod1\n\n\n' - 'class A(object):\n def newrun(self):\n pass\n' - 'class B(object):\n def newrun(self):\n pass\n' - 'mod1.f(A())\nmod1.f(B())\n', self.mod2.read()) + "import mod1\n\n\n" + "class A(object):\n def newrun(self):\n pass\n" + "class B(object):\n def newrun(self):\n pass\n" + "mod1.f(A())\nmod1.f(B())\n", + self.mod2.read(), + ) diff --git a/ropetest/refactor/restructuretest.py b/ropetest/refactor/restructuretest.py index 3c971a0d0..80d66e713 100644 --- a/ropetest/refactor/restructuretest.py +++ b/ropetest/refactor/restructuretest.py @@ -8,179 +8,182 @@ class RestructureTest(unittest.TestCase): - def setUp(self): super(RestructureTest, self).setUp() self.project = testutils.sample_project() self.pycore = self.project.pycore - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) super(RestructureTest, self).tearDown() def test_trivial_case(self): - refactoring = restructure.Restructure(self.project, - 'a = 1', 'a = 0') - self.mod.write('b = 1\n') + refactoring = restructure.Restructure(self.project, "a = 1", "a = 0") + self.mod.write("b = 1\n") self.project.do(refactoring.get_changes()) - self.assertEqual('b = 1\n', self.mod.read()) + self.assertEqual("b = 1\n", self.mod.read()) def test_replacing_simple_patterns(self): - refactoring = restructure.Restructure(self.project, - 'a = 1', 'a = int(1)') - self.mod.write('a = 1\nb = 1\n') + refactoring = restructure.Restructure(self.project, "a = 1", "a = int(1)") + self.mod.write("a = 1\nb = 1\n") self.project.do(refactoring.get_changes()) - self.assertEqual('a = int(1)\nb = 1\n', self.mod.read()) + self.assertEqual("a = int(1)\nb = 1\n", self.mod.read()) def test_replacing_patterns_with_normal_names(self): refactoring = restructure.Restructure( - self.project, '${a} = 1', '${a} = int(1)', args={'a': 'exact'}) - self.mod.write('a = 1\nb = 1\n') + self.project, "${a} = 1", "${a} = int(1)", args={"a": "exact"} + ) + self.mod.write("a = 1\nb = 1\n") self.project.do(refactoring.get_changes()) - self.assertEqual('a = int(1)\nb = 1\n', self.mod.read()) + self.assertEqual("a = int(1)\nb = 1\n", self.mod.read()) def test_replacing_patterns_with_any_names(self): - refactoring = restructure.Restructure(self.project, - '${a} = 1', '${a} = int(1)') - self.mod.write('a = 1\nb = 1\n') + refactoring = restructure.Restructure(self.project, "${a} = 1", "${a} = int(1)") + self.mod.write("a = 1\nb = 1\n") self.project.do(refactoring.get_changes()) - self.assertEqual('a = int(1)\nb = int(1)\n', self.mod.read()) + self.assertEqual("a = int(1)\nb = int(1)\n", self.mod.read()) def test_replacing_patterns_with_any_names2(self): - refactoring = restructure.Restructure( - self.project, '${x} + ${x}', '${x} * 2') - self.mod.write('a = 1 + 1\n') + refactoring = restructure.Restructure(self.project, "${x} + ${x}", "${x} * 2") + self.mod.write("a = 1 + 1\n") self.project.do(refactoring.get_changes()) - self.assertEqual('a = 1 * 2\n', self.mod.read()) + self.assertEqual("a = 1 * 2\n", self.mod.read()) def test_replacing_patterns_with_checks(self): - self.mod.write('def f(p=1):\n return p\ng = f\ng()\n') + self.mod.write("def f(p=1):\n return p\ng = f\ng()\n") refactoring = restructure.Restructure( - self.project, '${f}()', '${f}(2)', args={'f': 'object=mod.f'}) + self.project, "${f}()", "${f}(2)", args={"f": "object=mod.f"} + ) self.project.do(refactoring.get_changes()) - self.assertEqual('def f(p=1):\n return p\ng = f\ng(2)\n', - self.mod.read()) + self.assertEqual("def f(p=1):\n return p\ng = f\ng(2)\n", self.mod.read()) def test_replacing_assignments_with_sets(self): refactoring = restructure.Restructure( - self.project, '${a} = ${b}', '${a}.set(${b})') - self.mod.write('a = 1\nb = 1\n') + self.project, "${a} = ${b}", "${a}.set(${b})" + ) + self.mod.write("a = 1\nb = 1\n") self.project.do(refactoring.get_changes()) - self.assertEqual('a.set(1)\nb.set(1)\n', self.mod.read()) + self.assertEqual("a.set(1)\nb.set(1)\n", self.mod.read()) def test_replacing_sets_with_assignments(self): refactoring = restructure.Restructure( - self.project, '${a}.set(${b})', '${a} = ${b}') - self.mod.write('a.set(1)\nb.set(1)\n') + self.project, "${a}.set(${b})", "${a} = ${b}" + ) + self.mod.write("a.set(1)\nb.set(1)\n") self.project.do(refactoring.get_changes()) - self.assertEqual('a = 1\nb = 1\n', self.mod.read()) + self.assertEqual("a = 1\nb = 1\n", self.mod.read()) def test_using_make_checks(self): - self.mod.write('def f(p=1):\n return p\ng = f\ng()\n') + self.mod.write("def f(p=1):\n return p\ng = f\ng()\n") refactoring = restructure.Restructure( - self.project, '${f}()', '${f}(2)', args={'f': 'object=mod.f'}) + self.project, "${f}()", "${f}(2)", args={"f": "object=mod.f"} + ) self.project.do(refactoring.get_changes()) - self.assertEqual('def f(p=1):\n return p\ng = f\ng(2)\n', - self.mod.read()) + self.assertEqual("def f(p=1):\n return p\ng = f\ng(2)\n", self.mod.read()) def test_using_make_checking_builtin_types(self): - self.mod.write('a = 1 + 1\n') + self.mod.write("a = 1 + 1\n") refactoring = restructure.Restructure( - self.project, '${i} + ${i}', '${i} * 2', - args={'i': 'type=__builtin__.int'}) + self.project, "${i} + ${i}", "${i} * 2", args={"i": "type=__builtin__.int"} + ) self.project.do(refactoring.get_changes()) - self.assertEqual('a = 1 * 2\n', self.mod.read()) + self.assertEqual("a = 1 * 2\n", self.mod.read()) def test_auto_indentation_when_no_indentation(self): - self.mod.write('a = 2\n') + self.mod.write("a = 2\n") refactoring = restructure.Restructure( - self.project, '${a} = 2', '${a} = 1\n${a} += 1') + self.project, "${a} = 2", "${a} = 1\n${a} += 1" + ) self.project.do(refactoring.get_changes()) - self.assertEqual('a = 1\na += 1\n', self.mod.read()) + self.assertEqual("a = 1\na += 1\n", self.mod.read()) def test_auto_indentation(self): - self.mod.write('def f():\n a = 2\n') + self.mod.write("def f():\n a = 2\n") refactoring = restructure.Restructure( - self.project, '${a} = 2', '${a} = 1\n${a} += 1') + self.project, "${a} = 2", "${a} = 1\n${a} += 1" + ) self.project.do(refactoring.get_changes()) - self.assertEqual('def f():\n a = 1\n a += 1\n', self.mod.read()) + self.assertEqual("def f():\n a = 1\n a += 1\n", self.mod.read()) def test_auto_indentation_and_not_indenting_blanks(self): - self.mod.write('def f():\n a = 2\n') + self.mod.write("def f():\n a = 2\n") refactoring = restructure.Restructure( - self.project, '${a} = 2', '${a} = 1\n\n${a} += 1') + self.project, "${a} = 2", "${a} = 1\n\n${a} += 1" + ) self.project.do(refactoring.get_changes()) - self.assertEqual('def f():\n a = 1\n\n a += 1\n', - self.mod.read()) + self.assertEqual("def f():\n a = 1\n\n a += 1\n", self.mod.read()) def test_importing_names(self): - self.mod.write('a = 2\n') + self.mod.write("a = 2\n") refactoring = restructure.Restructure( - self.project, '${a} = 2', '${a} = myconsts.two', - imports=['import myconsts']) + self.project, "${a} = 2", "${a} = myconsts.two", imports=["import myconsts"] + ) self.project.do(refactoring.get_changes()) - self.assertEqual('import myconsts\na = myconsts.two\n', - self.mod.read()) + self.assertEqual("import myconsts\na = myconsts.two\n", self.mod.read()) def test_not_importing_names_when_there_are_no_changes(self): - self.mod.write('a = True\n') + self.mod.write("a = True\n") refactoring = restructure.Restructure( - self.project, '${a} = 2', '${a} = myconsts.two', - imports=['import myconsts']) + self.project, "${a} = 2", "${a} = myconsts.two", imports=["import myconsts"] + ) self.project.do(refactoring.get_changes()) - self.assertEqual('a = True\n', self.mod.read()) + self.assertEqual("a = True\n", self.mod.read()) def test_handling_containing_matches(self): - self.mod.write('a = 1 / 2 / 3\n') + self.mod.write("a = 1 / 2 / 3\n") refactoring = restructure.Restructure( - self.project, '${a} / ${b}', '${a} // ${b}') + self.project, "${a} / ${b}", "${a} // ${b}" + ) self.project.do(refactoring.get_changes()) - self.assertEqual('a = 1 // 2 // 3\n', self.mod.read()) + self.assertEqual("a = 1 // 2 // 3\n", self.mod.read()) def test_handling_overlapping_matches(self): - self.mod.write('a = 1\na = 1\na = 1\n') - refactoring = restructure.Restructure( - self.project, 'a = 1\na = 1\n', 'b = 1') + self.mod.write("a = 1\na = 1\na = 1\n") + refactoring = restructure.Restructure(self.project, "a = 1\na = 1\n", "b = 1") self.project.do(refactoring.get_changes()) - self.assertEqual('b = 1\na = 1\n', self.mod.read()) + self.assertEqual("b = 1\na = 1\n", self.mod.read()) def test_preventing_stack_overflow_when_matching(self): - self.mod.write('1\n') - refactoring = restructure.Restructure(self.project, '${a}', '${a}') + self.mod.write("1\n") + refactoring = restructure.Restructure(self.project, "${a}", "${a}") self.project.do(refactoring.get_changes()) - self.assertEqual('1\n', self.mod.read()) + self.assertEqual("1\n", self.mod.read()) def test_performing_a_restructuring_to_all_modules(self): - mod2 = testutils.create_module(self.project, 'mod2') - self.mod.write('a = 1\n') - mod2.write('b = 1\n') - refactoring = restructure.Restructure(self.project, '1', '2 / 1') + mod2 = testutils.create_module(self.project, "mod2") + self.mod.write("a = 1\n") + mod2.write("b = 1\n") + refactoring = restructure.Restructure(self.project, "1", "2 / 1") self.project.do(refactoring.get_changes()) - self.assertEqual('a = 2 / 1\n', self.mod.read()) - self.assertEqual('b = 2 / 1\n', mod2.read()) + self.assertEqual("a = 2 / 1\n", self.mod.read()) + self.assertEqual("b = 2 / 1\n", mod2.read()) def test_performing_a_restructuring_to_selected_modules(self): - mod2 = testutils.create_module(self.project, 'mod2') - self.mod.write('a = 1\n') - mod2.write('b = 1\n') - refactoring = restructure.Restructure(self.project, '1', '2 / 1') + mod2 = testutils.create_module(self.project, "mod2") + self.mod.write("a = 1\n") + mod2.write("b = 1\n") + refactoring = restructure.Restructure(self.project, "1", "2 / 1") self.project.do(refactoring.get_changes(resources=[mod2])) - self.assertEqual('a = 1\n', self.mod.read()) - self.assertEqual('b = 2 / 1\n', mod2.read()) + self.assertEqual("a = 1\n", self.mod.read()) + self.assertEqual("b = 2 / 1\n", mod2.read()) def test_unsure_argument_of_default_wildcard(self): self.mod.write('def f(p):\n return p * 2\nx = "" * 2\ni = 1 * 2\n') refactoring = restructure.Restructure( - self.project, '${s} * 2', 'dup(${s})', - args={'s': {'type': '__builtins__.str', 'unsure': True}}) + self.project, + "${s} * 2", + "dup(${s})", + args={"s": {"type": "__builtins__.str", "unsure": True}}, + ) self.project.do(refactoring.get_changes()) - self.assertEqual('def f(p):\n return dup(p)\nx = dup("")\n' - 'i = 1 * 2\n', self.mod.read()) + self.assertEqual( + 'def f(p):\n return dup(p)\nx = dup("")\n' "i = 1 * 2\n", self.mod.read() + ) def test_statement_after_string_and_column(self): mod_text = 'def f(x):\n if a == "a": raise Exception("test")\n' self.mod.write(mod_text) - refactoring = restructure.Restructure(self.project, '${a}', '${a}') + refactoring = restructure.Restructure(self.project, "${a}", "${a}") self.project.do(refactoring.get_changes()) self.assertEqual(mod_text, self.mod.read()) diff --git a/ropetest/refactor/similarfindertest.py b/ropetest/refactor/similarfindertest.py index 0ffb26c65..6f9068bf9 100644 --- a/ropetest/refactor/similarfindertest.py +++ b/ropetest/refactor/similarfindertest.py @@ -8,11 +8,10 @@ class SimilarFinderTest(unittest.TestCase): - def setUp(self): super(SimilarFinderTest, self).setUp() self.project = testutils.sample_project() - self.mod = testutils.create_module(self.project, 'mod') + self.mod = testutils.create_module(self.project, "mod") def tearDown(self): testutils.remove_project(self.project) @@ -24,247 +23,247 @@ def _create_finder(self, source, **kwds): return similarfinder.SimilarFinder(pymodule, **kwds) def test_trivial_case(self): - finder = self._create_finder('') - self.assertEqual([], list(finder.get_match_regions('10'))) + finder = self._create_finder("") + self.assertEqual([], list(finder.get_match_regions("10"))) def test_constant_integer(self): - source = 'a = 10\n' + source = "a = 10\n" finder = self._create_finder(source) - result = [(source.index('10'), source.index('10') + 2)] - self.assertEqual(result, list(finder.get_match_regions('10'))) + result = [(source.index("10"), source.index("10") + 2)] + self.assertEqual(result, list(finder.get_match_regions("10"))) def test_bool_is_not_similar_to_integer(self): - source = 'a = False\nb = 0' + source = "a = False\nb = 0" finder = self._create_finder(source) - result = [(source.index('False'), source.index('False') + len('False'))] - self.assertEqual(result, list(finder.get_match_regions('False'))) + result = [(source.index("False"), source.index("False") + len("False"))] + self.assertEqual(result, list(finder.get_match_regions("False"))) def test_simple_addition(self): - source = 'a = 1 + 2\n' + source = "a = 1 + 2\n" finder = self._create_finder(source) - result = [(source.index('1'), source.index('2') + 1)] - self.assertEqual(result, list(finder.get_match_regions('1 + 2'))) + result = [(source.index("1"), source.index("2") + 1)] + self.assertEqual(result, list(finder.get_match_regions("1 + 2"))) def test_simple_addition2(self): - source = 'a = 1 +2\n' + source = "a = 1 +2\n" finder = self._create_finder(source) - result = [(source.index('1'), source.index('2') + 1)] - self.assertEqual(result, list(finder.get_match_regions('1 + 2'))) + result = [(source.index("1"), source.index("2") + 1)] + self.assertEqual(result, list(finder.get_match_regions("1 + 2"))) def test_simple_assign_statements(self): - source = 'a = 1 + 2\n' + source = "a = 1 + 2\n" finder = self._create_finder(source) - self.assertEqual([(0, len(source) - 1)], - list(finder.get_match_regions('a = 1 + 2'))) + self.assertEqual( + [(0, len(source) - 1)], list(finder.get_match_regions("a = 1 + 2")) + ) def test_simple_multiline_statements(self): - source = 'a = 1\nb = 2\n' + source = "a = 1\nb = 2\n" finder = self._create_finder(source) - self.assertEqual([(0, len(source) - 1)], - list(finder.get_match_regions('a = 1\nb = 2'))) + self.assertEqual( + [(0, len(source) - 1)], list(finder.get_match_regions("a = 1\nb = 2")) + ) def test_multiple_matches(self): - source = 'a = 1 + 1\n' + source = "a = 1 + 1\n" finder = self._create_finder(source) - result = list(finder.get_match_regions('1')) + result = list(finder.get_match_regions("1")) self.assertEqual(2, len(result)) - start1 = source.index('1') + start1 = source.index("1") self.assertEqual((start1, start1 + 1), result[0]) - start2 = source.rindex('1') + start2 = source.rindex("1") self.assertEqual((start2, start2 + 1), result[1]) def test_multiple_matches2(self): - source = 'a = 1\nb = 2\n\na = 1\nb = 2\n' + source = "a = 1\nb = 2\n\na = 1\nb = 2\n" finder = self._create_finder(source) - self.assertEqual( - 2, len(list(finder.get_match_regions('a = 1\nb = 2')))) + self.assertEqual(2, len(list(finder.get_match_regions("a = 1\nb = 2")))) def test_restricting_the_region_to_search(self): - source = '1\n\n1\n' + source = "1\n\n1\n" finder = self._create_finder(source) - result = list(finder.get_match_regions('1', start=2)) - start = source.rfind('1') + result = list(finder.get_match_regions("1", start=2)) + start = source.rfind("1") self.assertEqual([(start, start + 1)], result) def test_matching_basic_patterns(self): - source = 'b = a\n' + source = "b = a\n" finder = self._create_finder(source) - result = list(finder.get_match_regions('${a}', args={'a': 'exact'})) - start = source.rfind('a') + result = list(finder.get_match_regions("${a}", args={"a": "exact"})) + start = source.rfind("a") self.assertEqual([(start, start + 1)], result) def test_match_get_ast(self): - source = 'b = a\n' + source = "b = a\n" finder = self._create_finder(source) - result = list(finder.get_matches('${a}', args={'a': 'exact'})) - self.assertEqual('a', result[0].get_ast('a').id) + result = list(finder.get_matches("${a}", args={"a": "exact"})) + self.assertEqual("a", result[0].get_ast("a").id) def test_match_get_ast_for_statements(self): - source = 'b = a\n' + source = "b = a\n" finder = self._create_finder(source) - result = list(finder.get_matches('b = ${a}')) - self.assertEqual('a', result[0].get_ast('a').id) + result = list(finder.get_matches("b = ${a}")) + self.assertEqual("a", result[0].get_ast("a").id) def test_matching_multiple_patterns(self): - source = 'c = a + b\n' + source = "c = a + b\n" finder = self._create_finder(source) - result = list(finder.get_matches('${a} + ${b}')) - self.assertEqual('a', result[0].get_ast('a').id) - self.assertEqual('b', result[0].get_ast('b').id) + result = list(finder.get_matches("${a} + ${b}")) + self.assertEqual("a", result[0].get_ast("a").id) + self.assertEqual("b", result[0].get_ast("b").id) def test_matching_any_patterns(self): - source = 'b = a\n' + source = "b = a\n" finder = self._create_finder(source) - result = list(finder.get_matches('b = ${x}')) - self.assertEqual('a', result[0].get_ast('x').id) + result = list(finder.get_matches("b = ${x}")) + self.assertEqual("a", result[0].get_ast("x").id) def test_matching_any_patterns_repeating(self): - source = 'b = 1 + 1\n' + source = "b = 1 + 1\n" finder = self._create_finder(source) - result = list(finder.get_matches('b = ${x} + ${x}')) - self.assertEqual(1, result[0].get_ast('x').n) + result = list(finder.get_matches("b = ${x} + ${x}")) + self.assertEqual(1, result[0].get_ast("x").n) def test_matching_any_patterns_not_matching_different_nodes(self): - source = 'b = 1 + 2\n' + source = "b = 1 + 2\n" finder = self._create_finder(source) - result = list(finder.get_matches('b = ${x} + ${x}')) + result = list(finder.get_matches("b = ${x} + ${x}")) self.assertEqual(0, len(result)) def test_matching_normal_names_and_assname(self): - source = 'a = 1\n' + source = "a = 1\n" finder = self._create_finder(source) - result = list(finder.get_matches('${a} = 1')) - self.assertEqual('a', result[0].get_ast('a').id) + result = list(finder.get_matches("${a} = 1")) + self.assertEqual("a", result[0].get_ast("a").id) def test_matching_normal_names_and_assname2(self): - source = 'a = 1\n' + source = "a = 1\n" finder = self._create_finder(source) - result = list(finder.get_matches('${a}', args={'a': 'exact'})) + result = list(finder.get_matches("${a}", args={"a": "exact"})) self.assertEqual(1, len(result)) def test_matching_normal_names_and_attributes(self): - source = 'x.a = 1\n' + source = "x.a = 1\n" finder = self._create_finder(source) - result = list(finder.get_matches('${a} = 1', args={'a': 'exact'})) + result = list(finder.get_matches("${a} = 1", args={"a": "exact"})) self.assertEqual(0, len(result)) def test_functions_not_matching_when_only_first_parameters(self): - source = 'f(1, 2)\n' + source = "f(1, 2)\n" finder = self._create_finder(source) - self.assertEqual(0, len(list(finder.get_matches('f(1)')))) + self.assertEqual(0, len(list(finder.get_matches("f(1)")))) def test_matching_nested_try_finally(self): - source = 'if 1:\n try:\n pass\n except:\n pass\n' - pattern = 'try:\n pass\nexcept:\n pass\n' + source = "if 1:\n try:\n pass\n except:\n pass\n" + pattern = "try:\n pass\nexcept:\n pass\n" finder = self._create_finder(source) self.assertEqual(1, len(list(finder.get_matches(pattern)))) def test_matching_dicts_inside_functions(self): - source = 'def f(p):\n d = {1: p.x}\n' - pattern = '{1: ${a}.x}' + source = "def f(p):\n d = {1: p.x}\n" + pattern = "{1: ${a}.x}" finder = self._create_finder(source) self.assertEqual(1, len(list(finder.get_matches(pattern)))) class CheckingFinderTest(unittest.TestCase): - def setUp(self): super(CheckingFinderTest, self).setUp() self.project = testutils.sample_project() - self.mod1 = testutils.create_module(self.project, 'mod1') + self.mod1 = testutils.create_module(self.project, "mod1") def tearDown(self): testutils.remove_project(self.project) super(CheckingFinderTest, self).tearDown() def test_trivial_case(self): - self.mod1.write('') + self.mod1.write("") pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - self.assertEqual([], list(finder.get_matches('10', {}))) + self.assertEqual([], list(finder.get_matches("10", {}))) def test_simple_finding(self): - self.mod1.write('class A(object):\n pass\na = A()\n') + self.mod1.write("class A(object):\n pass\na = A()\n") pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${anything} = ${A}()', {})) + result = list(finder.get_matches("${anything} = ${A}()", {})) self.assertEqual(1, len(result)) def test_not_matching_when_the_name_does_not_match(self): - self.mod1.write('class A(object):\n pass\na = list()\n') + self.mod1.write("class A(object):\n pass\na = list()\n") pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${anything} = ${C}()', - {'C': 'name=mod1.A'})) + result = list(finder.get_matches("${anything} = ${C}()", {"C": "name=mod1.A"})) self.assertEqual(0, len(result)) def test_not_matching_unknowns_finding(self): - self.mod1.write('class A(object):\n pass\na = unknown()\n') + self.mod1.write("class A(object):\n pass\na = unknown()\n") pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${anything} = ${C}()', - {'C': 'name=mod1.A'})) + result = list(finder.get_matches("${anything} = ${C}()", {"C": "name=mod1.A"})) self.assertEqual(0, len(result)) def test_finding_and_matching_pyobjects(self): - source = 'class A(object):\n pass\nNewA = A\na = NewA()\n' + source = "class A(object):\n pass\nNewA = A\na = NewA()\n" self.mod1.write(source) pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${anything} = ${A}()', - {'A': 'object=mod1.A'})) + result = list( + finder.get_matches("${anything} = ${A}()", {"A": "object=mod1.A"}) + ) self.assertEqual(1, len(result)) - start = source.rindex('a =') + start = source.rindex("a =") self.assertEqual((start, len(source) - 1), result[0].get_region()) def test_finding_and_matching_types(self): - source = 'class A(object):\n def f(self):\n pass\n' \ - 'a = A()\nb = a.f()\n' + source = ( + "class A(object):\n def f(self):\n pass\n" "a = A()\nb = a.f()\n" + ) self.mod1.write(source) pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${anything} = ${inst}.f()', - {'inst': 'type=mod1.A'})) + result = list( + finder.get_matches("${anything} = ${inst}.f()", {"inst": "type=mod1.A"}) + ) self.assertEqual(1, len(result)) - start = source.rindex('b') + start = source.rindex("b") self.assertEqual((start, len(source) - 1), result[0].get_region()) def test_checking_the_type_of_an_ass_name_node(self): - self.mod1.write('class A(object):\n pass\nan_a = A()\n') + self.mod1.write("class A(object):\n pass\nan_a = A()\n") pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${a} = ${assigned}', - {'a': 'type=mod1.A'})) + result = list(finder.get_matches("${a} = ${assigned}", {"a": "type=mod1.A"})) self.assertEqual(1, len(result)) def test_checking_instance_of_an_ass_name_node(self): - self.mod1.write('class A(object):\n pass\n' - 'class B(A):\n pass\nb = B()\n') + self.mod1.write( + "class A(object):\n pass\n" "class B(A):\n pass\nb = B()\n" + ) pymodule = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymodule) - result = list(finder.get_matches('${a} = ${assigned}', - {'a': 'instance=mod1.A'})) + result = list( + finder.get_matches("${a} = ${assigned}", {"a": "instance=mod1.A"}) + ) self.assertEqual(1, len(result)) def test_checking_equality_of_imported_pynames(self): - mod2 = testutils.create_module(self.project, 'mod2') - mod2.write('class A(object):\n pass\n') - self.mod1.write('from mod2 import A\nan_a = A()\n') + mod2 = testutils.create_module(self.project, "mod2") + mod2.write("class A(object):\n pass\n") + self.mod1.write("from mod2 import A\nan_a = A()\n") pymod1 = self.project.get_pymodule(self.mod1) finder = similarfinder.SimilarFinder(pymod1) - result = list(finder.get_matches('${a_class}()', - {'a_class': 'name=mod2.A'})) + result = list(finder.get_matches("${a_class}()", {"a_class": "name=mod2.A"})) self.assertEqual(1, len(result)) class TemplateTest(unittest.TestCase): - def test_simple_templates(self): - template = similarfinder.CodeTemplate('${a}\n') - self.assertEqual(set(['a']), set(template.get_names())) + template = similarfinder.CodeTemplate("${a}\n") + self.assertEqual(set(["a"]), set(template.get_names())) def test_ignoring_matches_in_comments(self): - template = similarfinder.CodeTemplate('#${a}\n') + template = similarfinder.CodeTemplate("#${a}\n") self.assertEqual({}.keys(), template.get_names()) def test_ignoring_matches_in_strings(self): @@ -272,9 +271,9 @@ def test_ignoring_matches_in_strings(self): self.assertEqual({}.keys(), template.get_names()) def test_simple_substitution(self): - template = similarfinder.CodeTemplate('${a}\n') - self.assertEqual('b\n', template.substitute({'a': 'b'})) + template = similarfinder.CodeTemplate("${a}\n") + self.assertEqual("b\n", template.substitute({"a": "b"})) def test_substituting_multiple_names(self): - template = similarfinder.CodeTemplate('${a}, ${b}\n') - self.assertEqual('1, 2\n', template.substitute({'a': '1', 'b': '2'})) + template = similarfinder.CodeTemplate("${a}, ${b}\n") + self.assertEqual("1, 2\n", template.substitute({"a": "1", "b": "2"})) diff --git a/ropetest/refactor/suitestest.py b/ropetest/refactor/suitestest.py index 722aecab5..a70506ecd 100644 --- a/ropetest/refactor/suitestest.py +++ b/ropetest/refactor/suitestest.py @@ -8,7 +8,6 @@ class SuiteTest(unittest.TestCase): - def setUp(self): super(SuiteTest, self).setUp() @@ -16,58 +15,55 @@ def tearDown(self): super(SuiteTest, self).tearDown() def test_trivial_case(self): - root = source_suite_tree('') + root = source_suite_tree("") self.assertEqual(1, root.get_start()) self.assertEqual(0, len(root.get_children())) def test_simple_ifs(self): - root = source_suite_tree('if True:\n pass') + root = source_suite_tree("if True:\n pass") self.assertEqual(1, len(root.get_children())) def test_simple_else(self): - root = source_suite_tree( - 'if True:\n pass\nelse:\n pass\n') + root = source_suite_tree("if True:\n pass\nelse:\n pass\n") self.assertEqual(2, len(root.get_children())) self.assertEqual(1, root.get_children()[1].get_start()) def test_for(self): - root = source_suite_tree( - '\nfor i in range(10):\n pass\nelse:\n pass\n') + root = source_suite_tree("\nfor i in range(10):\n pass\nelse:\n pass\n") self.assertEqual(2, len(root.get_children())) self.assertEqual(2, root.get_children()[1].get_start()) def test_while(self): - root = source_suite_tree( - 'while True:\n pass\n') + root = source_suite_tree("while True:\n pass\n") self.assertEqual(1, len(root.get_children())) self.assertEqual(1, root.get_children()[0].get_start()) def test_with(self): root = source_suite_tree( - 'from __future__ import with_statement\nwith file(x): pass\n') + "from __future__ import with_statement\nwith file(x): pass\n" + ) self.assertEqual(1, len(root.get_children())) self.assertEqual(2, root.get_children()[0].get_start()) def test_try_finally(self): - root = source_suite_tree( - 'try:\n pass\nfinally:\n pass\n') + root = source_suite_tree("try:\n pass\nfinally:\n pass\n") self.assertEqual(2, len(root.get_children())) self.assertEqual(1, root.get_children()[0].get_start()) def test_try_except(self): - root = source_suite_tree( - 'try:\n pass\nexcept:\n pass\nelse:\n pass\n') + root = source_suite_tree("try:\n pass\nexcept:\n pass\nelse:\n pass\n") self.assertEqual(3, len(root.get_children())) self.assertEqual(1, root.get_children()[2].get_start()) def test_try_except_finally(self): root = source_suite_tree( - 'try:\n pass\nexcept:\n pass\nfinally:\n pass\n') + "try:\n pass\nexcept:\n pass\nfinally:\n pass\n" + ) self.assertEqual(3, len(root.get_children())) self.assertEqual(1, root.get_children()[2].get_start()) def test_local_start_and_end(self): - root = source_suite_tree('if True:\n pass\nelse:\n pass\n') + root = source_suite_tree("if True:\n pass\nelse:\n pass\n") self.assertEqual(1, root.local_start()) self.assertEqual(4, root.local_end()) if_suite = root.get_children()[0] @@ -78,17 +74,16 @@ def test_local_start_and_end(self): self.assertEqual(4, else_suite.local_end()) def test_find_suite(self): - root = source_suite_tree('\n') + root = source_suite_tree("\n") self.assertEqual(root, root.find_suite(1)) def test_find_suite_for_ifs(self): - root = source_suite_tree('if True:\n pass\n') + root = source_suite_tree("if True:\n pass\n") if_suite = root.get_children()[0] self.assertEqual(if_suite, root.find_suite(2)) def test_find_suite_for_between_suites(self): - root = source_suite_tree( - 'if True:\n pass\nprint(1)\nif True:\n pass\n') + root = source_suite_tree("if True:\n pass\nprint(1)\nif True:\n pass\n") if_suite1 = root.get_children()[0] if_suite2 = root.get_children()[1] self.assertEqual(if_suite1, root.find_suite(2)) @@ -96,41 +91,39 @@ def test_find_suite_for_between_suites(self): self.assertEqual(root, root.find_suite(3)) def test_simple_find_visible(self): - root = source_suite_tree('a = 1\n') + root = source_suite_tree("a = 1\n") self.assertEqual(1, suites.find_visible_for_suite(root, [1])) def test_simple_find_visible_ifs(self): - root = source_suite_tree('\nif True:\n a = 1\n b = 2\n') + root = source_suite_tree("\nif True:\n a = 1\n b = 2\n") self.assertEqual(root.find_suite(3), root.find_suite(4)) self.assertEqual(3, suites.find_visible_for_suite(root, [3, 4])) def test_simple_find_visible_for_else(self): - root = source_suite_tree('\nif True:\n pass\nelse: pass\n') + root = source_suite_tree("\nif True:\n pass\nelse: pass\n") self.assertEqual(2, suites.find_visible_for_suite(root, [2, 4])) def test_simple_find_visible_for_different_suites(self): - root = source_suite_tree('if True:\n pass\na = 1\n' - 'if False:\n pass\n') + root = source_suite_tree("if True:\n pass\na = 1\n" "if False:\n pass\n") self.assertEqual(1, suites.find_visible_for_suite(root, [2, 3])) self.assertEqual(5, suites.find_visible_for_suite(root, [5])) self.assertEqual(1, suites.find_visible_for_suite(root, [2, 5])) def test_not_always_selecting_scope_start(self): root = source_suite_tree( - 'if True:\n a = 1\n if True:\n pass\n' - ' else:\n pass\n') + "if True:\n a = 1\n if True:\n pass\n" + " else:\n pass\n" + ) self.assertEqual(3, suites.find_visible_for_suite(root, [4, 6])) self.assertEqual(3, suites.find_visible_for_suite(root, [3, 5])) self.assertEqual(3, suites.find_visible_for_suite(root, [4, 5])) def test_ignoring_functions(self): - root = source_suite_tree( - 'def f():\n pass\na = 1\n') + root = source_suite_tree("def f():\n pass\na = 1\n") self.assertEqual(3, suites.find_visible_for_suite(root, [2, 3])) def test_ignoring_classes(self): - root = source_suite_tree( - 'a = 1\nclass C():\n pass\n') + root = source_suite_tree("a = 1\nclass C():\n pass\n") self.assertEqual(1, suites.find_visible_for_suite(root, [1, 3])) diff --git a/ropetest/refactor/usefunctiontest.py b/ropetest/refactor/usefunctiontest.py index e5f4977cf..6c6fd2a1e 100644 --- a/ropetest/refactor/usefunctiontest.py +++ b/ropetest/refactor/usefunctiontest.py @@ -9,113 +9,110 @@ class UseFunctionTest(unittest.TestCase): - def setUp(self): super(UseFunctionTest, self).setUp() self.project = testutils.sample_project() - self.mod1 = testutils.create_module(self.project, 'mod1') - self.mod2 = testutils.create_module(self.project, 'mod2') + self.mod1 = testutils.create_module(self.project, "mod1") + self.mod2 = testutils.create_module(self.project, "mod2") def tearDown(self): testutils.remove_project(self.project) super(UseFunctionTest, self).tearDown() def test_simple_case(self): - code = 'def f():\n pass\n' + code = "def f():\n pass\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) self.assertEqual(code, self.mod1.read()) def test_simple_function(self): - code = 'def f(p):\n print(p)\nprint(1)\n' + code = "def f(p):\n print(p)\nprint(1)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) - self.assertEqual('def f(p):\n print(p)\nf(1)\n', - self.mod1.read()) + self.assertEqual("def f(p):\n print(p)\nf(1)\n", self.mod1.read()) def test_simple_function2(self): - code = 'def f(p):\n print(p + 1)\nprint(1 + 1)\n' + code = "def f(p):\n print(p + 1)\nprint(1 + 1)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) - self.assertEqual('def f(p):\n print(p + 1)\nf(1)\n', - self.mod1.read()) + self.assertEqual("def f(p):\n print(p + 1)\nf(1)\n", self.mod1.read()) def test_functions_with_multiple_statements(self): - code = 'def f(p):\n r = p + 1\n print(r)\nr = 2 + 1\nprint(r)\n' + code = "def f(p):\n r = p + 1\n print(r)\nr = 2 + 1\nprint(r)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) - self.assertEqual('def f(p):\n r = p + 1\n print(r)\nf(2)\n', - self.mod1.read()) + self.assertEqual( + "def f(p):\n r = p + 1\n print(r)\nf(2)\n", self.mod1.read() + ) def test_returning(self): - code = 'def f(p):\n return p + 1\nr = 2 + 1\nprint(r)\n' + code = "def f(p):\n return p + 1\nr = 2 + 1\nprint(r)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) self.assertEqual( - 'def f(p):\n return p + 1\nr = f(2)\nprint(r)\n', - self.mod1.read()) + "def f(p):\n return p + 1\nr = f(2)\nprint(r)\n", self.mod1.read() + ) def test_returning_a_single_expression(self): - code = 'def f(p):\n return p + 1\nprint(2 + 1)\n' + code = "def f(p):\n return p + 1\nprint(2 + 1)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) - self.assertEqual( - 'def f(p):\n return p + 1\nprint(f(2))\n', - self.mod1.read()) + self.assertEqual("def f(p):\n return p + 1\nprint(f(2))\n", self.mod1.read()) def test_occurrences_in_other_modules(self): - code = 'def f(p):\n return p + 1\n' + code = "def f(p):\n return p + 1\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) - self.mod2.write('print(2 + 1)\n') + user = UseFunction(self.project, self.mod1, code.rindex("f")) + self.mod2.write("print(2 + 1)\n") self.project.do(user.get_changes()) - self.assertEqual('import mod1\nprint(mod1.f(2))\n', - self.mod2.read()) + self.assertEqual("import mod1\nprint(mod1.f(2))\n", self.mod2.read()) def test_when_performing_on_non_functions(self): - code = 'var = 1\n' + code = "var = 1\n" self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - UseFunction(self.project, self.mod1, code.rindex('var')) + UseFunction(self.project, self.mod1, code.rindex("var")) def test_differing_in_the_inner_temp_names(self): - code = 'def f(p):\n a = p + 1\n print(a)\nb = 2 + 1\nprint(b)\n' + code = "def f(p):\n a = p + 1\n print(a)\nb = 2 + 1\nprint(b)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) - self.assertEqual('def f(p):\n a = p + 1\n print(a)\nf(2)\n', - self.mod1.read()) + self.assertEqual( + "def f(p):\n a = p + 1\n print(a)\nf(2)\n", self.mod1.read() + ) # TODO: probably new options should be added to restructure def xxx_test_being_a_bit_more_intelligent_when_returning_assigneds(self): - code = 'def f(p):\n a = p + 1\n return a\n'\ - 'var = 2 + 1\nprint(var)\n' + code = "def f(p):\n a = p + 1\n return a\n" "var = 2 + 1\nprint(var)\n" self.mod1.write(code) - user = UseFunction(self.project, self.mod1, code.rindex('f')) + user = UseFunction(self.project, self.mod1, code.rindex("f")) self.project.do(user.get_changes()) - self.assertEqual('def f(p):\n a = p + 1\n return a\n' - 'var = f(p)\nprint(var)\n', self.mod1.read()) + self.assertEqual( + "def f(p):\n a = p + 1\n return a\n" "var = f(p)\nprint(var)\n", + self.mod1.read(), + ) def test_exception_when_performing_a_function_with_yield(self): - code = 'def func():\n yield 1\n' + code = "def func():\n yield 1\n" self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - UseFunction(self.project, self.mod1, code.index('func')) + UseFunction(self.project, self.mod1, code.index("func")) def test_exception_when_performing_a_function_two_returns(self): - code = 'def func():\n return 1\n return 2\n' + code = "def func():\n return 1\n return 2\n" self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - UseFunction(self.project, self.mod1, code.index('func')) + UseFunction(self.project, self.mod1, code.index("func")) def test_exception_when_returns_is_not_the_last_statement(self): - code = 'def func():\n return 2\n a = 1\n' + code = "def func():\n return 2\n a = 1\n" self.mod1.write(code) with self.assertRaises(exceptions.RefactoringError): - UseFunction(self.project, self.mod1, code.index('func')) + UseFunction(self.project, self.mod1, code.index("func")) diff --git a/ropetest/runmodtest.py b/ropetest/runmodtest.py index 43816d19e..5c66132a8 100644 --- a/ropetest/runmodtest.py +++ b/ropetest/runmodtest.py @@ -1,4 +1,5 @@ import os + try: import unittest2 as unittest except ImportError: @@ -10,7 +11,6 @@ class PythonFileRunnerTest(unittest.TestCase): - def setUp(self): super(PythonFileRunnerTest, self).setUp() self.project = testutils.sample_project() @@ -20,139 +20,147 @@ def tearDown(self): testutils.remove_project(self.project) super(PythonFileRunnerTest, self).tearDown() - def make_sample_python_file(self, file_path, - get_text_function_source=None): + def make_sample_python_file(self, file_path, get_text_function_source=None): self.project.root.create_file(file_path) file = self.project.get_resource(file_path) if not get_text_function_source: get_text_function_source = "def get_text():\n return 'run'\n\n" - file_content = get_text_function_source + \ - "output = open('output.txt', 'w')\n" \ + file_content = ( + get_text_function_source + "output = open('output.txt', 'w')\n" "output.write(get_text())\noutput.close()\n" + ) file.write(file_content) def get_output_file_content(self, file_path): try: - output_path = '' - last_slash = file_path.rfind('/') + output_path = "" + last_slash = file_path.rfind("/") if last_slash != -1: - output_path = file_path[0:last_slash + 1] - file = self.project.get_resource(output_path + 'output.txt') + output_path = file_path[0 : last_slash + 1] + file = self.project.get_resource(output_path + "output.txt") return file.read() except exceptions.ResourceNotFoundError: - return '' + return "" def test_making_runner(self): - file_path = 'sample.py' + file_path = "sample.py" self.make_sample_python_file(file_path) file_resource = self.project.get_resource(file_path) runner = self.pycore.run_module(file_resource) runner.wait_process() - self.assertEqual('run', self.get_output_file_content(file_path)) + self.assertEqual("run", self.get_output_file_content(file_path)) def test_passing_arguments(self): - file_path = 'sample.py' - function_source = 'import sys\ndef get_text():' \ - '\n return str(sys.argv[1:])\n' + file_path = "sample.py" + function_source = ( + "import sys\ndef get_text():" "\n return str(sys.argv[1:])\n" + ) self.make_sample_python_file(file_path, function_source) file_resource = self.project.get_resource(file_path) - runner = self.pycore.run_module(file_resource, args=['hello', 'world']) + runner = self.pycore.run_module(file_resource, args=["hello", "world"]) runner.wait_process() - self.assertTrue(self.get_output_file_content( - file_path).endswith("['hello', 'world']")) + self.assertTrue( + self.get_output_file_content(file_path).endswith("['hello', 'world']") + ) def test_passing_arguments_with_spaces(self): - file_path = 'sample.py' - function_source = 'import sys\ndef get_text():' \ - '\n return str(sys.argv[1:])\n' + file_path = "sample.py" + function_source = ( + "import sys\ndef get_text():" "\n return str(sys.argv[1:])\n" + ) self.make_sample_python_file(file_path, function_source) file_resource = self.project.get_resource(file_path) - runner = self.pycore.run_module(file_resource, args=['hello world']) + runner = self.pycore.run_module(file_resource, args=["hello world"]) runner.wait_process() - self.assertTrue(self.get_output_file_content( - file_path).endswith("['hello world']")) + self.assertTrue( + self.get_output_file_content(file_path).endswith("['hello world']") + ) def test_killing_runner(self): - file_path = 'sample.py' - self.make_sample_python_file(file_path, - 'def get_text():' - '\n import time' - '\n time.sleep(1)' - "\n return 'run'\n") + file_path = "sample.py" + self.make_sample_python_file( + file_path, + "def get_text():" + "\n import time" + "\n time.sleep(1)" + "\n return 'run'\n", + ) file_resource = self.project.get_resource(file_path) runner = self.pycore.run_module(file_resource) runner.kill_process() - self.assertEqual('', self.get_output_file_content(file_path)) + self.assertEqual("", self.get_output_file_content(file_path)) def test_running_nested_files(self): - self.project.root.create_folder('src') - file_path = 'src/sample.py' + self.project.root.create_folder("src") + file_path = "src/sample.py" self.make_sample_python_file(file_path) file_resource = self.project.get_resource(file_path) runner = self.pycore.run_module(file_resource) runner.wait_process() - self.assertEqual('run', self.get_output_file_content(file_path)) + self.assertEqual("run", self.get_output_file_content(file_path)) def test_setting_process_input(self): - file_path = 'sample.py' - self.make_sample_python_file(file_path, - "def get_text():" + - "\n import sys" - "\n return sys.stdin.readline()\n") - temp_file_name = 'processtest.tmp' + file_path = "sample.py" + self.make_sample_python_file( + file_path, + "def get_text():" + "\n import sys" + "\n return sys.stdin.readline()\n", + ) + temp_file_name = "processtest.tmp" try: - temp_file = open(temp_file_name, 'w') - temp_file.write('input text\n') + temp_file = open(temp_file_name, "w") + temp_file.write("input text\n") temp_file.close() file_resource = self.project.get_resource(file_path) stdin = open(temp_file_name) runner = self.pycore.run_module(file_resource, stdin=stdin) runner.wait_process() stdin.close() - self.assertEqual('input text\n', - self.get_output_file_content(file_path)) + self.assertEqual("input text\n", self.get_output_file_content(file_path)) finally: os.remove(temp_file_name) def test_setting_process_output(self): - file_path = 'sample.py' - self.make_sample_python_file(file_path, - "def get_text():" + - "\n print('output text')" - "\n return 'run'\n") - temp_file_name = 'processtest.tmp' + file_path = "sample.py" + self.make_sample_python_file( + file_path, + "def get_text():" + "\n print('output text')" "\n return 'run'\n", + ) + temp_file_name = "processtest.tmp" try: file_resource = self.project.get_resource(file_path) - stdout = open(temp_file_name, 'w') + stdout = open(temp_file_name, "w") runner = self.pycore.run_module(file_resource, stdout=stdout) runner.wait_process() stdout.close() - temp_file = open(temp_file_name, 'r') - self.assertEqual('output text\n', temp_file.read()) + temp_file = open(temp_file_name, "r") + self.assertEqual("output text\n", temp_file.read()) temp_file.close() finally: os.remove(temp_file_name) def test_setting_pythonpath(self): - src = self.project.root.create_folder('src') - src.create_file('sample.py') - src.get_child('sample.py').write('def f():\n pass\n') - self.project.root.create_folder('test') - file_path = 'test/test.py' - self.make_sample_python_file(file_path, - "def get_text():\n" - " import sample" - "\n sample.f()\n return'run'\n") + src = self.project.root.create_folder("src") + src.create_file("sample.py") + src.get_child("sample.py").write("def f():\n pass\n") + self.project.root.create_folder("test") + file_path = "test/test.py" + self.make_sample_python_file( + file_path, + "def get_text():\n" + " import sample" + "\n sample.f()\n return'run'\n", + ) file_resource = self.project.get_resource(file_path) runner = self.pycore.run_module(file_resource) runner.wait_process() - self.assertEqual('run', self.get_output_file_content(file_path)) + self.assertEqual("run", self.get_output_file_content(file_path)) def test_making_runner_when_doi_is_disabled(self): - self.project.set('enable_doi', False) - file_path = 'sample.py' + self.project.set("enable_doi", False) + file_path = "sample.py" self.make_sample_python_file(file_path) file_resource = self.project.get_resource(file_path) runner = self.pycore.run_module(file_resource) runner.wait_process() - self.assertEqual('run', self.get_output_file_content(file_path)) + self.assertEqual("run", self.get_output_file_content(file_path)) diff --git a/ropetest/simplifytest.py b/ropetest/simplifytest.py index 77c790a09..55cb207fb 100644 --- a/ropetest/simplifytest.py +++ b/ropetest/simplifytest.py @@ -7,7 +7,6 @@ class SimplifyTest(unittest.TestCase): - def setUp(self): super(SimplifyTest, self).setUp() @@ -15,7 +14,7 @@ def tearDown(self): super(SimplifyTest, self).tearDown() def test_trivial_case(self): - self.assertEqual('', simplify.real_code('')) + self.assertEqual("", simplify.real_code("")) def test_empty_strs(self): code = 's = ""\n' @@ -26,7 +25,7 @@ def test_blanking_strs(self): self.assertEqual('s = " "\n', simplify.real_code(code)) def test_changing_to_double_quotes(self): - code = 's = \'\'\n' + code = "s = ''\n" self.assertEqual('s = ""\n', simplify.real_code(code)) def test_changing_to_double_quotes2(self): @@ -34,30 +33,29 @@ def test_changing_to_double_quotes2(self): self.assertEqual('s = " "\n', simplify.real_code(code)) def test_removing_comments(self): - code = '# c\n' - self.assertEqual(' \n', simplify.real_code(code)) + code = "# c\n" + self.assertEqual(" \n", simplify.real_code(code)) def test_removing_comments_that_contain_strings(self): code = '# "c"\n' - self.assertEqual(' \n', simplify.real_code(code)) + self.assertEqual(" \n", simplify.real_code(code)) def test_removing_strings_containing_comments(self): code = '"#c"\n' self.assertEqual('" "\n', simplify.real_code(code)) def test_joining_implicit_continuations(self): - code = '(\n)\n' - self.assertEqual('( )\n', simplify.real_code(code)) + code = "(\n)\n" + self.assertEqual("( )\n", simplify.real_code(code)) def test_joining_explicit_continuations(self): - code = '1 + \\\n 2\n' - self.assertEqual('1 + 2\n', simplify.real_code(code)) + code = "1 + \\\n 2\n" + self.assertEqual("1 + 2\n", simplify.real_code(code)) def test_replacing_tabs(self): - code = '1\t+\t2\n' - self.assertEqual('1 + 2\n', simplify.real_code(code)) + code = "1\t+\t2\n" + self.assertEqual("1 + 2\n", simplify.real_code(code)) def test_replacing_semicolons(self): - code = 'a = 1;b = 2\n' - self.assertEqual('a = 1\nb = 2\n', simplify.real_code(code)) - + code = "a = 1;b = 2\n" + self.assertEqual("a = 1\nb = 2\n", simplify.real_code(code)) diff --git a/ropetest/testutils.py b/ropetest/testutils.py index 0b4e98ebe..416c78158 100644 --- a/ropetest/testutils.py +++ b/ropetest/testutils.py @@ -2,8 +2,8 @@ import shutil import sys import logging -logging.basicConfig(format='%(levelname)s:%(funcName)s:%(message)s', - level=logging.INFO) + +logging.basicConfig(format="%(levelname)s:%(funcName)s:%(message)s", level=logging.INFO) try: import unittest2 as unittest except ImportError: @@ -15,26 +15,31 @@ def sample_project(root=None, foldername=None, **kwds): if root is None: - root = 'sample_project' + root = "sample_project" if foldername: root = foldername # HACK: Using ``/dev/shm/`` for faster tests - if os.name == 'posix': - if os.path.isdir('/dev/shm') and os.access('/dev/shm', os.W_OK): - root = '/dev/shm/' + root - elif os.path.isdir('/tmp') and os.access('/tmp', os.W_OK): - root = '/tmp/' + root + if os.name == "posix": + if os.path.isdir("/dev/shm") and os.access("/dev/shm", os.W_OK): + root = "/dev/shm/" + root + elif os.path.isdir("/tmp") and os.access("/tmp", os.W_OK): + root = "/tmp/" + root logging.debug("Using %s as root of the project.", root) # Using these prefs for faster tests - prefs = {'save_objectdb': False, 'save_history': False, - 'validate_objectdb': False, 'automatic_soa': False, - 'ignored_resources': ['.ropeproject', '*.pyc'], - 'import_dynload_stdmods': False} + prefs = { + "save_objectdb": False, + "save_history": False, + "validate_objectdb": False, + "automatic_soa": False, + "ignored_resources": [".ropeproject", "*.pyc"], + "import_dynload_stdmods": False, + } prefs.update(kwds) remove_recursively(root) project = rope.base.project.Project(root, **prefs) return project + create_module = generate.create_module create_package = generate.create_package @@ -46,8 +51,9 @@ def remove_project(project): def remove_recursively(path): import time + # windows sometimes raises exceptions instead of removing files - if os.name == 'nt' or sys.platform == 'cygwin': + if os.name == "nt" or sys.platform == "cygwin": for i in range(12): try: _remove_recursively(path) @@ -74,28 +80,32 @@ def only_for(version): """Should be used as a decorator for a unittest.TestCase test method""" return unittest.skipIf( sys.version < version, - 'This test requires at least {0} version of Python.'.format(version)) + "This test requires at least {0} version of Python.".format(version), + ) def only_for_versions_lower(version): """Should be used as a decorator for a unittest.TestCase test method""" return unittest.skipIf( sys.version > version, - 'This test requires version of Python lower than {0}'.format(version)) + "This test requires version of Python lower than {0}".format(version), + ) + def only_for_versions_higher(version): """Should be used as a decorator for a unittest.TestCase test method""" return unittest.skipIf( sys.version < version, - 'This test requires version of Python higher than {0}'.format(version)) + "This test requires version of Python higher than {0}".format(version), + ) + def skipNotPOSIX(): - return unittest.skipIf(os.name != 'posix', - 'This test works only on POSIX') + return unittest.skipIf(os.name != "posix", "This test works only on POSIX") def time_limit(timeout): - if not any(procname in sys.argv[0] for procname in {'pytest', 'py.test'}): + if not any(procname in sys.argv[0] for procname in {"pytest", "py.test"}): # no-op when running tests without pytest return lambda *args, **kwargs: lambda func: func diff --git a/ropetest/type_hinting_test.py b/ropetest/type_hinting_test.py index e34c2b32a..668a90c91 100644 --- a/ropetest/type_hinting_test.py +++ b/ropetest/type_hinting_test.py @@ -10,7 +10,6 @@ class AbstractHintingTest(unittest.TestCase): - def setUp(self): super(AbstractHintingTest, self).setUp() self.project = testutils.sample_project() @@ -28,63 +27,65 @@ def assert_completion_in_result(self, name, scope, result): for proposal in result: if proposal.name == name and proposal.scope == scope: return - self.fail('completion <%s> in scope %r not proposed, available names: %r' % ( - name, - scope, - [(i.name, i.scope) for i in result] - )) + self.fail( + "completion <%s> in scope %r not proposed, available names: %r" + % (name, scope, [(i.name, i.scope) for i in result]) + ) def assert_completion_not_in_result(self, name, scope, result): for proposal in result: if proposal.name == name and proposal.scope == scope: - self.fail('completion <%s> was proposed' % name) + self.fail("completion <%s> was proposed" % name) def run(self, result=None): - if self.__class__.__name__.startswith('Abstract'): + if self.__class__.__name__.startswith("Abstract"): return super(AbstractHintingTest, self).run(result) class DocstringParamHintingTest(AbstractHintingTest): - def test_hint_param(self): - code = 'class Sample(object):\n' \ - ' def a_method(self, a_arg):\n' \ - ' """:type a_arg: threading.Thread"""\n' \ - ' a_arg.is_a' + code = ( + "class Sample(object):\n" + " def a_method(self, a_arg):\n" + ' """:type a_arg: threading.Thread"""\n' + " a_arg.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hierarchical_hint_param(self): - code = 'class ISample(object):\n' \ - ' def a_method(self, a_arg):\n' \ - ' """:type a_arg: threading.Thread"""\n' \ - '\n\n' \ - 'class Sample(ISample):\n' \ - ' def a_method(self, a_arg):\n' \ - ' a_arg.is_a' + code = ( + "class ISample(object):\n" + " def a_method(self, a_arg):\n" + ' """:type a_arg: threading.Thread"""\n' + "\n\n" + "class Sample(ISample):\n" + " def a_method(self, a_arg):\n" + " a_arg.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) class DocstringReturnHintingTest(AbstractHintingTest): - def test_hierarchical_hint_rtype(self): - code = 'class ISample(object):\n' \ - ' def b_method(self):\n' \ - ' """:rtype: threading.Thread"""\n' \ - '\n\n' \ - 'class Sample(ISample):\n' \ - ' def b_method(self):\n' \ - ' pass\n' \ - ' def a_method(self):\n' \ - ' self.b_method().is_a' + code = ( + "class ISample(object):\n" + " def b_method(self):\n" + ' """:rtype: threading.Thread"""\n' + "\n\n" + "class Sample(ISample):\n" + " def b_method(self):\n" + " pass\n" + " def a_method(self):\n" + " self.b_method().is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) class AbstractAssignmentHintingTest(AbstractHintingTest): - def _make_class_hint(self, type_str): raise NotImplementedError @@ -92,258 +93,299 @@ def _make_constructor_hint(self, type_str): raise NotImplementedError def test_hint_attr(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('threading.Thread') + \ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("threading.Thread") + + " def a_method(self):\n" + " self.a_attr.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hierarchical_hint_attr(self): - code = 'class ISample(object):\n' \ - + self._make_class_hint('threading.Thread') + \ - '\n\n' \ - 'class Sample(ISample):\n' \ - ' a_attr = None\n'\ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class ISample(object):\n" + + self._make_class_hint("threading.Thread") + + "\n\n" + "class Sample(ISample):\n" + " a_attr = None\n" + " def a_method(self):\n" + " self.a_attr.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_defined_by_constructor(self): - code = 'class Sample(object):\n' \ - ' def __init__(self, arg):\n' \ - + self._make_constructor_hint('threading.Thread') + \ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class Sample(object):\n" + " def __init__(self, arg):\n" + + self._make_constructor_hint("threading.Thread") + + " def a_method(self):\n" + " self.a_attr.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_attr_redefined_by_constructor(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('threading.Thread') + \ - ' def __init__(self):\n' \ - ' self.a_attr = None\n' \ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("threading.Thread") + + " def __init__(self):\n" + " self.a_attr = None\n" + " def a_method(self):\n" + " self.a_attr.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hierarchical_hint_attr_redefined_by_constructor(self): - code = 'class ISample(object):\n' \ - + self._make_class_hint('threading.Thread') + \ - '\n\n' \ - 'class Sample(ISample):\n' \ - ' def __init__(self):\n' \ - ' self.a_attr = None\n' \ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class ISample(object):\n" + + self._make_class_hint("threading.Thread") + + "\n\n" + "class Sample(ISample):\n" + " def __init__(self):\n" + " self.a_attr = None\n" + " def a_method(self):\n" + " self.a_attr.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_attr_for_pre_defined_type(self): - code = 'class Other(object):\n' \ - ' def is_alive(self):\n' \ - ' pass\n' \ - '\n\n' \ - 'class Sample(object):\n' \ - + self._make_class_hint('Other') + \ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class Other(object):\n" + " def is_alive(self):\n" + " pass\n" + "\n\n" + "class Sample(object):\n" + + self._make_class_hint("Other") + + " def a_method(self):\n" + " self.a_attr.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_attr_for_post_defined_type(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('Other') + \ - ' def a_method(self):\n' \ - ' self.a_attr.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("Other") + + " def a_method(self):\n" + " self.a_attr.is_a" + ) offset = len(code) - code += '\n\n' \ - 'class Other(object):\n' \ - ' def is_alive(self):\n' \ - ' pass\n' + code += ( + "\n\n" "class Other(object):\n" " def is_alive(self):\n" " pass\n" + ) result = self._assist(code, offset) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_list(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('list[threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr:\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("list[threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr:\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_tuple(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('tuple[threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr:\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("tuple[threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr:\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_set(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('set[threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr:\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("set[threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr:\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_iterable(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('collections.Iterable[threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr:\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("collections.Iterable[threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr:\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_iterator(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('collections.Iterator[threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr:\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("collections.Iterator[threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr:\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_dict_key(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('dict[str, threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr.keys():\n' \ - ' i.sta' + code = ( + "class Sample(object):\n" + + self._make_class_hint("dict[str, threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr.keys():\n" + " i.sta" + ) result = self._assist(code) - self.assert_completion_in_result('startswith', 'builtin', result) + self.assert_completion_in_result("startswith", "builtin", result) def test_hint_parametrized_dict_value(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('dict[str, threading.Thread]') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr.values():\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("dict[str, threading.Thread]") + + " def a_method(self):\n" + " for i in self.a_attr.values():\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_parametrized_nested_tuple_list(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('tuple[list[threading.Thread]]') + \ - ' def a_method(self):\n' \ - ' for j in self.a_attr:\n' \ - ' for i in j:\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("tuple[list[threading.Thread]]") + + " def a_method(self):\n" + " for j in self.a_attr:\n" + " for i in j:\n" + " i.is_a" + ) result = self._assist(code) - self.assert_completion_in_result('is_alive', 'attribute', result) + self.assert_completion_in_result("is_alive", "attribute", result) def test_hint_or(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('str | threading.Thread') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr.values():\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("str | threading.Thread") + + " def a_method(self):\n" + " for i in self.a_attr.values():\n" + " i.is_a" + ) result = self._assist(code) # Be sure, there isn't errors currently # self.assert_completion_in_result('is_alive', 'attribute', result) def test_hint_nonexistent(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('sdfdsf.asdfasdf.sdfasdf.Dffg') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr.values():\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("sdfdsf.asdfasdf.sdfasdf.Dffg") + + " def a_method(self):\n" + " for i in self.a_attr.values():\n" + " i.is_a" + ) self._assist(code) def test_hint_invalid_syntax(self): - code = 'class Sample(object):\n' \ - + self._make_class_hint('sdf | & # &*') + \ - ' def a_method(self):\n' \ - ' for i in self.a_attr.values():\n' \ - ' i.is_a' + code = ( + "class Sample(object):\n" + + self._make_class_hint("sdf | & # &*") + + " def a_method(self):\n" + " for i in self.a_attr.values():\n" + " i.is_a" + ) self._assist(code) class DocstringNoneAssignmentHintingTest(AbstractAssignmentHintingTest): - def _make_class_hint(self, type_str): - return ' """:type a_attr: ' + type_str + '"""\n' \ - ' a_attr = None\n' + return ' """:type a_attr: ' + type_str + '"""\n' " a_attr = None\n" def _make_constructor_hint(self, type_str): - return ' """:type arg: ' + type_str + '"""\n' \ - ' self.a_attr = arg\n' + return ( + ' """:type arg: ' + type_str + '"""\n' " self.a_attr = arg\n" + ) class DocstringNotImplementedAssignmentHintingTest(AbstractAssignmentHintingTest): - def _make_class_hint(self, type_str): - return ' """:type a_attr: ' + type_str + '"""\n' \ - ' a_attr = NotImplemented\n' + return ( + ' """:type a_attr: ' + type_str + '"""\n' " a_attr = NotImplemented\n" + ) def _make_constructor_hint(self, type_str): - return ' """:type arg: ' + type_str + '"""\n' \ - ' self.a_attr = arg\n' - + return ( + ' """:type arg: ' + type_str + '"""\n' " self.a_attr = arg\n" + ) class PEP0484CommentNoneAssignmentHintingTest(AbstractAssignmentHintingTest): - def _make_class_hint(self, type_str): - return ' a_attr = None # type: ' + type_str + '\n' + return " a_attr = None # type: " + type_str + "\n" def _make_constructor_hint(self, type_str): - return ' self.a_attr = None # type: ' + type_str + '\n' + return " self.a_attr = None # type: " + type_str + "\n" class PEP0484CommentNotImplementedAssignmentHintingTest(AbstractAssignmentHintingTest): - def _make_class_hint(self, type_str): - return ' a_attr = NotImplemented # type: ' + type_str + '\n' + return " a_attr = NotImplemented # type: " + type_str + "\n" def _make_constructor_hint(self, type_str): - return ' self.a_attr = NotImplemented # type: ' + type_str + '\n' + return " self.a_attr = NotImplemented # type: " + type_str + "\n" class EvaluateTest(unittest.TestCase): - def test_parser(self): tests = [ - ("Foo", - "(name Foo)"), - ("mod1.Foo", - "(name mod1.Foo)"), - ("mod1.mod2.Foo", - "(name mod1.mod2.Foo)"), - ("Foo[Bar]", - "('[' (name Foo) [(name Bar)])"), - ("Foo[Bar1, Bar2, Bar3]", - "('[' (name Foo) [(name Bar1), (name Bar2), (name Bar3)])"), - ("Foo[Bar[Baz]]", - "('[' (name Foo) [('[' (name Bar) [(name Baz)])])"), - ("Foo[Bar1[Baz1], Bar2[Baz2]]", - "('[' (name Foo) [('[' (name Bar1) [(name Baz1)]), ('[' (name Bar2) [(name Baz2)])])"), - ("mod1.mod2.Foo[Bar]", - "('[' (name mod1.mod2.Foo) [(name Bar)])"), - ("mod1.mod2.Foo[mod1.mod2.Bar]", - "('[' (name mod1.mod2.Foo) [(name mod1.mod2.Bar)])"), - ("mod1.mod2.Foo[Bar1, Bar2, Bar3]", - "('[' (name mod1.mod2.Foo) [(name Bar1), (name Bar2), (name Bar3)])"), - ("mod1.mod2.Foo[mod1.mod2.Bar[mod1.mod2.Baz]]", - "('[' (name mod1.mod2.Foo) [('[' (name mod1.mod2.Bar) [(name mod1.mod2.Baz)])])"), - ("mod1.mod2.Foo[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]]", - "('[' (name mod1.mod2.Foo) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])])"), - ("(Foo, Bar) -> Baz", - "('(' [(name Foo), (name Bar)] (name Baz))"), + ("Foo", "(name Foo)"), + ("mod1.Foo", "(name mod1.Foo)"), + ("mod1.mod2.Foo", "(name mod1.mod2.Foo)"), + ("Foo[Bar]", "('[' (name Foo) [(name Bar)])"), ( - "(mod1.mod2.Foo[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]], mod1.mod2.Bar[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]]) -> mod1.mod2.Baz[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]]", - "('(' [('[' (name mod1.mod2.Foo) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])]), ('[' (name mod1.mod2.Bar) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])])] ('[' (name mod1.mod2.Baz) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])]))"), - ("(Foo, Bar) -> Baz | Foo[Bar[Baz]]", - "('|' ('(' [(name Foo), (name Bar)] (name Baz)) ('[' (name Foo) [('[' (name Bar) [(name Baz)])]))"), - ("Foo[Bar[Baz | (Foo, Bar) -> Baz]]", - "('[' (name Foo) [('[' (name Bar) [('|' (name Baz) ('(' [(name Foo), (name Bar)] (name Baz)))])])"), + "Foo[Bar1, Bar2, Bar3]", + "('[' (name Foo) [(name Bar1), (name Bar2), (name Bar3)])", + ), + ("Foo[Bar[Baz]]", "('[' (name Foo) [('[' (name Bar) [(name Baz)])])"), + ( + "Foo[Bar1[Baz1], Bar2[Baz2]]", + "('[' (name Foo) [('[' (name Bar1) [(name Baz1)]), ('[' (name Bar2) [(name Baz2)])])", + ), + ("mod1.mod2.Foo[Bar]", "('[' (name mod1.mod2.Foo) [(name Bar)])"), + ( + "mod1.mod2.Foo[mod1.mod2.Bar]", + "('[' (name mod1.mod2.Foo) [(name mod1.mod2.Bar)])", + ), + ( + "mod1.mod2.Foo[Bar1, Bar2, Bar3]", + "('[' (name mod1.mod2.Foo) [(name Bar1), (name Bar2), (name Bar3)])", + ), + ( + "mod1.mod2.Foo[mod1.mod2.Bar[mod1.mod2.Baz]]", + "('[' (name mod1.mod2.Foo) [('[' (name mod1.mod2.Bar) [(name mod1.mod2.Baz)])])", + ), + ( + "mod1.mod2.Foo[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]]", + "('[' (name mod1.mod2.Foo) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])])", + ), + ("(Foo, Bar) -> Baz", "('(' [(name Foo), (name Bar)] (name Baz))"), + ( + "(mod1.mod2.Foo[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]], mod1.mod2.Bar[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]]) -> mod1.mod2.Baz[mod1.mod2.Bar1[mod1.mod2.Baz1], mod1.mod2.Bar2[mod1.mod2.Baz2]]", + "('(' [('[' (name mod1.mod2.Foo) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])]), ('[' (name mod1.mod2.Bar) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])])] ('[' (name mod1.mod2.Baz) [('[' (name mod1.mod2.Bar1) [(name mod1.mod2.Baz1)]), ('[' (name mod1.mod2.Bar2) [(name mod1.mod2.Baz2)])]))", + ), + ( + "(Foo, Bar) -> Baz | Foo[Bar[Baz]]", + "('|' ('(' [(name Foo), (name Bar)] (name Baz)) ('[' (name Foo) [('[' (name Bar) [(name Baz)])]))", + ), + ( + "Foo[Bar[Baz | (Foo, Bar) -> Baz]]", + "('[' (name Foo) [('[' (name Bar) [('|' (name Baz) ('(' [(name Foo), (name Bar)] (name Baz)))])])", + ), ] for t, expected in tests: @@ -352,19 +394,20 @@ def test_parser(self): class RegressionHintingTest(AbstractHintingTest): - def test_hierarchical_hint_for_mutable_attr_type(self): """Test for #157, AttributeError: 'PyObject' object has no attribute 'get_doc'""" - code = 'class SuperClass(object):\n' \ - ' def __init__(self):\n' \ - ' self.foo = None\n' \ - '\n\n' \ - 'class SubClass(SuperClass):\n' \ - ' def __init__(self):\n' \ - ' super(SubClass, self).__init__()\n' \ - ' self.bar = 3\n' \ - '\n\n' \ - ' def foo(self):\n' \ - ' return self.bar' + code = ( + "class SuperClass(object):\n" + " def __init__(self):\n" + " self.foo = None\n" + "\n\n" + "class SubClass(SuperClass):\n" + " def __init__(self):\n" + " super(SubClass, self).__init__()\n" + " self.bar = 3\n" + "\n\n" + " def foo(self):\n" + " return self.bar" + ) result = self._assist(code) - self.assert_completion_in_result('bar', 'attribute', result) + self.assert_completion_in_result("bar", "attribute", result) diff --git a/setup.py b/setup.py index d3611ccdd..eec776cd4 100644 --- a/setup.py +++ b/setup.py @@ -8,68 +8,73 @@ classifiers = [ - 'Development Status :: 4 - Beta', - 'Operating System :: OS Independent', - 'Environment :: X11 Applications', - 'Environment :: Win32 (MS Windows)', - 'Environment :: MacOS X', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)', - 'Natural Language :: English', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Topic :: Software Development'] + "Development Status :: 4 - Beta", + "Operating System :: OS Independent", + "Environment :: X11 Applications", + "Environment :: Win32 (MS Windows)", + "Environment :: MacOS X", + "Intended Audience :: Developers", + "License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)", + "Natural Language :: English", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Topic :: Software Development", +] def get_long_description(): - lines = io.open('README.rst', 'r', - encoding='utf8').read().splitlines(False) - end = lines.index('Getting Started') - return '\n' + '\n'.join(lines[:end]) + '\n' + lines = io.open("README.rst", "r", encoding="utf8").read().splitlines(False) + end = lines.index("Getting Started") + return "\n" + "\n".join(lines[:end]) + "\n" def get_version(): version = None - with io.open(os.path.join( - os.path.dirname(__file__), 'rope', '__init__.py')) as inif: + with io.open( + os.path.join(os.path.dirname(__file__), "rope", "__init__.py") + ) as inif: for line in inif: - if line.startswith('VERSION'): - version = line.split('=')[1].strip(" \t'\n") + if line.startswith("VERSION"): + version = line.split("=")[1].strip(" \t'\n") break return version -setup(name='rope', - version=get_version(), - description='a python refactoring library...', - long_description=get_long_description(), - long_description_content_type='text/x-rst', - author='Ali Gholami Rudi', - author_email='aligrudi@users.sourceforge.net', - url='https://github.com/python-rope/rope', - packages=['rope', - 'rope.base', - 'rope.base.oi', - 'rope.base.oi.type_hinting', - 'rope.base.oi.type_hinting.providers', - 'rope.base.oi.type_hinting.resolvers', - 'rope.base.utils', - 'rope.contrib', - 'rope.refactor', - 'rope.refactor.importutils'], - license='LGPL-3.0-or-later', - classifiers=classifiers, - extras_require={ - 'dev': [ - 'pytest', - 'pytest-timeout', - ] - }) +setup( + name="rope", + version=get_version(), + description="a python refactoring library...", + long_description=get_long_description(), + long_description_content_type="text/x-rst", + author="Ali Gholami Rudi", + author_email="aligrudi@users.sourceforge.net", + url="https://github.com/python-rope/rope", + packages=[ + "rope", + "rope.base", + "rope.base.oi", + "rope.base.oi.type_hinting", + "rope.base.oi.type_hinting.providers", + "rope.base.oi.type_hinting.resolvers", + "rope.base.utils", + "rope.contrib", + "rope.refactor", + "rope.refactor.importutils", + ], + license="LGPL-3.0-or-later", + classifiers=classifiers, + extras_require={ + "dev": [ + "pytest", + "pytest-timeout", + ] + }, +)