Assertion rewriting in Pytest part 4: The implementation

Part 1part 2 and part 3 looked at some background to why Pytest does something unusual and the principles behind how it works. Now it’s time to look into the implementation.

First it’s worth fleshing out how the AST stuff fits in to the original motivation of getting better assertions. Our original problem was that we had something like:

assert number_of_the_counting == 3

but we want more diagnostic information to be written out on failure. We could make the original test author write:

if not (number_of_the_counting == 3):
    print("Failed, expected number_of_the_counting == 3")
    print("Actual value=%s" % number_of_the_counting)

but there’s no chance that test authors are going to write this on every assertion they write. We saw how Python code can be converted into a tree structure, and how that tree structure can be converted into code we can execute. Hopefully you can see the possibility to convert the first form into a tree structure, modify the tree structure to represent the second form, then execute the result. Working on the tree structure makes it possible for us to do this without tearing our hair out.

With that background established, let’s look at Pytest, which actually does this stuff for real. The code we’re interested in is in pytest.assertion.rewrite, starting with the class AssertionRewritingHook. This class is an import hook as defined in PEP 302. This means that when this hook is registered, any subsequent call to import triggers code in this class.

There are actually two different objects involved in Python imports, finders and loaders. The finder receives a module name and figures out which file this relates to. The loader actually loads the given file and creates a Python module object. As is common, the AssertionRewritingHook does both jobs, via the find_module and load_module methods.

The find_module method is where the interesting stuff actually happens. Within this method it checks whether the module needs to be rewritten (only test cases get modified, not product code) and generates the rewritten module. The load_module just plucks the result from the cache. find_module has some logic for caching and handling encoding issues etc., but passes control to the AssertionRewriter class to do the rewriting.

The AssertionRewriter follows the visitor pattern, which deals with a tree structure made up of lots of different types of things. In our case, we have an AST with different nodes representing different pieces of code: class definitions, method definitions, variable assignments, function calls etc. The AssertionRewriter goes breadth-first through the tree looking for assertions:

nodes = [mod]
while nodes:
    node = nodes.pop()
    for name, field in ast.iter_fields(node):
        if isinstance(field, list):
            new = []
            for i, child in enumerate(field):
                if isinstance(child, ast.Assert):
                    # Transform assert.
                    new.extend(self.visit(child))
                else:
                    new.append(child)
                    if isinstance(child, ast.AST):
                        nodes.append(child)
            setattr(node, name, new)
        elif (isinstance(field, ast.AST) and
              # Don't recurse into expressions as they can't contain
              # asserts.
              not isinstance(field, ast.expr)):
            nodes.append(field)

The key line is here:

if isinstance(child, ast.Assert):
    # Transform assert.
    new.extend(self.visit(child))

This triggers the actual operation. Everything else is just code for searching through the tree and leaving everything unchanged that isn’t part of an assert statement. The Python AST is quite well structured, but it is still designed primarily for execution and not to make our life easy. Therefore the code is a little more complex than we might like.

The call to self.visit() actually hits a method in the base ast.NodeVisitor class. This forwards the call to a methodvisit_<nodename> if one exists (visit_Assert in our case) or generic_visit otherwise. This is part of the Visitor pattern, and means we can have different code for handling each different case without having to write a tedious if .. elif .. elif dispatch block.

So we end up calling visit_Assert() with the assertion node. If we have an assertion like:

assert myfunc(x) == 42, 'operation failed'

the AST node will look like:

Assert(
    test=Compare(
        left=Call(
            func=Name(id='myfunc', ctx=Load()),
            args=[Name(id='x', ctx=Load())],
            keywords=[]
        ),
        ops=[Eq()],
        comparators=[Num(n=42)]
    ),
    msg=Str(s='operation failed')
)

The test is the thing that we want to annotate, and Pytest deals with it like this:

# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)

This again re-dispatches to the correct visit_... method, in our case visit_Call (which is actually visit_Call_35 on my version of Python, due to a change in Python 3.5 and later).

The function that handles the function call looks like this:

new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
new_kwargs = []
for arg in call.args:
    res, expl = self.visit(arg)
    arg_expls.append(expl)
    new_args.append(res)
for keyword in call.keywords:
    res, expl = self.visit(keyword.value)
    new_kwargs.append(ast.keyword(keyword.arg, res))
    if keyword.arg:
        arg_expls.append(keyword.arg + "=" + expl)
    else:  # **args have `arg` keywords with an .arg of None
        arg_expls.append("**" + expl)
 
expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
return res, outer_expl

Let’s just look at a couple of these lines:

new_func, func_expl = self.visit(call.func)
# ...
new_call = ast.Call(new_func, new_args, new_kwargs)

This operates on the function that is being called, which in our case
is a lookup of a local name. In case that’s not clear: in Python, when
we have a call like:

my_function(42)

there are actually two steps: The name my_function is looked up in the local context to find out which function it refers to. Then the function is executed. At the moment we’re looking at the first of these steps.

The Name node (which looks up the function) is transformed by calling visit, which forwards to visit_Name. The result of this is new_func, which is packed into a new ast.Call node. The resultant AST still has the same effect of calling the function, but has more functionality wrapped round it. The AST will eventually be compiled and executed and will work like the original code, but behave differently in the case where the assertion fails.

OK, I’ve talked about “additional behaviour”, but what exactly does that amount to? When we call:

new_func, func_expl = self.visit(call.func)

what is getting returned?

Actually, the first value returned by visit_Name is just its first parameter, so new_func is actually just call.func. But func_expl contains some code that can be used to generate a human-readable name for the function:

def visit_Name(self, name):
    # Display the repr of the name if it's a local variable or
    # _should_repr_global_name() thinks it's acceptable.
    locs = ast_Call(self.builtin("locals"), [], [])
    inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
    dorepr = self.helper("should_repr_global_name", name)
    test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
    expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
    return name, self.explanation_param(expr)

Let’s build this up a step at a time, assuming we’ve passed in the Name(id='myfunc', ctx=Load()) that was in the AST we examined above:

locs = ast_Call(self.builtin("locals"), [], [])

This is just AST-speak for something like:

locals()

Next:

locs = ast_Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])

We already know what locs is, so we can substitute it into the comparison:

"myfunc" in locals()

The next couple of lines:

