Assertion rewriting in Pytest part 4: The implementation

Part 1part 2 and part 3 looked at some background to why Pytest does something unusual and the principles behind how it works. Now it’s time to look into the implementation.

First it’s worth fleshing out how the AST stuff fits in to the original motivation of getting better assertions. Our original problem was that we had something like:

assert number_of_the_counting == 3

but we want more diagnostic information to be written out on failure. We could make the original test author write:

if not (number_of_the_counting == 3):
    print("Failed, expected number_of_the_counting == 3")
    print("Actual value=%s" % number_of_the_counting)

but there’s no chance that test authors are going to write this on every assertion they write. We saw how Python code can be converted into a tree structure, and how that tree structure can be converted into code we can execute. Hopefully you can see the possibility to convert the first form into a tree structure, modify the tree structure to represent the second form, then execute the result. Working on the tree structure makes it possible for us to do this without tearing our hair out.

With that background established, let’s look at Pytest, which actually does this stuff for real. The code we’re interested in is in pytest.assertion.rewrite, starting with the class AssertionRewritingHook. This class is an import hook as defined in PEP 302. This means that when this hook is registered, any subsequent call to import triggers code in this class.

There are actually two different objects involved in Python imports, finders and loaders. The finder receives a module name and figures out which file this relates to. The loader actually loads the given file and creates a Python module object. As is common, the AssertionRewritingHook does both jobs, via the find_module and load_module methods.

The find_module method is where the interesting stuff actually happens. Within this method it checks whether the module needs to be rewritten (only test cases get modified, not product code) and generates the rewritten module. The load_module just plucks the result from the cache. find_module has some logic for caching and handling encoding issues etc., but passes control to the AssertionRewriter class to do the rewriting.

The AssertionRewriter follows the visitor pattern, which deals with a tree structure made up of lots of different types of things. In our case, we have an AST with different nodes representing different pieces of code: class definitions, method definitions, variable assignments, function calls etc. The AssertionRewriter goes breadth-first through the tree looking for assertions:

nodes = [mod]
while nodes:
    node = nodes.pop()
    for name, field in ast.iter_fields(node):
        if isinstance(field, list):
            new = []
            for i, child in enumerate(field):
                if isinstance(child, ast.Assert):
                    # Transform assert.
                    new.extend(self.visit(child))
                else:
                    new.append(child)
                    if isinstance(child, ast.AST):
                        nodes.append(child)
            setattr(node, name, new)
        elif (isinstance(field, ast.AST) and
              # Don't recurse into expressions as they can't contain
              # asserts.
              not isinstance(field, ast.expr)):
            nodes.append(field)

The key line is here:

if isinstance(child, ast.Assert):
    # Transform assert.
    new.extend(self.visit(child))

This triggers the actual operation. Everything else is just code for searching through the tree and leaving everything unchanged that isn’t part of an assert statement. The Python AST is quite well structured, but it is still designed primarily for execution and not to make our life easy. Therefore the code is a little more complex than we might like.

The call to self.visit() actually hits a method in the base ast.NodeVisitor class. This forwards the call to a methodvisit_<nodename> if one exists (visit_Assert in our case) or generic_visit otherwise. This is part of the Visitor pattern, and means we can have different code for handling each different case without having to write a tedious if .. elif .. elif dispatch block.

So we end up calling visit_Assert() with the assertion node. If we have an assertion like:

assert myfunc(x) == 42, 'operation failed'

the AST node will look like:

Assert(
    test=Compare(
        left=Call(
            func=Name(id='myfunc', ctx=Load()),
            args=[Name(id='x', ctx=Load())],
            keywords=[]
        ),
        ops=[Eq()],
        comparators=[Num(n=42)]
    ),
    msg=Str(s='operation failed')
)

The test is the thing that we want to annotate, and Pytest deals with it like this:

# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)

This again re-dispatches to the correct visit_... method, in our case visit_Call (which is actually visit_Call_35 on my version of Python, due to a change in Python 3.5 and later).

The function that handles the function call looks like this:

new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
new_kwargs = []
for arg in call.args:
    res, expl = self.visit(arg)
    arg_expls.append(expl)
    new_args.append(res)
for keyword in call.keywords:
    res, expl = self.visit(keyword.value)
    new_kwargs.append(ast.keyword(keyword.arg, res))
    if keyword.arg:
        arg_expls.append(keyword.arg + "=" + expl)
    else:  # **args have `arg` keywords with an .arg of None
        arg_expls.append("**" + expl)
 
expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
return res, outer_expl

Let’s just look at a couple of these lines:

new_func, func_expl = self.visit(call.func)
# ...
new_call = ast.Call(new_func, new_args, new_kwargs)

This operates on the function that is being called, which in our case
is a lookup of a local name. In case that’s not clear: in Python, when
we have a call like:

my_function(42)

there are actually two steps: The name my_function is looked up in the local context to find out which function it refers to. Then the function is executed. At the moment we’re looking at the first of these steps.

The Name node (which looks up the function) is transformed by calling visit, which forwards to visit_Name. The result of this is new_func, which is packed into a new ast.Call node. The resultant AST still has the same effect of calling the function, but has more functionality wrapped round it. The AST will eventually be compiled and executed and will work like the original code, but behave differently in the case where the assertion fails.

OK, I’ve talked about “additional behaviour”, but what exactly does that amount to? When we call:

new_func, func_expl = self.visit(call.func)

what is getting returned?

Actually, the first value returned by visit_Name is just its first parameter, so new_func is actually just call.func. But func_expl contains some code that can be used to generate a human-readable name for the function:

def visit_Name(self, name):
    # Display the repr of the name if it's a local variable or
    # _should_repr_global_name() thinks it's acceptable.
    locs = ast_Call(self.builtin("locals"), [], [])
    inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
    dorepr = self.helper("should_repr_global_name", name)
    test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
    expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
    return name, self.explanation_param(expr)

Let’s build this up a step at a time, assuming we’ve passed in the Name(id='myfunc', ctx=Load()) that was in the AST we examined above:

locs = ast_Call(self.builtin("locals"), [], [])

This is just AST-speak for something like:

locals()

Next:

locs = ast_Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])

We already know what locs is, so we can substitute it into the comparison:

"myfunc" in locals()

The next couple of lines:

dorepr = self.helper("should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])

If we ignore the self.helper stuff for now and just assume this is a call to _should_repr_global_name in the rewritemodule, we can see this becomes:

("myfunc" in locals()) or (_should_repr_global_name(Name(id='myfunc',...)))

This then gets used as a condition in an if expression:

expr = ast.IfExp(test, self.display(name), ast.Str(name.id))

Note that this is an if expression, not
an if statement. In other words, we’re looking at something like this:

repr(myfunc) if (("myfunc" in locals()) or (...)) else "myfunc"

(I glossed over the details of self.display, which actually does a bit more than just generate a call to repr, but that will do for now).

The result of all this shenanigans is to produce two things:

  • an AST that can be evaluated at run time to generate the name of the function
  • func_expl, which is a formatting string with placeholders in it like %(py0)s%(py1)s. These keys can be looked up in a dict stored on the Assertion Rewriter object to obtain the AST that generates the correct diagnostic output.

It’s worth stepping back at this point and reiterating why we do things in this roundabout way: generating an AST that generates the value rather than just generating the value. The answer is that a lot of the values aren’t known until later on when the code is executed. We’ve focused on the function name as an example because it’s simple, but (in most cases) the function name could just be dumped out directly by examining the source code. However, we can’t print the values of the arguments to the function or the return value from the function, because we don’t know what those values are yet.

All the above gets carried out for various other pieces of the input AST: comparisons, boolean operators, arithmetic etc. When the dust settles and we finish processing the visit_Assert, we need to generate a little more code:

# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))

By now you’ve hopefully got the trick of interpreting this code: We’re adding an element to self.statements that is the AST for code like the following:

if not (condition):
    explanation

where condition is just the original expression that was inside the assert statement, and explanation is the code that generates the diagnostics.

We’re finally there. Once the AssertionRewriter finishes walking through the tree, we have an AST that can be passed to compile. A little more housework is required to create a module object and put it in the Python modules list, but nothing terribly complicated.

The amazing thing about this is that the rest of the Python system doesn’t know that there’s anything unusual about this decorated module. Calling the Python code within works the same as calling any other Python module.

I think this is pretty impressive stuff, given how powerful it can be and how (relatively) little internal knowledge is required. Hopefully in the coming months I’ll dig a little deeper into some of this stuff and tear down some other examples.

Leave a Reply

Your email address will not be published. Required fields are marked *