Actions¶
Each rule can return an Action
which is simply an order to refactor. The
main entrypoint for an Action
is it’s apply()
method, where it takes the
source code and returns the refactored version of it.
Action¶
The default apply()
method implemented in refactor.Action
calls the
build()
method, and replaces the original node it got with the new node
returned by the build()
method. Here is a simple action that will replace
+
with -
on a binary operation
class CustomBuildAction(refactor.Action):
def build(self):
new_node = self.branch()
new_node.op = ast.Sub()
return new_node
The branch()
method will return an exact copy of the current node, and
on top of that node we will change the op
to point to ast.Sub
(-
) instead
of ast.Add
(+
).
Let’s also write the rule that uses it:
class ReplaceAdd(refactor.Rule):
def match(self, node):
assert isinstance(node, ast.BinOp)
assert isinstance(node.op, ast.Add)
return CustomBuildAction(node)
ReplacementAction¶
If you want to build the new node in the rule rather than in a custom action’s build()
,
you can simply use the ReplacementAction
. It will replace the node
(first argument)
with the target
(second argument). Also it’s build()
method will return the target
.
class ReplaceAdd(refactor.Rule):
def match(self, node):
assert isinstance(node, ast.BinOp)
assert isinstance(node.op, ast.Add)
new_node = copy.deepcopy(node)
new_node.op = ast.Sub()
return refactor.ReplacementAction(node, new_node)
NewStatementAction¶
If you don’t want to replace a node, but rather add a new statement after that (e.g adding
a new import if it already doesn’t exist), it’s where NewStatementAction
comes to play. Let’s
write a simple example which would add exit()
calls to the end of every main()
function:
@dataclass
class AddExitAction(refactor.NewStatementAction):
status_code: int
def build(self):
return ast.Call(
ast.Name("exit", ast.Load()),
args = [ast.Constant(self.status_code)],
keywords = []
)
class AddExitCalls(refactor.Rule):
def match(self, node):
# find all main() functions
assert isinstance(node, ast.FunctionDef)
assert node.name == 'main'
# ensure the last statement *is not* exit()
last_stmt = node.body[-1]
if (
isinstance(last_stmt, ast.Expr)
and isinstance(call := last_stmt.value, ast.Call)
and isinstance(call.func, ast.Name)
and call.func.id == "exit"
):
return None
# add a new statement after the last statement of
# the current function to call exit()
return AddExitAction(last_stmt, 0)