dorepr = self.helper("should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])

If we ignore the self.helper stuff for now and just assume this is a call to _should_repr_global_name in the rewritemodule, we can see this becomes:

("myfunc" in locals()) or (_should_repr_global_name(Name(id='myfunc',...)))

This then gets used as a condition in an if expression:

expr = ast.IfExp(test, self.display(name), ast.Str(name.id))

Note that this is an if expression, not
an if statement. In other words, we’re looking at something like this:

repr(myfunc) if (("myfunc" in locals()) or (...)) else "myfunc"

(I glossed over the details of self.display, which actually does a bit more than just generate a call to repr, but that will do for now).

The result of all this shenanigans is to produce two things:

  • an AST that can be evaluated at run time to generate the name of the function
  • func_expl, which is a formatting string with placeholders in it like %(py0)s%(py1)s. These keys can be looked up in a dict stored on the Assertion Rewriter object to obtain the AST that generates the correct diagnostic output.

It’s worth stepping back at this point and reiterating why we do things in this roundabout way: generating an AST that generates the value rather than just generating the value. The answer is that a lot of the values aren’t known until later on when the code is executed. We’ve focused on the function name as an example because it’s simple, but (in most cases) the function name could just be dumped out directly by examining the source code. However, we can’t print the values of the arguments to the function or the return value from the function, because we don’t know what those values are yet.

All the above gets carried out for various other pieces of the input AST: comparisons, boolean operators, arithmetic etc. When the dust settles and we finish processing the visit_Assert, we need to generate a little more code:

# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))

By now you’ve hopefully got the trick of interpreting this code: We’re adding an element to self.statements that is the AST for code like the following:

if not (condition):
    explanation

where condition is just the original expression that was inside the assert statement, and explanation is the code that generates the diagnostics.

We’re finally there. Once the AssertionRewriter finishes walking through the tree, we have an AST that can be passed to compile. A little more housework is required to create a module object and put it in the Python modules list, but nothing terribly complicated.

The amazing thing about this is that the rest of the Python system doesn’t know that there’s anything unusual about this decorated module. Calling the Python code within works the same as calling any other Python module.

I think this is pretty impressive stuff, given how powerful it can be and how (relatively) little internal knowledge is required. Hopefully in the coming months I’ll dig a little deeper into some of this stuff and tear down some other examples.

Assertion rewriting in Pytest part 3: The AST

Part 1 and part 2 looked at why a test framework might need to handle method calls in an unusual way, and how we might go about implementing this by the simplest possible means.

The approach in part 2 has a couple of limitations, but one is that we’re using eval, which has a very limited interface. If we have something like:

result = eval("len(some_list) > 10")

then eval lets us find out whether the condition succeeds or fails, which is useful. But ideally we want to go further, and figure out how exactly the condition failed. We know that len(some_list) was less than or equal to 10, but exactly what was it? Maybe it would be nice to see the elements of some_list printed out to aid debugging. eval doesn’t give us any of this.

The problem is that eval is actually doing too much at once. There are broadly three steps:

  • Parsing: What is the structure of len(some_list) > 10? This stage converts a simple string of characters into a richer structure whose shape encodes everything we need to know to execute the code, and nothing more.
  • Compiling: Convert this into a form that the computer can execute.
  • Execution: Find the outcome from executing the code with the specific input values in question.

If we can split these apart, then it’s much easier to write code that manipulates the behaviour of the piece we are interested in.

Let’s look at the first stage for now.

I glibly talked about the structure having “everything we need, and nothing more”, but that’s not really a practical starting point. What should such a structure actually look like?

Without digressing too much, here’s an example of the sort of structure that’s used in Python:

 

If you look at this from the top down, you’ll first see a less-than (<) operator that has two arguments. The right hand argument is a simple number, but the left-hand argument is more complex: it’s made up of a function (len) applied to a further argument (some_list). Some points about this structure:

Firstly, it’s not necessarily obvious from the trivial example here, but this is a tree structure. As you look further down it can split apart into different branches, but the branches never join up with each other to form loops. This makes the structure much easier to work on at the cost of limiting what we can express with it; it turns out that this particular trade-off is a good one.

Secondly, the value of a node in the tree is only affected by the nodes below it, not all the other nodes in the tree. For the len node, the function has the same value whether it’s appearing in len(some_list) < 10 or len(some_list) > 6. This corresponds to what we expect, and is a good sign that this structure is going to make our life easier than if we used a different structure.

This sort of structure is called an Abstract Syntax Tree or AST and as I say, most languages use something like this. However, Python is relatively unusual in the degree to which it exposes the AST to you as a programmer. You can just import the ast module and start playing around:

>>> import ast
>>> tree = ast.parse('len(some_list) > 10', mode='eval')
>>> ast.dump(tree)
"Expression(body=Compare(left=Call(func=Name(id='len', ctx=Load()),
args=[Name(id='some_list', ctx=Load())], keywords=[]), ops=[Gt()],
comparators=[Num(n=10)]))"

If I rearrange that dump a little bit (and cut out some details), we can see:

Expression(
    body=Compare(
        left=Call(
            func=Name(id='len', ...),
            args=[
                Name(id='some_list', ...)
            ],
            ....
        ),
        ops=[Gt()],
        comparators=[
            Num(n=10)
        ]
    )
)

It doesn’t take too much imagination to see that the diagram above is a rough approximation to this.

So we’re peering into the first stage of executing the code; we’re actually seeing inside the workings of exec. But that’s not all. This is real working Python code, and we can prove this by carrying on with the compile and execute stages:

>>> bytecode = compile(tree, '<string>', mode='eval')
>>> eval(
        bytecode,
        {},
        {
            'some_list': [1, 2, 3]
        }
    )
