Your first Refactor rule#
If you have already read what is Refactor and curious to learn more about how to actually use it, this is where you should start. In this example, we’ll be building our own little tool to analyze bad code patterns that are nearly impossible to locate with common text-processing tools and we will write a program that can automatically fix them!
Our first transformation#
As for our first ever transformation, we want to write a simple program that would find all additions and subtractions, and if both sides are constant it would transform that operation into its result.
PI = 3.14
TAU = PI + PI
# This is a very interesting constant that can
# solve world's all problems.
WEIRD_MATH_CONSTANT = PI + TAU
def make_computation(x: int, y: int, z: int) -> float:
result = (
WEIRD_MATH_CONSTANT * 2 + (5 + 3) # A very complex math equation
) + 8 # Don't forget the 8 here
# This would help us find the point of origin
result += 2 + 1
return 3.14 - 6.28 + result + z * 2
“contract” + “transformation” = “rule”#
Each separate analyzer in Refactor corresponds to a “rule”, an abstract component that would structurally hold the information regarding what we are looking for (e.g., an addition operation) and how we would transform the source code if we find what we are looking for.
import ast
import refactor
Every rule starts with creating a class that inherits from refactor.Rule
and defines a custom match()
method. The match()
method is the
only entry-point to the rule, and it will be called with every node of the tree during the
traversal (in a breadth-first order).
If the given node
satisfies the criteria of the rule, then match()
method can return a
source code transformation action (which we’ll talk about more the next section) and the
action will be executed immediately.
import ast
import refactor
class FoldMyConstants(refactor.Rule):
def match(self, node: ast.AST) -> refactor.BaseAction:
As the section header implies, a rule consist of two parts: “contract” and “transformation”. Let’s dive in to the first part.
“contract”#
There isn’t a formal definition of “contract”, but it is generally the name we use for the part inside match()
that does the filtering on a sea of AST nodes to find something that fits the rule’s criteria. Let’s write our
contract as per our specification. We want to ensure that the given node
is a binary operation
where both sides are constants, and the operator of this operation is either an addition or a subtraction.
import ast
import refactor
class FoldMyConstants(refactor.Rule):
def match(self, node: ast.AST) -> refactor.BaseAction:
assert isinstance(node, ast.BinOp)
assert isinstance(op := node.op, (ast.Add, ast.Sub))
assert isinstance(left := node.left, ast.Constant)
assert isinstance(right := node.right, ast.Constant)
assert isinstance(l_val := left.value, (int, float))
assert isinstance(r_val := right.value, (int, float))
import ast
import refactor
from typing import Optional
class FoldMyConstants(refactor.Rule):
def match(self, node: ast.AST) -> Optional[refactor.BaseAction]:
if not (
isinstance(node, ast.BinOp)
and isinstance(op := node.op, (ast.Add, ast.Sub))
and isinstance(left := node.left, ast.Constant)
and isinstance(right := node.right, ast.Constant)
and isinstance(l_val := left.value, (int, float))
and isinstance(r_val := right.value, (int, float))
):
return None
As you might have noticed, we offer two different ways of filtering nodes. The first one, assertion based contracts, is
the recommended way of laying out the conditions in a linearized fashion that is easy to read and operate on. Any AssertionError
that might happen on match()
would be a skip signal, so the engine can proceed forward. It is also possible to do
this without assertions (the second way), and just use good old if statements (or even match statements) but don’t forget to return
None
if the given node doesn’t fit your rule’s criteria.
Tip
Our recommendation is using the assertion based contracts for filtering as much as possible, due to their flexibility (you can have an assertion in anywhere in your code, so instead of handling error paths recursively you can just depend on an assert to fail) and how readable they are with the linearized writing model
“transformation”#
If our contract is validated (or if all the conditions are met), we’ll proceed to the second step: “transformation”. This is where we decide what sort of action we are going to take on top of the source code. It is possible to write your own actions, but for most of the use cases you should be able to simply use a built-in one.
What we need in our example is an action called Replace
which takes a node and a target, and
changes the code belonging to the given node with the re-synthesized version of the target. We’ll use it to replace the whole
binary operation with just its result.
import ast
import refactor
class FoldMyConstants(refactor.Rule):
def match(self, node: ast.AST) -> refactor.BaseAction:
assert isinstance(node, ast.BinOp)
assert isinstance(op := node.op, (ast.Add, ast.Sub))
assert isinstance(left := node.left, ast.Constant)
assert isinstance(right := node.right, ast.Constant)
assert isinstance(l_val := left.value, (int, float))
assert isinstance(r_val := right.value, (int, float))
if isinstance(op, ast.Add):
result = ast.Constant(l_val + r_val)
else:
result = ast.Constant(l_val - r_val)
return refactor.Replace(node, result)
And yes, this is it!
Bonus: running the rules#
Now it is time for us to test our first rule and see how it performs. For more advanced use cases,
we’ll show how you can build your own CLI with a session
but for simple programs
like the one above Refactor itself offers a very convenient mechanism to directly convert your rules into
their own CLI tool.
import ast
import refactor
class FoldMyConstants(refactor.Rule):
def match(self, node: ast.AST) -> refactor.BaseAction:
assert isinstance(node, ast.BinOp)
assert isinstance(op := node.op, (ast.Add, ast.Sub))
assert isinstance(left := node.left, ast.Constant)
assert isinstance(right := node.right, ast.Constant)
assert isinstance(l_val := left.value, (int, float))
assert isinstance(r_val := right.value, (int, float))
if isinstance(op, ast.Add):
result = ast.Constant(l_val + r_val)
else:
result = ast.Constant(l_val - r_val)
return refactor.Replace(node, result)
if __name__ == "__main__":
refactor.run(rules=[FoldMyConstants])
And with this here, we can execute our Python script and pass it any file we would like. By default
the CLI program will display the diff, but you can also pass -a
to apply it in-place. Let’s test this on
our example program.
--- program.py
+++ program.py
@@ -7,10 +7,10 @@
def make_computation(x: int, y: int, z: int) -> float:
result = (
- WEIRD_MATH_CONSTANT * 2 + (5 + 3) # A very complex math equation
+ WEIRD_MATH_CONSTANT * 2 + (8) # A very complex math equation
) + 8 # Don't forget the 8 here
# This would help us find the point of origin
- result += 2 + 1
+ result += 3
- return 3.14 - 6.28 + result + z * 2
+ return -3.14 + result + z * 2
As you can see from the diff, it handled the transformation very delicately and changed only what it needed to change.
Going further#
Fortunately in the snippet above, there are still some stuff left that we can work on.
The one that hits to me in the first glance is, both TAU
and WEIRD_MATH_CONSTANT
could have
been folded if they were used as constants instead of names.
PI = 3.14
TAU = PI + PI
# This is a very interesting constant that can
# solve world's all problems.
WEIRD_MATH_CONSTANT = PI + TAU
def make_computation(x: int, y: int, z: int) -> float:
result = (
WEIRD_MATH_CONSTANT * 2 + (5 + 3) # A very complex math equation
) + 8 # Don't forget the 8 here
So why don’t we start writing a new rule that would resolve all the names with a value that is easy to infer. This is called constant propagation, and I think it will play nicely with our constant folding implementation in the first rule.
from refactor.context import Scope
class PropagateMyConstants(refactor.Rule):
context_providers = (Scope,)
We started as usual, by inheriting from Rule
, but this time we have a new
keyword: context_providers
. Each rule in Refactor has a shared state / knowledge base of the current
file. It is accessible by the context
attribute, and contains stuff like the source code, tree, and
most importantly context providers.
Context providers operate on the module-level, unlike to rules which operate on the node-level, and they are really useful
for gathering knowledge from surroundings and sharing them with rules to do more precise analyses. There
are a bunch of built-in ones, like Ancestry
which can help you to backtrack
a node to its parents. But what what need today is Scope
. It allows you to learn
which semantical scope a node belongs to, and where it can access from it.
from refactor.context import Scope
class PropagateMyConstants(refactor.Rule):
context_providers = (Scope,)
def match(self, node: ast.AST) -> refactor.BaseAction:
assert isinstance(node, ast.Name)
assert isinstance(node.ctx, ast.Load)
scope = self.context.scope.resolve(node)
As you can see, once we are in a spot on our contract where we need scope information to decide
further we access it directly under self.context
. After accessing to it, we call resolve()
with the node we want to learn more and it returns a ScopeInfo
record.
Tip
All initialized context providers can be accessed by their name (actually their name’s snake case
version, so Scope
is scope
and SomethingProvider
is something_provider
) so we can simply say
self.context.scope
The record contains information about your current scope, the node that encloses you (semantically), all
the definitions that are made inside your scope and a very important method: get_definitions()
.
It is how we ask Refactor to tell us where a given name (like PI
, or TAU
) is defined. It goes back in all reachable
scopes (like if you are inside a top-level function, and if the name you are accessing is not defined in there then it will check
the global scope) until it finds a set of definitions.
Let’s retrieve the definitions and check if the assigned value is a constant. Also for keeping this example simple, we won’t be considering names that are defined multiple times (e.g., conditionally by an if statement or overwritten in some part of the program) so we want the number of definitions to be only one. And if all of our checks pass, we’ll return an action to replace it.
from refactor.context import Scope
class PropagateMyConstants(refactor.Rule):
context_providers = (Scope,)
def match(self, node: ast.AST) -> refactor.BaseAction:
assert isinstance(node, ast.Name)
assert isinstance(node.ctx, ast.Load)
scope = self.context.scope.resolve(node)
definitions = scope.get_definitions(node.id)
assert len(definitions) == 1
# The definition might be anything, it might be coming
# from an import or it might be function. So we'll only
# allow assignments.
[definition] = definitions
assert isinstance(definition, ast.Assign)
assert isinstance(defined_value := definition.value, ast.Constant)
return refactor.Replace(node, defined_value)
Let’s test out how our rules now perform in three different ways:
--- program.py
+++ program.py
@@ -1,9 +1,9 @@
PI = 3.14
-TAU = PI + PI
+TAU = 3.14 + 3.14
# This is a very interesting constant that can
# solve world's all problems.
-WEIRD_MATH_CONSTANT = PI + TAU
+WEIRD_MATH_CONSTANT = 3.14 + TAU
--- program.py
+++ program.py
@@ -7,10 +7,10 @@
def make_computation(x: int, y: int, z: int) -> float:
result = (
- WEIRD_MATH_CONSTANT * 2 + (5 + 3) # A very complex math equation
+ WEIRD_MATH_CONSTANT * 2 + (8) # A very complex math equation
) + 8 # Don't forget the 8 here
# This would help us find the point of origin
- result += 2 + 1
+ result += 3
- return 3.14 - 6.28 + result + z * 2
+ return -3.14 + result + z * 2
As you can see, this is where the main difference comes from.
--- program.py
+++ program.py
@@ -1,16 +1,16 @@
PI = 3.14
-TAU = PI + PI
+TAU = 6.28
# This is a very interesting constant that can
# solve world's all problems.
-WEIRD_MATH_CONSTANT = PI + TAU
+WEIRD_MATH_CONSTANT = 9.42
def make_computation(x: int, y: int, z: int) -> float:
result = (
- WEIRD_MATH_CONSTANT * 2 + (5 + 3) # A very complex math equation
+ 9.42 * 2 + (8) # A very complex math equation
) + 8 # Don't forget the 8 here
# This would help us find the point of origin
- result += 2 + 1
+ result += 3
- return 3.14 - 6.28 + result + z * 2
+ return -3.14 + result + z * 2
Note
The full code for this example is available at our GitHub repo