Repairing Code Automatically

So far, we have discussed how to track failures and how to locate defects in code. Let us now discuss how to repair defects – that is, to correct the code such that the failure no longer occurs. We will discuss how to repair code automatically – by systematically searching through possible fixes and evolving the most promising candidates.

from bookutils import YouTubeVideo
YouTubeVideo("UJTf7cW0idI")

Prerequisites

Synopsis

To use the code provided in this chapter, write

>>> from debuggingbook.Repairer import <identifier>

and then make use of the following features.

This chapter provides tools and techniques for automated repair of program code. The Repairer class takes a RankingDebugger debugger as input (such as OchiaiDebugger from the chapter on statistical debugging. A typical setup looks like this:

from debuggingbook.StatisticalDebugger import OchiaiDebugger

debugger = OchiaiDebugger()
for inputs in TESTCASES:
    with debugger:
        test_foo(inputs)
...

repairer = Repairer(debugger)

Here, test_foo() is a function that raises an exception if the tested function foo() fails. If foo() passes, test_foo() should not raise an exception.

The repair() method of a Repairer searches for a repair of the code covered in the debugger (except for methods whose name starts or ends in test, such that foo(), not test_foo() is repaired). repair() returns the best fix candidate as a pair (tree, fitness) where tree is a Python abstract syntax tree (AST) of the fix candidate, and fitness is the fitness of the candidate (a value between 0 and 1). A fitness of 1.0 means that the candidate passed all tests. A typical usage looks like this:

tree, fitness = repairer.repair()
print(ast.unparse(tree), fitness)

Here is a complete example for the middle() program. This is the original source code of middle():

def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    else:
        if x > y:
            return y
        elif x > z:
            return x
    return z

We set up a function middle_test() that tests it. The middle_debugger collects testcases and outcomes:

>>> middle_debugger = OchiaiDebugger()
>>> for x, y, z in MIDDLE_PASSING_TESTCASES + MIDDLE_FAILING_TESTCASES:
>>>     with middle_debugger:
>>>         middle_test(x, y, z)

The repairer is instantiated with the debugger used (middle_debugger):

>>> middle_repairer = Repairer(middle_debugger)
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/108387771.py:15: UserWarning: Can't parse _clean_thread_parent_frames
  warnings.warn(f"Can't parse {item.__name__}")
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/108387771.py:15: UserWarning: Can't parse ident
  warnings.warn(f"Can't parse {item.__name__}")

The repair() method of the repairer attempts to repair the function invoked by the test (middle()).

>>> tree, fitness = middle_repairer.repair()

The returned AST tree can be output via ast.unparse():

>>> print(ast.unparse(tree))
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return x
    elif x > y:
        return y
    elif x > z:
        return x
    return z
    'Return a list of all Thread objects currently alive.\n    '
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())

The fitness value shows how well the repaired program fits the tests. A fitness value of 1.0 shows that the repaired program satisfies all tests.

>>> fitness
1.0

Hence, the above program indeed is a perfect repair in the sense that all previously failing tests now pass – our repair was successful.

Here are the classes defined in this chapter. A Repairer repairs a program, using a StatementMutator and a CrossoverOperator to evolve a population of candidates.

Repairer Repairer __init__() repair() default_functions() evolve() fitness() fitness_key() getsource() initial_population() log_tree() parse() reduce() run_test_set() run_tests() test_reduce() toplevel_defs() validate() weight() StackInspector StackInspector _generated_function_cache Repairer->StackInspector ConditionMutator ConditionMutator __init__() choose_bool_op() choose_condition() swap() StatementMutator StatementMutator NODE_MAX_LENGTH __init__() mutate() choose_op() choose_statement() delete() format_node() insert() node_suspiciousness() node_to_be_mutated() swap() visit() ConditionMutator->StatementMutator NodeTransformer NodeTransformer StatementMutator->NodeTransformer NodeVisitor NodeVisitor NodeTransformer->NodeVisitor CrossoverOperator CrossoverOperator SKIP_LIST __init__() crossover() can_cross() cross_bodies() crossover_attr() crossover_branches() Legend Legend •  public_method() •  private_method() •  overloaded_method() Hover over names to see doc

Automatic Code Repairs

So far, we have discussed how to locate defects in code, how to track failures back to the defects that caused them, and how to systematically determine failure conditions. Let us now address the last step in debugging – namely, how to automatically fix code.

Already in the introduction to debugging, we have discussed how to fix code manually. Notably, we have established that a diagnosis (which induces a fix) should show causality (i.e., how the defect causes the failure) and incorrectness (how the defect is wrong). Is it possible to obtain such a diagnosis automatically?

In this chapter, we introduce a technique of automatic code repair – that is, for a given failure, automatically determine a fix that makes the failure go away. To do so, we randomly (but systematically) mutate the program code – that is, insert, change, and delete fragments – until we find a change that actually causes the failing test to pass.

If this sounds like an audacious idea, that is because it is. But not only is automated program repair one of the hottest topics of software research in the last decade, it is also being increasingly deployed in industry. At Facebook, for instance, every failing test report comes with an automatically generated repair suggestion – a suggestion that already has been validated to work. Programmers can apply the suggestion as is or use it as basis for their own fixes.

The middle() Function

Let us introduce our ongoing example. In the chapter on statistical debugging, we have introduced the middle() function – a function that returns the "middle" of three numbers x, y, and z:

from StatisticalDebugger import middle
708  def middle(x, y, z):  # type: ignore
709      if y < z:
710          if x < y:
711              return y
712          elif x < z:
713              return y
714      else:
715          if x > y:
716              return y
717          elif x > z:
718              return x
719      return z

In most cases, middle() just runs fine:

middle(4, 5, 6)
5

In some other cases, though, it does not work correctly:

middle(2, 1, 3)
1

Validated Repairs

Now, if we only want a repair that fixes this one given failure, this would be very easy. All we have to do is to replace the entire body by a single statement:

def middle_sort_of_fixed(x, y, z):
    return x

You will concur that the failure no longer occurs:

middle_sort_of_fixed(2, 1, 3)
2

But this, of course, is not the aim of automatic fixes, nor of fixes in general: We want our fixes not only to make the given failure go away, but we also want the resulting code to be correct (which, of course, is a lot harder).

Automatic repair techniques therefore assume the existence of a test suite that can check whether an implementation satisfies its requirements. Better yet, one can use the test suite to gradually check how close one is to perfection: A piece of code that satisfies 99% of all tests is better than one that satisfies ~33% of all tests, as middle_sort_of_fixed() would do (assuming the test suite evenly checks the input space).

Genetic Optimization

The common approach for automatic repair follows the principle of genetic optimization. Roughly spoken, genetic optimization is a metaheuristic inspired by the process of natural selection. The idea is to evolve a selection of candidate solutions towards a maximum fitness:

  1. Have a selection of candidates.
  2. Determine the fitness of each candidate.
  3. Retain those candidates with the highest fitness.
  4. Create new candidates from the retained candidates, by applying genetic operations:
    • Mutation mutates some aspect of a candidate.
    • CrossoverOperator creates new candidates combining features of two candidates.
  5. Repeat until an optimal solution is found.

Applied for automated program repair, this means the following steps:

  1. Have a test suite with both failing and passing tests that helps to assert correctness of possible solutions.
  2. With the test suite, use fault localization to determine potential code locations to be fixed.
  3. Systematically mutate the code (by adding, changing, or deleting code) and cross code to create possible fix candidates.
  4. Identify the fittest fix candidates – that is, those that satisfy the most tests.
  5. Evolve the fittest candidates until a perfect fix is found, or until time resources are depleted.

Let us illustrate these steps in the following sections.

A Test Suite

In automated repair, the larger and the more thorough the test suite, the higher the quality of the resulting fix (if any). Hence, if we want to repair middle() automatically, we need a good test suite – with good inputs, but also with good checks. Note that running the test suite commonly takes the most time of automated repair, so a large test suite also comes with extra cost.

Let us first focus on achieving high-quality repairs. Hence, we will use the extensive test suites introduced in the chapter on statistical debugging:

from StatisticalDebugger import MIDDLE_PASSING_TESTCASES, MIDDLE_FAILING_TESTCASES

The middle_test() function fails whenever middle() returns an incorrect result:

def middle_test(x: int, y: int, z: int) -> None:
    m = middle(x, y, z)
    assert m == sorted([x, y, z])[1]
from ExpectError import ExpectError
with ExpectError():
    middle_test(2, 1, 3)
Traceback (most recent call last):
  File "/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/3661663124.py", line 2, in <module>
    middle_test(2, 1, 3)
  File "/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/40742806.py", line 3, in middle_test
    assert m == sorted([x, y, z])[1]
           ^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError (expected)

Locating the Defect

Our next step is to find potential defect locations – that is, those locations in the code our mutations should focus upon. Since we already do have two test suites, we can make use of statistical debugging to identify likely faulty locations. Our OchiaiDebugger ranks individual code lines by how frequently they are executed in failing runs (and not in passing runs).

from StatisticalDebugger import OchiaiDebugger, RankingDebugger
middle_debugger = OchiaiDebugger()

for x, y, z in MIDDLE_PASSING_TESTCASES + MIDDLE_FAILING_TESTCASES:
    with middle_debugger:
        middle_test(x, y, z)

We see that the upper half of the middle() code is definitely more suspicious:

middle_debugger
   1 def middle_test(x: int, y: int, z: int) -> None:
   2     m = middle(x, y, z)
   3     assert m == sorted([x, y, z])[1]

 708 def middle(x, y, z):  # type: ignore
 709     if y < z:
 710         if x < y:
 711             return y
 712         elif x < z:
 713             return y
 714     else:
 715         if x > y:
 716             return y
 717         elif x > z:
 718             return x
 719     return z

The most suspicious line is:

713            return y

with a suspiciousness of:

0.9667364890456637

Random Code Mutations

Our third step in automatic code repair is to randomly mutate the code. Specifically, we want to randomly delete, insert, and replace statements in the program to be repaired. However, simply synthesizing code from scratch is unlikely to yield anything meaningful – the number of combinations is simply far too high. Already for a three-character identifier name, we have more than 200,000 combinations:

import string
string.ascii_letters
'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
len(string.ascii_letters + '_') * \
  len(string.ascii_letters + '_' + string.digits) * \
  len(string.ascii_letters + '_' + string.digits)
210357

Hence, we do not synthesize code from scratch, but instead reuse elements from the program to be fixed, hypothesizing that "a program that contains an error in one area likely implements the correct behavior elsewhere" [C. Le Goues et al, 2012]. This insight has been dubbed the plastic surgery hypothesis: content of new code can often be assembled out of fragments of code that already exist in the code base \citeBarr2014}.

For our "plastic surgery", we do not operate on a textual representation of the program, but rather on a structural representation, which by construction allows us to avoid lexical and syntactical errors in the first place.