False

We just did the compile and execute (second and third) stages ourselves, and got the same output as if we had let Python do it for us inside of exec.

It’s worth reiterating: this isn’t just a curiosity. This is a practical tool that you can use in real Python code, and you can do this with just the tools in the standard library. I’ll look at how this might be used in practice in the next part.

Assertion rewriting in Pytest part 2: Simple workarounds

Part 1 looked at why a test framework can’t offer a convenient interface using ordinary functions alone. In this part I’ll look at some ad hoc code that might be used to implement something better. Hopefully this illustrates some of the things that Python can do and sets the boundaries for what a test framework might want to do.

Continue reading “Assertion rewriting in Pytest part 2: Simple workarounds”

Assertion rewriting in Pytest part 1: Why it’s needed

Pytest is fast becoming the de facto standard for Python unit testing. One of its most distinctive features is that it allows (indeed, encourages) you to use plain old Python assert rather than having to use library-specific assertSomeLongChainOfConditions methods.

For example, instead of writing:

def test_generate_list(self):
    results = generate_list()
    self.assertListContains(results, 'foo')

You can write:

def test_generate_list(self):
    results = generate_list()
    assert 'foo' in results

Not only is there less to remember for the test writer, it’s easier to understand for the test reader.

Why didn’t the Python unittest library do this from the start? The obvious answer is that most unit test libraries take their inspiration from JUnit, and in Java an assert statement triggers an error that can’t be caught. But the same isn’t true in Python, where you can catch an AssertionError just the same as you can catch any other exception.

In fact, in Python’s unittest the assertion methods throw an AssertionError when a condition fails, so type-wise you handle the same exception from:

self.assertEquals(1, 2)

as from:

assert 1 == 2

Why did the unittest authors bother with all those extra methods? Surely it’s not just copying Java?

Thinking a little more, it becomes clear that assert doesn’t actually give us all we need. Suppose we have a test assertion that’s failing:

assert the_number_of_the_counting == 3

When this fails, all we get is an AssertionError(). At this point you want your unit test library to give some useful diagnostics, such as the true value of the_number_of_the_counting, but it doesn’t have any data to go on. There isn’t even any message:

try:
    assert the_number_of_the_counting == 3
except AssertionError as e:
    print("Message: " + e.message)

This just prints:

Message:

We seem to be out of options here. We can’t change what the assert statement does, because it’s built into the language. We can’t create our own replacement with a function because when you pass an expression to a function the function only gets to see the resultant value, it can’t see the whole expression. For example, if we have:

my_assert(1 == 2)

then the my_assert function just receives a False value, and can’t see the two numbers that are being compared. No matter what happens inside the function, the information is already lost.

How do other languages cope with this? Most of them just rely on explicit assertion methods like assertEquals, but C++ has an approach that’s worth thinking about.

In C++ (derived from the C language it grew out of) you can declare a macro, which defines an operation to be applied to the text of the program before any compilation happens. There’s an example of this in the Boost library. You write something like:

BOOST_TEST(a + b < 10);

And before the compiler compiles the code it’s processed by a rule that converts it into something like:

real_assert("a + b < 10", a + b < 10);

This is starting to get us somewhere, because the function now at least has enough information to log a decent error message. We can implement this function with something like:

void real_assert(string expression, bool value) {
    if (!value) {
        report_error("Expression " + expression + " wasn't satisfied");
    }
}

As an aside, this sort of text macro substitution is pretty uncommon outside of C and C++. It can do some useful things, but it’s pretty easy to make things unmaintainable and it’s often hard to achieve what you want.

In this particular case, the macro approach still imposes serious limits on what we can do. A really great test library, when faced with the failure of the expression a + b < 10 would print out not just the whole expression but the values of a and b, so that the developer can see which is wrong. This kind of thing just isn't possible with the compile-time string substitution.

Pytest has a really interesting solution to this problem, and I think it illustrates a key reason why Python is such an interesting language: user-defined libraries can do things that would need compiler modifications in other languages. I'll dig into this in the next part.