This structural representation is the abstract syntax tree (AST), which we already have seen in various chapters, such as the chapter on delta debugging, the chapter on tracing, and excessively in the chapter on slicing. The official Python ast reference is complete, but a bit brief; the documentation "Green Tree Snakes - the missing Python AST docs" provides an excellent introduction.

Recapitulating, an AST is a tree representation of the program, showing a hierarchical structure of the program's elements. Here is the AST for our middle() function.

import ast
import inspect
from bookutils import print_content, show_ast
def middle_tree() -> ast.AST:
    return ast.parse(inspect.getsource(middle))
show_ast(middle_tree())
0 FunctionDef 1 "middle" 0--1 2 arguments 0--2 9 If 0--9 70 Return 0--70 3 arg 2--3 5 arg 2--5 7 arg 2--7 4 "x" 3--4 6 "y" 5--6 8 "z" 7--8 10 Compare 9--10 18 If 9--18 44 If 9--44 11 Name 10--11 14 Lt 10--14 15 Name 10--15 12 "y" 11--12 13 Load 11--13 16 "z" 15--16 17 Load 15--17 19 Compare 18--19 27 Return 18--27 31 If 18--31 20 Name 19--20 23 Lt 19--23 24 Name 19--24 21 "x" 20--21 22 Load 20--22 25 "y" 24--25 26 Load 24--26 28 Name 27--28 29 "y" 28--29 30 Load 28--30 32 Compare 31--32 40 Return 31--40 33 Name 32--33 36 Lt 32--36 37 Name 32--37 34 "x" 33--34 35 Load 33--35 38 "z" 37--38 39 Load 37--39 41 Name 40--41 42 "y" 41--42 43 Load 41--43 45 Compare 44--45 53 Return 44--53 57 If 44--57 46 Name 45--46 49 Gt 45--49 50 Name 45--50 47 "x" 46--47 48 Load 46--48 51 "y" 50--51 52 Load 50--52 54 Name 53--54 55 "y" 54--55 56 Load 54--56 58 Compare 57--58 66 Return 57--66 59 Name 58--59 62 Gt 58--62 63 Name 58--63 60 "x" 59--60 61 Load 59--61 64 "z" 63--64 65 Load 63--65 67 Name 66--67 68 "x" 67--68 69 Load 67--69 71 Name 70--71 72 "z" 71--72 73 Load 71--73

You see that it consists of one function definition (FunctionDef) with three arguments and two statements – one If and one Return. Each If subtree has three branches – one for the condition (test), one for the body to be executed if the condition is true (body), and one for the else case (orelse). The body and orelse branches again are lists of statements.

An AST can also be shown as text, which is more compact, yet reveals more information. ast.dump() gives not only the class names of elements, but also how they are constructed – actually, the whole expression can be used to construct an AST.

print(ast.dump(middle_tree()))
Module(body=[FunctionDef(name='middle', args=arguments(posonlyargs=[], args=[arg(arg='x'), arg(arg='y'), arg(arg='z')], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[If(test=Compare(left=Name(id='y', ctx=Load()), ops=[Lt()], comparators=[Name(id='z', ctx=Load())]), body=[If(test=Compare(left=Name(id='x', ctx=Load()), ops=[Lt()], comparators=[Name(id='y', ctx=Load())]), body=[Return(value=Name(id='y', ctx=Load()))], orelse=[If(test=Compare(left=Name(id='x', ctx=Load()), ops=[Lt()], comparators=[Name(id='z', ctx=Load())]), body=[Return(value=Name(id='y', ctx=Load()))], orelse=[])])], orelse=[If(test=Compare(left=Name(id='x', ctx=Load()), ops=[Gt()], comparators=[Name(id='y', ctx=Load())]), body=[Return(value=Name(id='y', ctx=Load()))], orelse=[If(test=Compare(left=Name(id='x', ctx=Load()), ops=[Gt()], comparators=[Name(id='z', ctx=Load())]), body=[Return(value=Name(id='x', ctx=Load()))], orelse=[])])]), Return(value=Name(id='z', ctx=Load()))], decorator_list=[], type_params=[])], type_ignores=[])

This is the path to the first return statement:

ast.dump(middle_tree().body[0].body[0].body[0].body[0])
"Return(value=Name(id='y', ctx=Load()))"

Picking Statements

For our mutation operators, we want to use statements from the program itself. Hence, we need a means to find those very statements. The StatementVisitor class iterates through an AST, adding all statements it finds in function definitions to its statements list. To do so, it subclasses the Python ast NodeVisitor class, described in the official Python ast reference.

from ast import NodeVisitor
class StatementVisitor(NodeVisitor):
    """Visit all statements within function defs in an AST"""

    def __init__(self) -> None:
        self.statements: List[Tuple[ast.AST, str]] = []
        self.func_name = ""
        self.statements_seen: Set[Tuple[ast.AST, str]] = set()
        super().__init__()

    def add_statements(self, node: ast.AST, attr: str) -> None:
        elems: List[ast.AST] = getattr(node, attr, [])
        if not isinstance(elems, list):
            elems = [elems]

        for elem in elems:
            stmt = (elem, self.func_name)
            if stmt in self.statements_seen:
                continue

            self.statements.append(stmt)
            self.statements_seen.add(stmt)

    def visit_node(self, node: ast.AST) -> None:
        # Any node other than the ones listed below
        self.add_statements(node, 'body')
        self.add_statements(node, 'orelse')

    def visit_Module(self, node: ast.Module) -> None:
        # Module children are defs, classes and globals - don't add
        super().generic_visit(node)

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        # Class children are defs and globals - don't add
        super().generic_visit(node)

    def generic_visit(self, node: ast.AST) -> None:
        self.visit_node(node)
        super().generic_visit(node)

    def visit_FunctionDef(self,
                          node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> None:
        if not self.func_name:
            self.func_name = node.name

        self.visit_node(node)
        super().generic_visit(node)
        self.func_name = ""

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        return self.visit_FunctionDef(node)

The function all_statements() returns all statements in the given AST tree. If an ast class tp is given, it only returns instances of that class.

def all_statements_and_functions(tree: ast.AST, 
                                 tp: Optional[Type] = None) -> \
                                 List[Tuple[ast.AST, str]]:
    """
    Return a list of pairs (`statement`, `function`) for all statements in `tree`.
    If `tp` is given, return only statements of that class.
    """

    visitor = StatementVisitor()
    visitor.visit(tree)
    statements = visitor.statements
    if tp is not None:
        statements = [s for s in statements if isinstance(s[0], tp)]

    return statements
def all_statements(tree: ast.AST, tp: Optional[Type] = None) -> List[ast.AST]:
    """
    Return a list of all statements in `tree`.
    If `tp` is given, return only statements of that class.
    """

    return [stmt for stmt, func_name in all_statements_and_functions(tree, tp)]

Here are all the return statements in middle():

all_statements(middle_tree(), ast.Return)
[<ast.Return at 0x11a827850>,
 <ast.Return at 0x11a740f10>,
 <ast.Return at 0x11a8261d0>,
 <ast.Return at 0x11a827510>,
 <ast.Return at 0x11a826e90>]
all_statements_and_functions(middle_tree(), ast.If)
[(<ast.If at 0x11a769b50>, 'middle'),
 (<ast.If at 0x109a93890>, 'middle'),
 (<ast.If at 0x11a835810>, 'middle'),
 (<ast.If at 0x11a83f3d0>, 'middle'),
 (<ast.If at 0x11a83c210>, 'middle')]

We can randomly pick an element:

import random
random_node = random.choice(all_statements(middle_tree()))
ast.unparse(random_node)
'return y'

Mutating Statements

The main part in mutation, however, is to actually mutate the code of the program under test. To this end, we introduce a StatementMutator class – a subclass of NodeTransformer, described in the official Python ast reference.

The constructor provides various keyword arguments to configure the mutator.

from ast import NodeTransformer
import copy
class StatementMutator(NodeTransformer):
    """Mutate statements in an AST for automated repair."""

    def __init__(self,
                 suspiciousness_func:
                     Optional[Callable[[Tuple[Callable, int]], float]] = None,
                 source: Optional[List[ast.AST]] = None,
                 log: Union[bool, int] = False) -> None:
        """
        Constructor.
        `suspiciousness_func` is a function that takes a location
        (function, line_number) and returns a suspiciousness value
        between 0 and 1.0. If not given, all locations get the same 
        suspiciousness of 1.0.
        `source` is a list of statements to choose from.
        """

        super().__init__()
        self.log = log

        if suspiciousness_func is None:
            def suspiciousness_func(location: Tuple[Callable, int]) -> float:
                return 1.0
        assert suspiciousness_func is not None

        self.suspiciousness_func: Callable = suspiciousness_func

        if source is None:
            source = []
        self.source = source

        if self.log > 1:
            for i, node in enumerate(self.source):
                print(f"Source for repairs #{i}:")
                print_content(ast.unparse(node), '.py')
                print()
                print()

        self.mutations = 0

Choosing Suspicious Statements to Mutate

We start with deciding which AST nodes to mutate. The method node_suspiciousness() returns the suspiciousness for a given node, by invoking the suspiciousness function suspiciousness_func given during initialization.

import warnings
class StatementMutator(StatementMutator):
    def node_suspiciousness(self, stmt: ast.AST, func_name: str) -> float:
        if not hasattr(stmt, 'lineno'):
            warnings.warn(f"{self.format_node(stmt)}: Expected line number")
            return 0.0

        suspiciousness = self.suspiciousness_func((func_name, stmt.lineno))
        if suspiciousness is None:  # not executed
            return 0.0

        return suspiciousness

    def format_node(self, node: ast.AST) -> str:
        ...

The method node_to_be_mutated() picks a node (statement) to be mutated. It determines the suspiciousness of all statements, and invokes random.choices(), using the suspiciousness as weight. Unsuspicious statements (with zero weight) will not be chosen.

class StatementMutator(StatementMutator):
    def node_to_be_mutated(self, tree: ast.AST) -> ast.AST:
        statements = all_statements_and_functions(tree)
        assert len(statements) > 0, "No statements"

        weights = [self.node_suspiciousness(stmt, func_name) 
                   for stmt, func_name in statements]
        stmts = [stmt for stmt, func_name in statements]

        if self.log > 1:
            print("Weights:")
            for i, stmt in enumerate(statements):
                node, func_name = stmt
                print(f"{weights[i]:.2} {self.format_node(node)}")

        if sum(weights) == 0.0:
            # No suspicious line
            return random.choice(stmts)
        else:
            return random.choices(stmts, weights=weights)[0]

Choosing a Mutation Method

The method visit() is invoked on all nodes. For nodes marked with a mutate_me attribute, it randomly chooses a mutation method (choose_op()) and then invokes it on the node.

According to the rules of NodeTransformer, the mutation method can return

  • a new node or a list of nodes, replacing the current node;
  • None, deleting it; or
  • the node itself, keeping things as they are.
import re
RE_SPACE = re.compile(r'[ \t\n]+')
class StatementMutator(StatementMutator):
    def choose_op(self) -> Callable:
        return random.choice([self.insert, self.swap, self.delete])

    def visit(self, node: ast.AST) -> ast.AST:
        super().visit(node)  # Visits (and transforms?) children

        if not node.mutate_me:
            return node

        op = self.choose_op()
        new_node = op(node)
        self.mutations += 1

        if self.log:
            print(f"{node.lineno:4}:{op.__name__ + ':':7} "
                  f"{self.format_node(node)} "
                  f"becomes {self.format_node(new_node)}")

        return new_node

Swapping Statements

Our first mutator is swap(), which replaces the current node NODE by a random node found in source (using a newly defined choose_statement()).

As a rule of thumb, we try to avoid inserting entire subtrees with all attached statements; and try to respect only the first line of a node. If the new node has the form

if P:
    BODY

we thus only insert

if P: 
    pass

since the statements in BODY have a later chance to get inserted. The same holds for all constructs that have a BODY, i.e. while, for, try, with, and more.

class StatementMutator(StatementMutator):
    def choose_statement(self) -> ast.AST:
        return copy.deepcopy(random.choice(self.source))
class StatementMutator(StatementMutator):
    def swap(self, node: ast.AST) -> ast.AST:
        """Replace `node` with a random node from `source`"""
        new_node = self.choose_statement()

        if isinstance(new_node, ast.stmt):
            # The source `if P: X` is added as `if P: pass`
            if hasattr(new_node, 'body'):
                new_node.body = [ast.Pass()]
            if hasattr(new_node, 'orelse'):
                new_node.orelse = []
            if hasattr(new_node, 'finalbody'):
                new_node.finalbody = []

        # ast.copy_location(new_node, node)
        return new_node

Inserting Statements

Our next mutator is insert(), which randomly chooses some node from source and inserts it after the current node NODE. (If NODE is a return statement, then we insert the new node before NODE.)

If the statement to be inserted has the form

if P:
    BODY

we only insert the "header" of the if, resulting in

if P: 
    NODE

Again, this applies to all constructs that have a BODY, i.e., while, for, try, with, and more.

class StatementMutator(StatementMutator):
    def insert(self, node: ast.AST) -> Union[ast.AST, List[ast.AST]]:
        """Insert a random node from `source` after `node`"""
        new_node = self.choose_statement()

        if isinstance(new_node, ast.stmt) and hasattr(new_node, 'body'):
            # Inserting `if P: X` as `if P:`
            new_node.body = [node]
            if hasattr(new_node, 'orelse'):
                new_node.orelse = []
            if hasattr(new_node, 'finalbody'):
                new_node.finalbody = []
            # ast.copy_location(new_node, node)
            return new_node

        # Only insert before `return`, not after it
        if isinstance(node, ast.Return):
            if isinstance(new_node, ast.Return):
                return new_node
            else:
                return [new_node, node]

        return [node, new_node]

Deleting Statements

Our last mutator is delete(), which deletes the current node NODE. The standard case is to replace NODE by a pass statement.

If the statement to be deleted has the form

if P:
    BODY

we only delete the "header" of the if, resulting in

BODY

Again, this applies to all constructs that have a BODY, i.e., while, for, try, with, and more. If the statement to be deleted has multiple branches, a random branch is chosen (e.g., the else branch of an if statement).

class StatementMutator(StatementMutator):
    def delete(self, node: ast.AST) -> None:
        """Delete `node`."""

        branches = [attr for attr in ['body', 'orelse', 'finalbody']
                    if hasattr(node, attr) and getattr(node, attr)]
        if branches:
            # Replace `if P: S` by `S`
            branch = random.choice(branches)
            new_node = getattr(node, branch)
            return new_node

        if isinstance(node, ast.stmt):
            # Avoid empty bodies; make this a `pass` statement
            new_node = ast.Pass()
            ast.copy_location(new_node, node)
            return new_node

        return None  # Just delete
from bookutils import quiz

Quiz

Why are statements replaced by pass rather than deleted?





Indeed, Python's compile() will fail if any of the bodies is an empty list. Also, it leaves us a statement that can be evolved further.

Helpers

For logging purposes, we introduce a helper function format_node() that returns a short string representation of the node.

class StatementMutator(StatementMutator):
    NODE_MAX_LENGTH = 20

    def format_node(self, node: ast.AST) -> str:
        """Return a string representation for `node`."""
        if node is None:
            return "None"

        if isinstance(node, list):
            return "; ".join(self.format_node(elem) for elem in node)

        s = RE_SPACE.sub(' ', ast.unparse(node)).strip()
        if len(s) > self.NODE_MAX_LENGTH - len("..."):
            s = s[:self.NODE_MAX_LENGTH] + "..."
        return repr(s)

All Together

Let us now create the main entry point, which is mutate(). It picks the node to be mutated and marks it with a mutate_me attribute. By calling visit(), it then sets off the NodeTransformer transformation.

class StatementMutator(StatementMutator):
    def mutate(self, tree: ast.AST) -> ast.AST:
        """Mutate the given AST `tree` in place. Return mutated tree."""

        assert isinstance(tree, ast.AST)

        tree = copy.deepcopy(tree)

        if not self.source:
            self.source = all_statements(tree)

        for node in ast.walk(tree):
            node.mutate_me = False

        node = self.node_to_be_mutated(tree)
        node.mutate_me = True

        self.mutations = 0

        tree = self.visit(tree)

        if self.mutations == 0:
            warnings.warn("No mutations found")

        ast.fix_missing_locations(tree)
        return tree

Here are a number of transformations applied by StatementMutator:

mutator = StatementMutator(log=True)
for i in range(10):
    new_tree = mutator.mutate(middle_tree())
   9:insert: 'return y' becomes 'return y'
   8:insert: 'if x > y: return y e...' becomes 'if x < y: if x > y: ...'
  12:insert: 'return z' becomes 'if y < z: return z...'
   3:swap:   'if x < y: return y e...' becomes 'return x'
   3:swap:   'if x < y: return y e...' becomes 'return z'
   3:swap:   'if x < y: return y e...' becomes 'return x'
  11:swap:   'return x' becomes 'return y'
  10:insert: 'if x > z: return x...' becomes 'if x > z: return x...'; 'return z'
  12:delete: 'return z' becomes 'pass'
   8:swap:   'if x > y: return y e...' becomes 'if y < z: pass'

This is the effect of the last mutator applied on middle:

print_content(ast.unparse(new_tree), '.py')
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    elif y < z:
        pass
    return z

Fitness

Now that we can apply random mutations to code, let us find out how good these mutations are. Given our test suites for middle, we can check for a given code candidate how many of the previously passing test cases it passes, and how many of the failing test cases it passes. The more tests pass, the higher the fitness of the candidate.

Not all passing tests have the same value, though. We want to prevent regressions – that is, having a fix that breaks a previously passing test. The values of WEIGHT_PASSING and WEIGHT_FAILING set the relative weight (or importance) of passing vs. failing tests; we see that keeping passing tests passing is far more important than fixing failing tests.

WEIGHT_PASSING = 0.99
WEIGHT_FAILING = 0.01
def middle_fitness(tree: ast.AST) -> float:
    """Compute fitness of a `middle()` candidate given in `tree`"""
    original_middle = middle

    try:
        code = compile(cast(ast.Module, tree), '<fitness>', 'exec')
    except ValueError:
        return 0  # Compilation error

    exec(code, globals())

    passing_passed = 0
    failing_passed = 0

    # Test how many of the passing runs pass
    for x, y, z in MIDDLE_PASSING_TESTCASES:
        try:
            middle_test(x, y, z)
            passing_passed += 1
        except AssertionError:
            pass

    passing_ratio = passing_passed / len(MIDDLE_PASSING_TESTCASES)

    # Test how many of the failing runs pass
    for x, y, z in MIDDLE_FAILING_TESTCASES:
        try:
            middle_test(x, y, z)
            failing_passed += 1
        except AssertionError:
            pass

    failing_ratio = failing_passed / len(MIDDLE_FAILING_TESTCASES)

    fitness = (WEIGHT_PASSING * passing_ratio +
               WEIGHT_FAILING * failing_ratio)

    globals()['middle'] = original_middle
    return fitness

Our faulty middle() program has a fitness of WEIGHT_PASSING (99%), because it passes all the passing tests (but none of the failing ones).

middle_fitness(middle_tree())
0.99

Our "sort of fixed" version of middle() gets a much lower fitness:

middle_fitness(ast.parse("def middle(x, y, z): return x"))
0.4258

In the chapter on statistical debugging, we also defined a fixed version of middle(). This gets a fitness of 1.0, passing all tests. (We won't use this fixed version for automated repairs.)

from StatisticalDebugger import middle_fixed
middle_fixed_source = \
    inspect.getsource(middle_fixed).replace('middle_fixed', 'middle').strip()
middle_fitness(ast.parse(middle_fixed_source))
1.0

Population

We now set up a population of fix candidates to evolve over time. A higher population size will yield more candidates to check, but also need more time to test; a lower population size will yield fewer candidates, but allow for more evolution steps. We choose a population size of 40 (from [C. Le Goues et al, 2012]).

POPULATION_SIZE = 40
middle_mutator = StatementMutator()
MIDDLE_POPULATION = [middle_tree()] + \
    [middle_mutator.mutate(middle_tree()) for i in range(POPULATION_SIZE - 1)]

We sort the fix candidates according to their fitness. This actually runs all tests on all candidates.

MIDDLE_POPULATION.sort(key=middle_fitness, reverse=True)

The candidate with the highest fitness is still our original (faulty) middle() code:

print(ast.unparse(MIDDLE_POPULATION[0]),
      middle_fitness(MIDDLE_POPULATION[0]))
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    elif x > y:
        return y
    elif x > z:
        return x
    return z 0.99

At the other end of the spectrum, the candidate with the lowest fitness has some vital functionality removed:

print(ast.unparse(MIDDLE_POPULATION[-1]),
      middle_fitness(MIDDLE_POPULATION[-1]))
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    else:
        return y
    return z 0.5445

Evolution

To evolve our population of candidates, we fill up the population with mutations created from the population, using a StatementMutator as described above to create these mutations. Then we reduce the population to its original size, keeping the fittest candidates.

def evolve_middle() -> None:
    global MIDDLE_POPULATION

    source = all_statements(middle_tree())
    mutator = StatementMutator(source=source)

    n = len(MIDDLE_POPULATION)

    offspring: List[ast.AST] = []
    while len(offspring) < n:
        parent = random.choice(MIDDLE_POPULATION)
        offspring.append(mutator.mutate(parent))

    MIDDLE_POPULATION += offspring
    MIDDLE_POPULATION.sort(key=middle_fitness, reverse=True)
    MIDDLE_POPULATION = MIDDLE_POPULATION[:n]

This is what happens when evolving our population for the first time; the original source is still our best candidate.

evolve_middle()
tree = MIDDLE_POPULATION[0]
print(ast.unparse(tree), middle_fitness(tree))
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    elif x > y:
        return y
    elif x > z:
        return x
    return z 0.99

However, nothing keeps us from evolving for a few generations more...

for i in range(50):
    evolve_middle()
    best_middle_tree = MIDDLE_POPULATION[0]
    fitness = middle_fitness(best_middle_tree)
    print(f"\rIteration {i:2}: fitness = {fitness}  ", end="")
    if fitness >= 1.0:
        break
Iteration  1: fitness = 1.0  

Success! We find a candidate that actually passes all tests, including the failing ones. Here is the candidate:

print_content(ast.unparse(best_middle_tree), '.py', start_line_number=1)
 1  def middle(x, y, z):
 2      if y < z:
 3          if x < y:
 4              if x < z:
 5                  return y
 6          elif x < z:
 7              return x
 8      elif x > y:
 9          return y
10      else:
11          if x > z:
12              return x
13          return z
14      return z

... and yes, it passes all tests:

original_middle = middle
code = compile(cast(ast.Module, best_middle_tree), '<string>', 'exec')
exec(code, globals())

for x, y, z in MIDDLE_PASSING_TESTCASES + MIDDLE_FAILING_TESTCASES:
    middle_test(x, y, z)

middle = original_middle

As the code is already validated by hundreds of test cases, it is very valuable for the programmer. Even if the programmer decides not to use the code as is, the location gives very strong hints on which code to examine and where to apply a fix.

However, a closer look at our fix candidate shows that there is some amount of redundancy – that is, superfluous statements.

Quiz

Some of the lines in our fix candidate are redundant. Which are these?





Simplifying

As demonstrated in the chapter on reducing failure-inducing inputs, we can use delta debugging on code to get rid of these superfluous statements.

The trick for simplification is to have the test function (test_middle_lines()) declare a fitness of 1.0 as a "failure". Delta debugging will then simplify the input as long as the "failure" (and hence the maximum fitness obtained) persists.

from DeltaDebugger import DeltaDebugger
middle_lines = ast.unparse(best_middle_tree).strip().split('\n')
def test_middle_lines(lines: List[str]) -> None:
    source = "\n".join(lines)
    tree = ast.parse(source)
    assert middle_fitness(tree) < 1.0  # "Fail" only while fitness is 1.0
with DeltaDebugger() as dd:
    test_middle_lines(middle_lines)
reduced_lines = dd.min_args()['lines']
reduced_source = "\n".join(reduced_lines)
repaired_source = ast.unparse(ast.parse(reduced_source))  # normalize
print_content(repaired_source, '.py')
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return x
    elif x > y:
        return y
    elif x > z:
        return x
    return z

Success! Delta Debugging has eliminated the superfluous statements. We can present the difference to the original as a patch:

original_source = ast.unparse(ast.parse(middle_source))  # normalize
from ChangeDebugger import diff, print_patch  # minor dependency
for patch in diff(original_source, repaired_source):
    print_patch(patch)
@@ -87,37 +87,37 @@
  x < z:

-            return y

+            return x

     elif

We can present this patch to the programmer, who will then immediately know what to fix in the middle() code.

Crossover

So far, we have only applied one kind of genetic operators – mutation. There is a second one, though, also inspired by natural selection.

The crossover operation mutates two strands of genes, as illustrated in the following picture. We have two parents (red and blue), each as a sequence of genes. To create "crossed" children, we pick a crossover point and exchange the strands at this very point:

We implement a CrossoverOperator class that implements such an operation on two randomly chosen statement lists of two programs. It is used as

crossover = CrossoverOperator()
crossover.crossover(tree_p1, tree_p2)

where tree_p1 and tree_p2 are two ASTs that are changed in place.

Implementing Crossover

Crossing Statement Lists

Applied on programs, a crossover mutation takes two parents and "crosses" a list of statements. As an example, if our "parents" p1() and p2() are defined as follows:

def p1():
    a = 1
    b = 2
    c = 3
def p2():
    x = 1
    y = 2
    z = 3

Then a crossover operation would produce one child with a body

a = 1
y = 2
z = 3

and another child with a body

x = 1
b = 2
c = 3

We can easily implement this in a CrossoverOperator class in a method cross_bodies().

class CrossoverOperator:
    """A class for performing statement crossover of Python programs"""

    def __init__(self, log: Union[bool, int] = False):
        """Constructor. If `log` is set, turn on logging."""
        self.log = log

    def cross_bodies(self, body_1: List[ast.AST], body_2: List[ast.AST]) -> \
        Tuple[List[ast.AST], List[ast.AST]]:
        """Crossover the statement lists `body_1` x `body_2`. Return new lists."""

        assert isinstance(body_1, list)
        assert isinstance(body_2, list)

        crossover_point_1 = len(body_1) // 2
        crossover_point_2 = len(body_2) // 2
        return (body_1[:crossover_point_1] + body_2[crossover_point_2:],
                body_2[:crossover_point_2] + body_1[crossover_point_1:])

Here's the CrossoverOperatorMutator applied on p1 and p2:

tree_p1: ast.Module = ast.parse(inspect.getsource(p1))
tree_p2: ast.Module = ast.parse(inspect.getsource(p2))
body_p1 = tree_p1.body[0].body
body_p2 = tree_p2.body[0].body
body_p1
[<ast.Assign at 0x11a9a44d0>,
 <ast.Assign at 0x11a965390>,
 <ast.Assign at 0x11a965dd0>]
crosser = CrossoverOperator()
tree_p1.body[0].body, tree_p2.body[0].body = crosser.cross_bodies(body_p1, body_p2)
print_content(ast.unparse(tree_p1), '.py')
def p1():
    a = 1
    y = 2
    z = 3
print_content(ast.unparse(tree_p2), '.py')
def p2():
    x = 1
    b = 2
    c = 3

Applying Crossover on Programs

Applying the crossover operation on arbitrary programs is a bit more complex, though. We first have to find lists of statements that we actually can cross over. The can_cross() method returns True if we have a list of statements that we can cross. Python modules and classes are excluded, because changing the ordering of definitions will not have much impact on the program functionality, other than introducing errors due to dependencies.

class CrossoverOperator(CrossoverOperator):
    # In modules and class defs, the ordering of elements does not matter (much)
    SKIP_LIST = {ast.Module, ast.ClassDef}

    def can_cross(self, tree: ast.AST, body_attr: str = 'body') -> bool:
        if any(isinstance(tree, cls) for cls in self.SKIP_LIST):
            return False

        body = getattr(tree, body_attr, [])
        return body is not None and len(body) >= 2

Here comes our method crossover_attr() which searches for crossover possibilities. It takes two ASTs t1 and t2 and an attribute (typically 'body') and retrieves the attribute lists $l_1$ (from t1.<attr>) and $l_2$ (from t2.<attr>).

If $l_1$ and $l_2$ can be crossed, it crosses them, and is done. Otherwise

  • If there is a pair of elements $e_1 \in l_1$ and $e_2 \in l_2$ that has the same name – say, functions of the same name –, it applies itself to $e_1$ and $e_2$.
  • Otherwise, it creates random pairs of elements $e_1 \in l_1$ and $e_2 \in l_2$ and applies itself on these very pairs.

crossover_attr() changes t1 and t2 in place and returns True if a crossover was found; it returns False otherwise.

class CrossoverOperator(CrossoverOperator):
    def crossover_attr(self, t1: ast.AST, t2: ast.AST, body_attr: str) -> bool:
        """
        Crossover the bodies `body_attr` of two trees `t1` and `t2`.
        Return True if successful.
        """
        assert isinstance(t1, ast.AST)
        assert isinstance(t2, ast.AST)
        assert isinstance(body_attr, str)

        if not getattr(t1, body_attr, None) or not getattr(t2, body_attr, None):
            return False

        if self.crossover_branches(t1, t2):
            return True

        if self.log > 1:
            print(f"Checking {t1}.{body_attr} x {t2}.{body_attr}")

        body_1 = getattr(t1, body_attr)
        body_2 = getattr(t2, body_attr)

        # If both trees have the attribute, we can cross their bodies
        if self.can_cross(t1, body_attr) and self.can_cross(t2, body_attr):
            if self.log:
                print(f"Crossing {t1}.{body_attr} x {t2}.{body_attr}")

            new_body_1, new_body_2 = self.cross_bodies(body_1, body_2)
            setattr(t1, body_attr, new_body_1)
            setattr(t2, body_attr, new_body_2)
            return True

        # Strategy 1: Find matches in class/function of same name
        for child_1 in body_1:
            if hasattr(child_1, 'name'):
                for child_2 in body_2:
                    if (hasattr(child_2, 'name') and
                           child_1.name == child_2.name):
                        if self.crossover_attr(child_1, child_2, body_attr):
                            return True

        # Strategy 2: Find matches anywhere
        for child_1 in random.sample(body_1, len(body_1)):
            for child_2 in random.sample(body_2, len(body_2)):
                if self.crossover_attr(child_1, child_2, body_attr):
                    return True

        return False

We have a special case for if nodes, where we can cross their body and else branches. (In Python, for and while also have else branches, but swapping these with loop bodies is likely to create havoc.)

class CrossoverOperator(CrossoverOperator):
    def crossover_branches(self, t1: ast.AST, t2: ast.AST) -> bool:
        """Special case:
        `t1` = `if P: S1 else: S2` x `t2` = `if P': S1' else: S2'`
        becomes
        `t1` = `if P: S2' else: S1'` and `t2` = `if P': S2 else: S1`
        Returns True if successful.
        """
        assert isinstance(t1, ast.AST)
        assert isinstance(t2, ast.AST)

        if (hasattr(t1, 'body') and hasattr(t1, 'orelse') and
            hasattr(t2, 'body') and hasattr(t2, 'orelse')):

            t1 = cast(ast.If, t1)  # keep mypy happy
            t2 = cast(ast.If, t2)

            if self.log:
                print(f"Crossing branches {t1} x {t2}")

            t1.body, t1.orelse, t2.body, t2.orelse = \
                t2.orelse, t2.body, t1.orelse, t1.body
            return True

        return False

The method crossover() is the main entry point. It checks for the special if case as described above; if not, it searches for possible crossover points. It raises CrossoverError if not successful.

class CrossoverOperator(CrossoverOperator):
    def crossover(self, t1: ast.AST, t2: ast.AST) -> Tuple[ast.AST, ast.AST]:
        """Do a crossover of ASTs `t1` and `t2`.
        Raises `CrossoverError` if no crossover is found."""
        assert isinstance(t1, ast.AST)
        assert isinstance(t2, ast.AST)

        for body_attr in ['body', 'orelse', 'finalbody']:
            if self.crossover_attr(t1, t2, body_attr):
                return t1, t2

        raise CrossoverError("No crossover found")
class CrossoverError(ValueError):
    pass

Crossover in Action

Let us put our CrossoverOperator in action. Here is a test case for crossover, involving more deeply nested structures:

def p1():
    if True:
        print(1)
        print(2)
        print(3)
def p2():
    if True:
        print(a)
        print(b)
    else:
        print(c)
        print(d)

We invoke the crossover() method with two ASTs from p1 and p2:

crossover = CrossoverOperator()
tree_p1 = ast.parse(inspect.getsource(p1))
tree_p2 = ast.parse(inspect.getsource(p2))
crossover.crossover(tree_p1, tree_p2);

Here is the crossed offspring, mixing statement lists of p1 and p2:

print_content(ast.unparse(tree_p1), '.py')
def p1():
    if True:
        print(c)
        print(d)
    else:
        print(a)
        print(b)
print_content(ast.unparse(tree_p2), '.py')
def p2():
    if True:
    else:
        print(1)
        print(2)
        print(3)

Here is our special case for if nodes in action, crossing our middle() tree with p2.

middle_t1, middle_t2 = crossover.crossover(middle_tree(),
                                          ast.parse(inspect.getsource(p2)))

We see how the resulting offspring encompasses elements of both sources:

print_content(ast.unparse(middle_t1), '.py')
def middle(x, y, z):
    if y < z:
        print(c)
        print(d)
    else:
        print(a)
        print(b)
    return z
print_content(ast.unparse(middle_t2), '.py')
def p2():
    if True:
        if x > y:
            return y
        elif x > z:
            return x
    elif x < y:
        return y
    elif x < z:
        return y

A Repairer Class

So far, we have applied all our techniques on the middle() program only. Let us now create a Repairer class that applies automatic program repair on arbitrary Python programs. The idea is that you can apply it on some statistical debugger, for which you have gathered passing and failing test cases, and then invoke its repair() method to find a "best" fix candidate:

debugger = OchiaiDebugger()
with debugger:
    <passing test>
with debugger:
    <failing test>
...
repairer = Repairer(debugger)
repairer.repair()
Implementing Repairer

The main argument to the Repairer constructor is the debugger to get information from. On top of that, it also allows customizing the classes used for mutation, crossover, and reduction. Setting targets allows defining a set of functions to repair; setting sources allows setting a set of sources to take repairs from. The constructor then sets up the environment for running tests and repairing, as described below.

from StackInspector import StackInspector  # minor dependency
class Repairer(StackInspector):
    """A class for automatic repair of Python programs"""

    def __init__(self, debugger: RankingDebugger, *,
                 targets: Optional[List[Any]] = None,
                 sources: Optional[List[Any]] = None,
                 log: Union[bool, int] = False,
                 mutator_class: Type = StatementMutator,
                 crossover_class: Type = CrossoverOperator,
                 reducer_class: Type = DeltaDebugger,
                 globals: Optional[Dict[str, Any]] = None):
        """Constructor.
`debugger`: a `RankingDebugger` to take tests and coverage from.
`targets`: a list of functions/modules to be repaired.
    (default: the covered functions in `debugger`, except tests)
`sources`: a list of functions/modules to take repairs from.
    (default: same as `targets`)
`globals`: if given, a `globals()` dict for executing targets
    (default: `globals()` of caller)"""

        assert isinstance(debugger, RankingDebugger)
        self.debugger = debugger
        self.log = log

        if targets is None:
            targets = self.default_functions()
        if not targets:
            raise ValueError("No targets to repair")

        if sources is None:
            sources = self.default_functions()
        if not sources:
            raise ValueError("No sources to take repairs from")

        if self.debugger.function() is None:
            raise ValueError("Multiple entry points observed")

        self.target_tree: ast.AST = self.parse(targets)
        self.source_tree: ast.AST = self.parse(sources)

        self.log_tree("Target code to be repaired:", self.target_tree)
        if ast.dump(self.target_tree) != ast.dump(self.source_tree):
            self.log_tree("Source code to take repairs from:", 
                          self.source_tree)

        self.fitness_cache: Dict[str, float] = {}

        self.mutator: StatementMutator = \
            mutator_class(
                source=all_statements(self.source_tree),
                suspiciousness_func=self.debugger.suspiciousness,
                log=(self.log >= 3))
        self.crossover: CrossoverOperator = crossover_class(log=(self.log >= 3))
        self.reducer: DeltaDebugger = reducer_class(log=(self.log >= 3))

        if globals is None:
            globals = self.caller_globals()  # see below

        self.globals = globals

When we access or execute functions, we do so in the caller's environment, not ours. The caller_globals() method from StackInspector acts as replacement for globals().

Helper Functions

The constructor uses a number of helper functions to create its environment.

class Repairer(Repairer):
    def getsource(self, item: Union[str, Any]) -> str:
        """Get the source for `item`. Can also be a string."""

        if isinstance(item, str):
            item = self.globals[item]
        return inspect.getsource(item)
class Repairer(Repairer):
    def default_functions(self) -> List[Callable]:
        """Return the set of functions to be repaired.
        Functions whose names start or end in `test` are excluded."""
        def is_test(name: str) -> bool:
            return name.startswith('test') or name.endswith('test')

        return [func for func in self.debugger.covered_functions()
                if not is_test(func.__name__)]
class Repairer(Repairer):
    def log_tree(self, description: str, tree: Any) -> None:
        """Print out `tree` as source code prefixed by `description`."""
        if self.log:
            print(description)
            print_content(ast.unparse(tree), '.py')
            print()
            print()
class Repairer(Repairer):
    def parse(self, items: List[Any]) -> ast.AST:
        """Read in a list of items into a single tree"""
        tree = ast.parse("")
        for item in items:
            if isinstance(item, str):
                item = self.globals[item]

            item_lines, item_first_lineno = inspect.getsourcelines(item)

            try:
                item_tree = ast.parse("".join(item_lines))
            except IndentationError:
                # inner function or likewise
                warnings.warn(f"Can't parse {item.__name__}")
                continue

            ast.increment_lineno(item_tree, item_first_lineno - 1)
            tree.body += item_tree.body

        return tree

Running Tests

Now that we have set the environment for Repairer, we can implement one step of automatic repair after the other. The method run_test_set() runs the given test_set (DifferenceDebugger.PASS or DifferenceDebugger.FAIL), returning the number of passed tests. If validate is set, it checks whether the outcomes are as expected.

class Repairer(Repairer):
    def run_test_set(self, test_set: str, validate: bool = False) -> int:
        """
        Run given `test_set`
        (`DifferenceDebugger.PASS` or `DifferenceDebugger.FAIL`).
        If `validate` is set, check expectations.
        Return number of passed tests.
        """
        passed = 0
        collectors = self.debugger.collectors[test_set]
        function = self.debugger.function()
        assert function is not None
        # FIXME: function may have been redefined

        for c in collectors:
            if self.log >= 4:
                print(f"Testing {c.id()}...", end="")

            try:
                function(**c.args())
            except Exception as err:
                if self.log >= 4:
                    print(f"failed ({err.__class__.__name__})")

                if validate and test_set == self.debugger.PASS:
                    raise err.__class__(
                        f"{c.id()} should have passed, but failed")
                continue

            passed += 1
            if self.log >= 4:
                print("passed")

            if validate and test_set == self.debugger.FAIL:
                raise FailureNotReproducedError(
                    f"{c.id()} should have failed, but passed")

        return passed
class FailureNotReproducedError(ValueError):
    pass

Here is how we use run_tests_set():

repairer = Repairer(middle_debugger)
assert repairer.run_test_set(middle_debugger.PASS) == \
    len(MIDDLE_PASSING_TESTCASES)
assert repairer.run_test_set(middle_debugger.FAIL) == 0

The method run_tests() runs passing and failing tests, weighing the passed test cases to obtain the overall fitness.

class Repairer(Repairer):
    def weight(self, test_set: str) -> float:
        """
        Return the weight of `test_set`
        (`DifferenceDebugger.PASS` or `DifferenceDebugger.FAIL`).
        """
        return {
            self.debugger.PASS: WEIGHT_PASSING,
            self.debugger.FAIL: WEIGHT_FAILING
        }[test_set]

    def run_tests(self, validate: bool = False) -> float:
        """Run passing and failing tests, returning weighted fitness."""
        fitness = 0.0

        for test_set in [self.debugger.PASS, self.debugger.FAIL]:
            passed = self.run_test_set(test_set, validate=validate)
            ratio = passed / len(self.debugger.collectors[test_set])
            fitness += self.weight(test_set) * ratio

        return fitness

The method validate() ensures the observed tests can be adequately reproduced.

class Repairer(Repairer):
    def validate(self) -> None:
        fitness = self.run_tests(validate=True)
        assert fitness == self.weight(self.debugger.PASS)
repairer = Repairer(middle_debugger)
repairer.validate()

(Re)defining Functions

Our run_tests() methods above do not yet redefine the function to be repaired. This is done by the fitness() function, which compiles and defines the given repair candidate tree before testing it. It caches and returns the fitness.

class Repairer(Repairer):
    def fitness(self, tree: ast.AST) -> float:
        """Test `tree`, returning its fitness"""
        key = cast(str, ast.dump(tree))
        if key in self.fitness_cache:
            return self.fitness_cache[key]

        # Save defs
        original_defs: Dict[str, Any] = {}
        for name in self.toplevel_defs(tree):
            if name in self.globals:
                original_defs[name] = self.globals[name]
            else:
                warnings.warn(f"Couldn't find definition of {repr(name)}")

        assert original_defs, f"Couldn't find any definition"

        if self.log >= 3:
            print("Repair candidate:")
            print_content(ast.unparse(tree), '.py')
            print()

        # Create new definition
        try:
            code = compile(cast(ast.Module, tree), '<Repairer>', 'exec')
        except ValueError:  # Compilation error
            code = None

        if code is None:
            if self.log >= 3:
                print(f"Fitness = 0.0 (compilation error)")

            fitness = 0.0
            return fitness

        # Execute new code, defining new functions in `self.globals`
        exec(code, self.globals)

        # Set new definitions in the namespace (`__globals__`)
        # of the function we will be calling.
        function = self.debugger.function()
        assert function is not None
        assert hasattr(function, '__globals__')

        for name in original_defs:
            function.__globals__[name] = self.globals[name]

        fitness = self.run_tests(validate=False)

        # Restore definitions
        for name in original_defs:
            function.__globals__[name] = original_defs[name]
            self.globals[name] = original_defs[name]

        if self.log >= 3:
            print(f"Fitness = {fitness}")

        self.fitness_cache[key] = fitness
        return fitness

The helper function toplevel_defs() helps to save and restore the environment before and after redefining the function under repair.

class Repairer(Repairer):
    def toplevel_defs(self, tree: ast.AST) -> List[str]:
        """Return a list of names of defined functions and classes in `tree`"""
        visitor = DefinitionVisitor()
        visitor.visit(tree)
        assert hasattr(visitor, 'definitions')
        return visitor.definitions
class DefinitionVisitor(NodeVisitor):
    def __init__(self) -> None:
        self.definitions: List[str] = []

    def add_definition(self, node: Union[ast.ClassDef, 
                                         ast.FunctionDef, 
                                         ast.AsyncFunctionDef]) -> None:
        self.definitions.append(node.name)

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        self.add_definition(node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        self.add_definition(node)

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        self.add_definition(node)

Here's an example for fitness():

repairer = Repairer(middle_debugger, log=1)
Target code to be repaired:
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    elif x > y:
        return y
    elif x > z:
        return x
    return z
good_fitness = repairer.fitness(middle_tree())
good_fitness
0.99
bad_middle_tree = ast.parse("def middle(x, y, z): return x")
bad_fitness = repairer.fitness(bad_middle_tree)
bad_fitness
0.4258

Repairing

Now for the actual repair() method, which creates a population and then evolves it until the fitness is 1.0 or the given number of iterations is spent.

import traceback
class Repairer(Repairer):
    def initial_population(self, size: int) -> List[ast.AST]:
        """Return an initial population of size `size`"""
        return [self.target_tree] + \
            [self.mutator.mutate(copy.deepcopy(self.target_tree))
                for i in range(size - 1)]

    def repair(self, population_size: int = POPULATION_SIZE, iterations: int = 100) -> \
        Tuple[ast.AST, float]:
        """
        Repair the function we collected test runs from.
        Use a population size of `population_size` and
        at most `iterations` iterations.
        Returns a pair (`ast`, `fitness`) where 
        `ast` is the AST of the repaired function, and
        `fitness` is its fitness (between 0 and 1.0)
        """
        self.validate()

        population = self.initial_population(population_size)

        last_key = ast.dump(self.target_tree)

        for iteration in range(iterations):
            population = self.evolve(population)

            best_tree = population[0]
            fitness = self.fitness(best_tree)

            if self.log:
                print(f"Evolving population: "
                      f"iteration{iteration:4}/{iterations} "
                      f"fitness = {fitness:.5}   \r", end="")

            if self.log >= 2:
                best_key = ast.dump(best_tree)
                if best_key != last_key:
                    print()
                    print()
                    self.log_tree(f"New best code (fitness = {fitness}):",
                                  best_tree)
                    last_key = best_key

            if fitness >= 1.0:
                break

        if self.log:
            print()

        if self.log and self.log < 2:
            self.log_tree(f"Best code (fitness = {fitness}):", best_tree)

        best_tree = self.reduce(best_tree)
        fitness = self.fitness(best_tree)

        self.log_tree(f"Reduced code (fitness = {fitness}):", best_tree)

        return best_tree, fitness

Evolving

The evolution of our population takes place in the evolve() method. In contrast to the evolve_middle() function, above, we use crossover to create the offspring, which we still mutate afterwards.

class Repairer(Repairer):
    def evolve(self, population: List[ast.AST]) -> List[ast.AST]:
        """Evolve the candidate population by mutating and crossover."""
        n = len(population)

        # Create offspring as crossover of parents
        offspring: List[ast.AST] = []
        while len(offspring) < n:
            parent_1 = copy.deepcopy(random.choice(population))
            parent_2 = copy.deepcopy(random.choice(population))
            try:
                self.crossover.crossover(parent_1, parent_2)
            except CrossoverError:
                pass  # Just keep parents
            offspring += [parent_1, parent_2]

        # Mutate offspring
        offspring = [self.mutator.mutate(tree) for tree in offspring]

        # Add it to population
        population += offspring

        # Keep the fitter part of the population
        population.sort(key=self.fitness_key, reverse=True)
        population = population[:n]

        return population

A second difference is that we not only sort by fitness, but also by tree size – with equal fitness, a smaller tree thus will be favored. This helps keeping fixes and patches small.

class Repairer(Repairer):
    def fitness_key(self, tree: ast.AST) -> Tuple[float, int]:
        """Key to be used for sorting the population"""
        tree_size = len([node for node in ast.walk(tree)])
        return (self.fitness(tree), -tree_size)

Simplifying

The last step in repairing is simplifying the code. As demonstrated in the chapter on reducing failure-inducing inputs, we can use delta debugging on code to get rid of superfluous statements. To this end, we convert the tree to lines, run delta debugging on them, and then convert it back to a tree.

class Repairer(Repairer):
    def reduce(self, tree: ast.AST) -> ast.AST:
        """Simplify `tree` using delta debugging."""

        original_fitness = self.fitness(tree)
        source_lines = ast.unparse(tree).split('\n')

        with self.reducer:
            self.test_reduce(source_lines, original_fitness)

        reduced_lines = self.reducer.min_args()['source_lines']
        reduced_source = "\n".join(reduced_lines)

        return ast.parse(reduced_source)

As dicussed above, we simplify the code by having the test function (test_reduce()) declare reaching the maximum fitness obtained so far as a "failure". Delta debugging will then simplify the input as long as the "failure" (and hence the maximum fitness obtained) persists.

class Repairer(Repairer):
    def test_reduce(self, source_lines: List[str], original_fitness: float) -> None:
        """Test function for delta debugging."""

        try:
            source = "\n".join(source_lines)
            tree = ast.parse(source)
            fitness = self.fitness(tree)
            assert fitness < original_fitness

        except AssertionError:
            raise
        except SyntaxError:
            raise
        except IndentationError:
            raise
        except Exception:
            # traceback.print_exc()  # Uncomment to see internal errors
            raise

Repairer in Action

Let us go and apply Repairer in practice. We initialize it with middle_debugger, which has (still) collected the passing and failing runs for middle_test(). We also set log for some diagnostics along the way.

repairer = Repairer(middle_debugger, log=True)
Target code to be repaired:
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return y
    elif x > y:
        return y
    elif x > z:
        return x
    return z

We now invoke repair() to evolve our population. After a few iterations, we find a tree with perfect fitness.

best_tree, fitness = repairer.repair()
Evolving population: iteration   0/100 fitness = 1.0   
Best code (fitness = 1.0):
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return x
    elif x > y:
        return y
    elif x > z:
        return x
    return z

Reduced code (fitness = 1.0):
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return x
    elif x > y:
        return y
    elif x > z:
        return x
    return z
print_content(ast.unparse(best_tree), '.py')
def middle(x, y, z):
    if y < z:
        if x < y:
            return y
        elif x < z:
            return x
    elif x > y:
        return y
    elif x > z:
        return x
    return z
fitness
1.0

Again, we have a perfect solution. Here, we did not even need to simplify the code in the last iteration, as our fitness_key() function favors smaller implementations.

Removing HTML Markup

Let us apply Repairer on our other ongoing example, namely remove_html_markup().

def remove_html_markup(s):
    tag = False
    quote = False
    out = ""

    for c in s:
        if c == '<' and not quote:
            tag = True
        elif c == '>' and not quote:
            tag = False
        elif c == '"' or c == "'" and tag:
            quote = not quote
        elif not tag:
            out = out + c

    return out
def remove_html_markup_tree() -> ast.AST:
    return ast.parse(inspect.getsource(remove_html_markup))

To run Repairer on remove_html_markup(), we need a test and a test suite. remove_html_markup_test() raises an exception if applying remove_html_markup() on the given html string does not yield the plain string.

def remove_html_markup_test(html: str, plain: str) -> None:
    outcome = remove_html_markup(html)
    assert outcome == plain, \
        f"Got {repr(outcome)}, expected {repr(plain)}"

Now for the test suite. We use a simple fuzzing scheme to create dozens of passing and failing test cases in REMOVE_HTML_PASSING_TESTCASES and REMOVE_HTML_FAILING_TESTCASES, respectively.

Creating HTML Test Cases
def random_string(length: int = 5, start: int = ord(' '), end: int = ord('~')) -> str:
    return "".join(chr(random.randrange(start, end + 1)) for i in range(length))
random_string()
'@YeYg'
def random_id(length: int = 2) -> str:
    return random_string(start=ord('a'), end=ord('z'))
random_id()
'eaaem'
def random_plain() -> str:
    return random_string().replace('<', '').replace('>', '')
def random_string_noquotes() -> str:
    return random_string().replace('"', '').replace("'", '')
def random_html(depth: int = 0) -> Tuple[str, str]:
    prefix = random_plain()
    tag = random_id()

    if depth > 0:
        html, plain = random_html(depth - 1)
    else:
        html = plain = random_plain()

    attr = random_id()
    value = '"' + random_string_noquotes() + '"'
    postfix = random_plain()

    return f'{prefix}<{tag} {attr}={value}>{html}</{tag}>{postfix}', \
        prefix + plain + postfix
random_html()
('L1JA|<uidfm lmvie="QNSD:">BvE"8</uidfm>v@rNS', 'L1JA|BvE"8v@rNS')
def remove_html_testcase(expected: bool = True) -> Tuple[str, str]:
    while True:
        html, plain = random_html()
        outcome = (remove_html_markup(html) == plain)
        if outcome == expected:
            return html, plain
REMOVE_HTML_TESTS = 100
REMOVE_HTML_PASSING_TESTCASES = \
    [remove_html_testcase(True) for i in range(REMOVE_HTML_TESTS)]
REMOVE_HTML_FAILING_TESTCASES = \
    [remove_html_testcase(False) for i in range(REMOVE_HTML_TESTS)]

Here is a passing test case:

REMOVE_HTML_PASSING_TESTCASES[0]
('Sg$VT<fqlui ppzww="!EyHN">J9Ji </fqlui>.)!$', 'Sg$VTJ9Ji .)!$')
html, plain = REMOVE_HTML_PASSING_TESTCASES[0]
remove_html_markup_test(html, plain)

Here is a failing test case (containing a double quote in the plain text)

REMOVE_HTML_FAILING_TESTCASES[0]
('3AGe<qcguk yewyq="wA^<S">7"!%H</qcguk>6azh_', '3AGe7"!%H6azh_')
with ExpectError():
    html, plain = REMOVE_HTML_FAILING_TESTCASES[0]
    remove_html_markup_test(html, plain)
Traceback (most recent call last):
  File "/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/2578453007.py", line 3, in <module>
    remove_html_markup_test(html, plain)
  File "/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/700130947.py", line 3, in remove_html_markup_test
    assert outcome == plain, \
           ^^^^^^^^^^^^^^^^
AssertionError: Got '3AGe7!%H</qcguk>6azh_', expected '3AGe7"!%H6azh_' (expected)

We run our tests, collecting the outcomes in html_debugger.

html_debugger = OchiaiDebugger()
for html, plain in (REMOVE_HTML_PASSING_TESTCASES + 
                    REMOVE_HTML_FAILING_TESTCASES):
    with html_debugger:
        remove_html_markup_test(html, plain)

The suspiciousness distribution will not be of much help here – pretty much all lines in remove_html_markup() have the same suspiciousness.

html_debugger
 775     def _clean_thread_parent_frames(
 776         self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
 777     ):
 778         """Clean parent frames of threads which are no longer running.
 779         This is meant to be invoked by garbage collector callback hook.
 780  
 781         The implementation enumerates the threads because there is no "exit" hook yet,
 782         but there might be one in the future: https://bugs.python.org/issue14073
 783  
 784         This is a no-op if the `self._stdout` and `self._stderr` are not
 785         sub-classes of `OutStream`.
 786         """
 787         # Only run before the garbage collector starts
 788         if phase != "start":
 789             return
 790         active_threads = {thread.ident for thread in threading.enumerate()}
 791         for stream in [self._stdout, self._stderr]:
 792             if isinstance(stream, OutStream):
 793                 thread_to_parent_header = stream._thread_to_parent_header
 794                 for identity in list(thread_to_parent_header.keys()):
 795                     if identity not in active_threads:
 796                         try:
 797                             del thread_to_parent_header[identity]
 798                         except KeyError:
 799                             pass
 800                 thread_to_parent = stream._thread_to_parent
 801                 for identity in list(thread_to_parent.keys()):
 802                     if identity not in active_threads:
 803                         try:
 804                             del thread_to_parent[identity]
 805                         except KeyError:
 806                             pass

1198     @property
1199     def ident(self):
1200         """Thread identifier of this thread or None if it has not been started.
1201  
1202         This is a nonzero integer. See the get_ident() function. Thread
1203         identifiers may be recycled when a thread exits and another thread is
1204         created. The identifier is available even after the thread has exited.
1205  
1206         """
1207         assert self._initialized, "Thread.__init__() not called"
1208         return self._ident

   1 def remove_html_markup(s):  # type: ignore
   2     tag = False
   3     quote = False
   4     out = ""
   5  
   6     for c in s:
   7         if c == '<' and not quote:
   8             tag = True
   9         elif c == '>' and not quote:
  10             tag = False
  11         elif c == '"' or c == "'" and tag:
  12             quote = not quote
  13         elif not tag:
  14             out = out + c
  15  
  16     return out

1535 def enumerate():
1536     """Return a list of all Thread objects currently alive.
1537  
1538     The list includes daemonic threads, dummy thread objects created by
1539     current_thread(), and the main thread. It excludes terminated threads and
1540     threads that have not yet been started.
1541  
1542     """
1543     with _active_limbo_lock:
1544         return list(_active.values()) + list(_limbo.values())

   1 def remove_html_markup_test(html: str, plain: str) -> None:
   2     outcome = remove_html_markup(html)
   3     assert outcome == plain, \
   4         f"Got {repr(outcome)}, expected {repr(plain)}"

Let us create our repairer and run it.

html_repairer = Repairer(html_debugger, log=True)
Target code to be repaired:
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/108387771.py:15: UserWarning: Can't parse _clean_thread_parent_frames
  warnings.warn(f"Can't parse {item.__name__}")
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/108387771.py:15: UserWarning: Can't parse ident
  warnings.warn(f"Can't parse {item.__name__}")
best_tree, fitness = html_repairer.repair(iterations=20)
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/1291881850.py:14: UserWarning: Couldn't find definition of 'enumerate'
  warnings.warn(f"Couldn't find definition of {repr(name)}")
Evolving population: iteration  19/20 fitness = 0.99   
Best code (fitness = 0.99):
def remove_html_markup(s):
    pass
    quote = False
    out = ''
    tag = False
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        pass

Reduced code (fitness = 0.99):
def remove_html_markup(s):
    quote = False
    out = ''
    tag = False
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out
    'Return a list of all Thread objects currently alive.\n    '
    with _active_limbo_lock:
        pass

We see that the "best" code is still our original code, with no changes. And we can set iterations to 50, 100, 200... – our Repairer won't be able to repair it.

Quiz

Why couldn't Repairer() repair remove_html_markup()?







You can explore all the hypotheses above by changing the appropriate parameters, but you won't be able to change the outcome. The problem is that, unlike middle(), there is no statement (or combination thereof) in remove_html_markup() that could be used to make the failure go away. For this, we need to mutate another aspect of the code, which we will explore in the next section.

Mutating Conditions

The Repairer class is very configurable. The individual steps in automated repair can all be replaced by providing own classes in the keyword arguments of its __init__() constructor:

  • To change fault localization, pass a different debugger that is a subclass of RankingDebugger.
  • To change the mutation operator, set mutator_class to a subclass of StatementMutator.
  • To change the crossover operator, set crossover_class to a subclass of CrossoverOperator.
  • To change the reduction algorithm, set reducer_class to a subclass of Reducer.

In this section, we will explore how to extend the mutation operator such that it can mutate conditions for control constructs such as if, while, or for. To this end, we introduce a new class ConditionMutator subclassing StatementMutator.

Collecting Conditions

Let us start with a few simple supporting functions. The function all_conditions() retrieves all control conditions from an AST.

def all_conditions(trees: Union[ast.AST, List[ast.AST]],
                   tp: Optional[Type] = None) -> List[ast.expr]:
    """
    Return all conditions from the AST (or AST list) `trees`.
    If `tp` is given, return only elements of that type.
    """

    if not isinstance(trees, list):
        assert isinstance(trees, ast.AST)
        trees = [trees]

    visitor = ConditionVisitor()
    for tree in trees:
        visitor.visit(tree)
    conditions = visitor.conditions
    if tp is not None:
        conditions = [c for c in conditions if isinstance(c, tp)]

    return conditions

all_conditions() uses a ConditionVisitor class to walk the tree and collect the conditions:

class ConditionVisitor(NodeVisitor):
    def __init__(self) -> None:
        self.conditions: List[ast.expr] = []
        self.conditions_seen: Set[str] = set()
        super().__init__()

    def add_conditions(self, node: ast.AST, attr: str) -> None:
        elems = getattr(node, attr, [])
        if not isinstance(elems, list):
            elems = [elems]

        elems = cast(List[ast.expr], elems)

        for elem in elems:
            elem_str = ast.unparse(elem)
            if elem_str not in self.conditions_seen:
                self.conditions.append(elem)
                self.conditions_seen.add(elem_str)

    def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
        self.add_conditions(node, 'values')
        return super().generic_visit(node)

    def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
        if isinstance(node.op, ast.Not):
            self.add_conditions(node, 'operand')
        return super().generic_visit(node)

    def generic_visit(self, node: ast.AST) -> ast.AST:
        if hasattr(node, 'test'):
            self.add_conditions(node, 'test')
        return super().generic_visit(node)

Here are all the conditions in remove_html_markup(). This is some material to construct new conditions from.

[ast.unparse(cond).strip()
    for cond in all_conditions(remove_html_markup_tree())]
["c == '<' and (not quote)",
 "c == '<'",
 'not quote',
 'quote',
 "c == '>' and (not quote)",
 "c == '>'",
 'c == \'"\' or (c == "\'" and tag)',
 'c == \'"\'',
 'c == "\'" and tag',
 'c == "\'"',
 'tag',
 'not tag']

Mutating Conditions

Here comes our ConditionMutator class. We subclass from StatementMutator and set an attribute self.conditions containing all the conditions in the source. The method choose_condition() randomly picks a condition.

class ConditionMutator(StatementMutator):
    """Mutate conditions in an AST"""

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Constructor. Arguments are as with `StatementMutator` constructor."""
        super().__init__(*args, **kwargs)
        self.conditions = all_conditions(self.source)
        if self.log:
            print("Found conditions",
                  [ast.unparse(cond).strip() 
                   for cond in self.conditions])

    def choose_condition(self) -> ast.expr:
        """Return a random condition from source."""
        return copy.deepcopy(random.choice(self.conditions))

The actual mutation takes place in the swap() method. If the node to be replaced has a test attribute (i.e. a controlling predicate), then we pick a random condition cond from the source and randomly chose from:

  • set: We change test to cond.
  • not: We invert test.
  • and: We replace test by cond and test.
  • or: We replace test by cond or test.

Over time, this might lead to operators propagating across the population.

class ConditionMutator(ConditionMutator):
    def choose_bool_op(self) -> str:
        return random.choice(['set', 'not', 'and', 'or'])

    def swap(self, node: ast.AST) -> ast.AST:
        """Replace `node` condition by a condition from `source`"""
        if not hasattr(node, 'test'):
            return super().swap(node)

        node = cast(ast.If, node)

        cond = self.choose_condition()
        new_test = None

        choice = self.choose_bool_op()

        if choice == 'set':
            new_test = cond
        elif choice == 'not':
            new_test = ast.UnaryOp(op=ast.Not(), operand=node.test)
        elif choice == 'and':
            new_test = ast.BoolOp(op=ast.And(), values=[cond, node.test])
        elif choice == 'or':
            new_test = ast.BoolOp(op=ast.Or(), values=[cond, node.test])
        else:
            raise ValueError("Unknown boolean operand")

        if new_test:
            # ast.copy_location(new_test, node)
            node.test = new_test

        return node

We can use the mutator just like StatementMutator, except that some of the mutations will also include new conditions:

mutator = ConditionMutator(source=all_statements(remove_html_markup_tree()),
                           log=True)
Found conditions ["c == '<' and (not quote)", "c == '<'", 'not quote', 'quote', "c == '>' and (not quote)", "c == '>'", 'c == \'"\' or (c == "\'" and tag)', 'c == \'"\'', 'c == "\'" and tag', 'c == "\'"', 'tag', 'not tag']
for i in range(10):
    new_tree = mutator.mutate(remove_html_markup_tree())
   8:swap:   'tag = True' becomes 'tag = False'
   6:delete: 'for c in s: if c == ...' becomes "if c == '<' and (not..."
  13:swap:   'if not tag and (not ...' becomes 'if not tag and (not ...'
   7:insert: "if c == '<' and (not..." becomes "if c == '<' and (not..."; "out = ''"
   4:insert: "out = ''" becomes "if c == '<' and (not..."
   9:insert: "if c == '>' and (not..." becomes "if c == '>' and (not..."; 'quote = not quote'
  12:delete: 'quote = not quote' becomes 'pass'
  10:swap:   'tag = False' becomes 'quote = False'
  10:delete: 'tag = False' becomes 'pass'
  11:insert: 'if c == \'"\' or (c ==...' becomes 'for c in s: if c == ...'

Let us put our new mutator to action, again in a Repairer(). To activate it, all we need to do is to pass it as mutator_class keyword argument.

condition_repairer = Repairer(html_debugger,
                              mutator_class=ConditionMutator,
                              log=2)
Target code to be repaired:
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/108387771.py:15: UserWarning: Can't parse _clean_thread_parent_frames
  warnings.warn(f"Can't parse {item.__name__}")
/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/108387771.py:15: UserWarning: Can't parse ident
  warnings.warn(f"Can't parse {item.__name__}")

We might need more iterations for this one. Let us see...

best_tree, fitness = condition_repairer.repair(iterations=200)
Evolving population: iteration   4/200 fitness = 0.99   

New best code (fitness = 0.99):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        if c == '<' and (not quote):
            pass

Evolving population: iteration   9/200 fitness = 0.99   

New best code (fitness = 0.99):
def remove_html_markup(s):
    for c in s:
        tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
            'Return a list of all Thread objects currently alive.\n\n    The list includes daemonic threads, dummy thread objects created by\n    current_thread(), and the main thread. It excludes terminated threads and\n    threads that have not yet been started.\n\n    '
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        pass

Evolving population: iteration  10/200 fitness = 0.99   

New best code (fitness = 0.99):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
            'Return a list of all Thread objects currently alive.\n\n    The list includes daemonic threads, dummy thread objects created by\n    current_thread(), and the main thread. It excludes terminated threads and\n    threads that have not yet been started.\n\n    '
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
            tag = False
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        pass

Evolving population: iteration  11/200 fitness = 1.0   

New best code (fitness = 1.0):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
            'Return a list of all Thread objects currently alive.\n\n    The list includes daemonic threads, dummy thread objects created by\n    current_thread(), and the main thread. It excludes terminated threads and\n    threads that have not yet been started.\n\n    '
        elif c == '>' and (not quote):
            tag = False
        elif tag and (c == '"' or (c == "'" and tag)):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())

Reduced code (fitness = 1.0):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif tag and (c == '"' or (c == "'" and tag)):
            quote = not quote
        elif not tag:
            out = out + c
    return out
    'Return a list of all Thread objects currently alive.\n    '
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())
repaired_source = ast.unparse(best_tree)
print_content(repaired_source, '.py')
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif tag and (c == '"' or (c == "'" and tag)):
            quote = not quote
        elif not tag:
            out = out + c
    return out
    'Return a list of all Thread objects currently alive.\n    '
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())

Success again! We have automatically repaired remove_html_markup() – the resulting code passes all tests, including those that were previously failing.

Again, we can present the fix as a patch:

original_source = ast.unparse(remove_html_markup_tree())
for patch in diff(original_source, repaired_source):
    print_patch(patch)
@@ -210,53 +210,63 @@
 lse

-        elif c == '"' or (c == "'" and tag):

+        elif tag and (c == '"' or (c == "'" and tag)):
@@ -343,18 +343,174 @@
 + c

-    return out
+    return out
    'Return a list of all Thread objects currently alive.\n    '
    with _active_limbo_lock:
        return list(_active.values()) + list(_limbo.values())

However, looking at the patch, one may come up with doubts.

Quiz

Is this actually the best solution?



Indeed – our solution does not seem to handle single quotes anymore. Why is that so?

Quiz

Why aren't single quotes handled in the solution?



Correct! Our test cases do not include single quotes – at least not in the interior of HTML tags – and thus, automatic repair did not care to preserve their handling.

How can we fix this? An easy way is to include an appropriate test case in our set – a test case that passes with the original remove_html_markup(), yet fails with the "repaired" remove_html_markup() as shown above.

with html_debugger:
    remove_html_markup_test("<foo quote='>abc'>me</foo>", "me")

Let us repeat the repair with the extended test set:

best_tree, fitness = condition_repairer.repair(iterations=200)
Evolving population: iteration  10/200 fitness = 0.99   

New best code (fitness = 0.99):
def remove_html_markup(s):
    tag = False
    quote = False
    pass
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
            'Return a list of all Thread objects currently alive.\n\n    The list includes daemonic threads, dummy thread objects created by\n    current_thread(), and the main thread. It excludes terminated threads and\n    threads that have not yet been started.\n\n    '
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        pass

Evolving population: iteration  14/200 fitness = 0.99   

New best code (fitness = 0.99):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    with _active_limbo_lock:
        pass

Evolving population: iteration  49/200 fitness = 0.99   

New best code (fitness = 0.99):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif c == '"' or (c == "'" and tag):
            quote = not quote
        elif not tag:
            out = out + c
    pass
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    pass

Evolving population: iteration  99/200 fitness = 1.0   

New best code (fitness = 1.0):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif tag and (c == '"' or (c == "'" and tag)):
            quote = not quote
        elif not tag:
            out = out + c
    pass
    return out

def enumerate():
    """Return a list of all Thread objects currently alive.

    The list includes daemonic threads, dummy thread objects created by
    current_thread(), and the main thread. It excludes terminated threads and
    threads that have not yet been started.

    """
    pass

Reduced code (fitness = 1.0):
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif tag and (c == '"' or (c == "'" and tag)):
            quote = not quote
        elif not tag:
            out = out + c
    return out
    'Return a list of all Thread objects currently alive.\n    '

Here is the final tree:

print_content(ast.unparse(best_tree), '.py')
def remove_html_markup(s):
    tag = False
    quote = False
    out = ''
    for c in s:
        if c == '<' and (not quote):
            tag = True
        elif c == '>' and (not quote):
            tag = False
        elif tag and (c == '"' or (c == "'" and tag)):
            quote = not quote
        elif not tag:
            out = out + c
    return out
    'Return a list of all Thread objects currently alive.\n    '

And here is its fitness:

fitness
1.0

The revised candidate now passes all tests (including the tricky quote test we added last). Its condition now properly checks for tag and both quotes. (The tag inside the parentheses is still redundant, but so be it.) From this example, we can learn a few lessons about the possibilities and risks of automated repair:

  • First, automatic repair is highly dependent on the quality of the checking tests. The risk is that the repair may overspecialize towards the test.
  • Second, when based on "plastic surgery", automated repair is highly dependent on the sources that program fragments are chosen from. If there is a hint of a solution somewhere in the code, there is a chance that automated repair will catch it up.
  • Third, automatic repair is a deeply heuristic approach. Its behavior will vary widely with any change to the parameters (and the underlying random number generators).
  • Fourth, automatic repair can take a long time. The examples we have in this chapter take less than a minute to compute, and neither Python nor our implementation is exactly fast. But as the search space grows, automated repair will take much longer.

On the other hand, even an incomplete automated repair candidate can be much better than nothing at all – it may provide all the essential ingredients (such as the location or the involved variables) for a successful fix. When users of automated repair techniques are aware of its limitations and its assumptions, there is lots of potential in automated repair. Enjoy!

Limitations

The Repairer class is tested on our example programs, but not much more. Things that do not work include

  • Functions with inner functions are not repaired.

Lessons Learned

  • Automated repair based on genetic optimization uses five ingredients:
    1. A test suite to determine passing and failing tests
    2. Defect localization (typically obtained from statistical debugging with the test suite) to determine potential locations to be fixed
    3. Random code mutations and crossover operations to create and evolve a population of inputs
    4. A fitness function and a selection strategy to determine the part of the population that should be evolved further
    5. A reducer such as delta debugging to simplify the final candidate with the highest fitness.
  • The result of automated repair is a fix candidate with the highest fitness for the given tests.
  • A fix candidate is not guaranteed to be correct or optimal, but gives important hints on how to fix the program.
  • All the above ingredients offer plenty of settings and alternatives to experiment with.

Background

The seminal work in automated repair is GenProg [C. Le Goues et al, 2012], which heavily inspired our Repairer implementation. Major differences between GenProg and Repairer include:

  • GenProg includes its own defect localization (which is also dynamically updated), whereas Repairer builds on earlier statistical debugging.
  • GenProg can apply multiple mutations on programs (or none at all), whereas Repairer applies exactly one mutation.
  • The StatementMutator used by Repairer includes various special cases for program structures (if, for, while...), whereas GenProg operates on statements only.
  • GenProg has been tested on large production programs.

While GenProg is the seminal work in the area (and arguably the most important software engineering research contribution of the 2010s), there have been a number of important extensions of automated repair. These include:

  • AutoFix [Y. Pei et al, 2014] leverages program contracts (pre- and postconditions) to generate tests and assertions automatically. Not only do such assertions help in fault localization, they also allow for much better validation of fix candidates.
  • SemFix [Nguyen et al, 2013] and its successor Angelix [Mechtaev et al, 2016] introduce automated program repair based on symbolic analysis rather than genetic optimization. This allows leveraging program semantics, which GenProg does not consider.

To learn more about automated program repair, see program-repair.org, the community page dedicated to research in program repair.

Exercises

Exercise 1: Automated Repair Parameters

Automated Repair is influenced by numerous design choices – the size of the population, the number of iterations, the genetic optimization strategy, and more. How do changes to these design choices affect its effectiveness?

  • Consider the constants defined in this chapter (such as POPULATION_SIZE or WEIGHT_PASSING vs. WEIGHT_FAILING). How do changes affect the effectiveness of automated repair?
  • As an effectiveness metric, consider the number of iterations it takes to produce a fix candidate.
  • Since genetic optimization is a random algorithm, you need to determine effectiveness averages over a large number of runs (say, 100).

Exercise 2: Elitism

Elitism (also known as elitist selection) is a variant of genetic selection in which a small fraction of the fittest candidates of the last population are included unchanged in the offspring.

  • Implement elitist selection by subclassing the evolve() method. Experiment with various fractions (5%, 10%, 25%) of "elites" and see how this improves results.

Exercise 3: Evolving Values

Following the steps of ConditionMutator, implement a ValueMutator class that replaces one constant value by another one found in the source (say, 0 by 1 or True by False).

For validation, consider the following failure in the square_root() function from the chapter on assertions:

from Assertions import square_root  # minor dependency
with ExpectError():
    square_root_of_zero = square_root(0)
Traceback (most recent call last):
  File "/var/folders/n2/xd9445p97rb3xh7m1dfx8_4h0006ts/T/ipykernel_15846/1107282428.py", line 2, in <module>
    square_root_of_zero = square_root(0)
                          ^^^^^^^^^^^^^^
  File "Assertions.ipynb", line 61, in square_root
    guess = (approx + x / approx) / 2
                      ~~^~~~~~~~
ZeroDivisionError: float division by zero (expected)

Can your ValueMutator automatically fix this failure?

Exercise 4: Evolving Variable Names

Following the steps of ConditionMutator, implement a IdentifierMutator class that replaces one identifier by another one found in the source (say, y by x). Does it help to fix the middle() error?

Exercise 5: Parallel Repair

Automatic Repair is a technique that is embarrassingly parallel – all tests for one candidate can all be run in parallel, and all tests for all candidates can also be run in parallel. Set up an infrastructure for running concurrent tests using Pythons asyncio library.

Creative Commons License The content of this project is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. The source code that is part of the content, as well as the source code used to format and display that content is licensed under the MIT License. Last change: 2024-08-27 17:18:55+02:00CiteImprint