lilatomic

Testing Azure Durable Functions in Python

Testing Azure Durable Functions in Python

TLDR

Azure Durable Functions allow you to use Azure Serverless Functions to make workflows and to implement a number of standard patterns for enterprise systems. Obviously, we'd like to test our code. But the Serverless paradigm does not lend itself well to rapid cycle times and deep testing. Every deploy-test-evaluate cycle to a testing instance in Azure takes about 5 minutes, especially if you've bought into the whole stack and are monitoring with AppInsights. Plus, it's difficult to test for error conditions if you have to actually produce them in an environment rather than being able to mock them in. Wouldn't it be nice if we could test these function like they were normal functions with all the techniques and tools that we developed for those?

Forunately, Azure Functions are amenable to this. But most of the examples are wrong.

Background : Python Generators

ADFs make use of Python generators to invoke other functions. Here's an example from the tutorial:

result1 = yield context.call_activity('E1_SayHello', "Tokyo")

If you're only a little familiar with Python generators, you've probably seen them as a way to generate a (possible infinite) sequence of values. The example from the Python wiki generates a series of numbers up to a value:

def first_n(n):
	num = 0
	while num < n:
		yield num
		num += 1

The syntax here is a bit different than the one used in the ADF, in that the return of the yield statement is discarded. If you're more familiar with Python generators, you might know that the yield statement can have a return value if you're using it to send values to a generator and get a return from it using the send keyword:

def useful():
	n = 0
	while True:
		x = n * 2
		n = yield x

We can then use it with

>>> g = useful()
>>> g.send(None) # prime the generator
0
>>> g.send(42)
84

But this is actually hiding some complexity, which we can see if we instrument this function. We could use print statements, but we're going to need some more sophisticated instrumentation later, so we're instead going to just push things onto a list:

e = []
def log(event):
	e.append(event)

def useful():
	r = 0
	log("init")
	while True:
		log("start of loop")
		x = r * 2
		log(f"computation done: {r} * 2 = {x}")
		r = yield x
		log(f"yielded {x=}, got {r=}")

We can test this out with

g = useful()
log(g.send(None))
log("sending 10")
log(g.send(10))
log("sending 11")
log(g.send(11))

This gives us about what we'd expect:

['init',
 'start of loop',
 'computation done: 0 * 2 = 0',
 0,
 'sending 10',
 'yielded x=0, got r=10',
 'start of loop',
 'computation done: 10 * 2 = 20',
 20,
 'sending 11',
 'yielded x=20, got r=11',
 'start of loop',
 'computation done: 11 * 2 = 22',
 22]

What complexity is that hiding? From one perspective, none. This is exactly what the send keyword was built for. Of course this is how you build coroutines. But there is a key difference between a coroutine and an ADF: the ADF is the one doing the driving. That is, in this example, we sent values to the generator and got results back; but in an ADF, the generator submits tasks nebulously and gets the results back. This is completely backwards. For example, we normally have a function that looks something like this:

import azure.functions as func
import azure.durable_functions as df


def orchestrator_function(context: df.DurableOrchestrationContext):
	result1 = yield context.call_activity('E1_SayHello', "Tokyo")
	result2 = yield context.call_activity('E1_SayHello', "Seattle")
	result3 = yield context.call_activity('E1_SayHello', "London")
	return [result1, result2, result3]

main = df.Orchestrator.create(orchestrator_function)

So what happens if we make a generator which looks like that?

def workflow():
	log("init")
	x = 0
	r = yield x
	log(f"yielded {x=}, got {r=}")
	x = 1
	r = yield x
	log(f"yielded {x=}, got {r=}")
	x = 2
	r = yield x
	log(f"yielded {x=}, got {r=}")
	return f"return value {r=}"

Let's pretend that we're submitting tasks and we want the number to be added to 100:

g = workflow()
log(g.send(None))
log(g.send(100))
log(g.send(101))
log(g.send(102))

we get

Traceback (most recent call last):
  File "_includes/resources/azure_functions/durable_testing_python/00_generator_like_sample.py", line 25, in <module>
    log(g.send(102))
StopIteration: 102
>>> pprint.pp(e)
['init',
 0,
 'yielded x=0, got r=100',
 1,
 'yielded x=1, got r=101',
 2,
 'yielded x=2, got r=102']

Well that makes sense. It's a bit weird that the return value comes in the StopIteration, but it's in the documentation for generators.

Background : A basic executor

Let's pretend that instead of mocking out the dispatching, we actually wanted to run the ADF. With the knowledge we gained from exploring generators previously, we know that we'll have 3 different conditions to handle:

  • the initial task which will be submitted through the priming (with send(None))
  • all the intermediate executions
  • the return value which will be sent in a StopIteration

We can then bang out the following executor:

def executor(generator):
	log("init executor")
	arg = generator.send(None)
	try:
		log("entering loop")
		while True:
			value = useful(arg)
			log(value)
			arg = generator.send(value)
	except StopIteration as ret:
		log("stopped iteration")
		log(f"return value is {ret.value}")

We can also restart it equivalently as follows, although I'm not sure it's clearer.

def executor(generator):
	log("init executor")
	value = None
	try:
		log("entering loop")
		while True:
			arg = generator.send(value)
			value = useful(arg)
			log(value)
	except StopIteration as ret:
		log("stopped iteration")
		log(f"return value is {ret.value}")

This simple executor suggests an alternative for mocking the functions that our orchestrator calls out to. Instead of trying to patch those functions, we could simply execute the function-under-test with an executor which will provide the answers we need. This means that we can't use fun mocking libraries out-of-the-box, but it also means that we might not need to use them.

Background : Examining the source code

Microsoft has been open-sourcing a lot of their stuff. This is very convenient for them: Every time their documentation is missing something very basic, if it happens to be important enough for a company, that company might pay someone to dig into that code and write up that documentation.

There are 3 relevant repositories:

The 3rd is the important one. The executor is called TaskOrchestrationExecutor. You'll notice that you can't find calls which invoke the __next__() method (either the next keyword or hidden in iterable operations, like list(generator)) on the generator itself. But it does call send on that generator here. They do iterate over the history here, but it's not as straightforward as that.

User code generator

We'll start closest to our code: the resume_user_code function. I'll first point out that they catch the StopIteration and mark that as the function output, just like out little executor. The most interesting part is:

# resume orchestration with a resolved task's value
task_value = current_task.result
task_succeeded = current_task.state is TaskState.SUCCEEDED
new_task = self.generator.send(
	task_value) if task_succeeded else self.generator.throw(task_value)
self.context._add_to_open_tasks(new_task)

It's a bit hard to parse out, but this code uses the send to send the previous result and to get the next task at the same time. Just like our little executor!
Later, it also shuffles the new task to the current task, and adds it to the list of actions in the context.But if you trace it, you find that this doesnt actually trigger any further execution.

Iterating over the tasks

The only iteration we've seen is over the history parameter. This is passed into the TaskOrchestrationExecutor.execute method. You have to dig in to the internals of the ADF execution to find where this comes from. Essentially, every time a function is ready to advance, the entire function up to that point will be invoked. Every time a task is created (with yield context.call_activity(...), for example), that result has been serialised and is passed back into the function. The orchestrator can then advance until it hits a yield statement which creates a new task.

This replay behaviour is somewhat described in the article on Orchestrator function code constraints. The focus isn't really on the implementation; it's on the natural consequence that only deterministic APIs can be used.

You can see the replay behaviour for yourself by instrumenting a basic function.

import azure.durable_functions as df

class Logotron:
    def __init__(self) -> None:
        self.i = -1
        self.l = []

    def new_span(self):
        self.i +=1
        self.l.append([])
    
    def log(self, e):
        self.l[self.i].append(e)

logotron = Logotron()

def orchestrator_function(context: df.DurableOrchestrationContext):
    logotron.new_span()
    logotron.log("Tokyo")
    result1 = yield context.call_activity('Hello', "Tokyo")
    logotron.log("Seattle")
    result2 = yield context.call_activity('Hello', "Seattle")
    logotron.log("London")
    result3 = yield context.call_activity('Hello', "London")
    logotron.log("Done")
    return logotron.log

main = df.Orchestrator.create(orchestrator_function)

We then get this as the trace. We also get a Non-Deterministic workflow detected warning, which is True.

[['Tokyo'],
 ['Tokyo', 'Seattle'],
 ['Tokyo', 'Seattle', 'London'],
 ['Tokyo', 'Seattle', 'London', 'Done']]

Mocking durable functions

With all the background done, it looks like there are 3 parts to the challenge of implementing mocking for ADF:

  1. feeding in a DurableOrchestrationContext
  2. creating a mock executor
  3. mocking the evaluation of remote calls

DurableOrchestrationContext

There's a bit of shimming that you have to do to get one of these created, but you can just fake all the values and it seems to work as well as we need it to.

def make_ctx() -> df.DurableOrchestrationContext:
	"""Create a DurableOrchestrationContext by filling in dummy values """
	fakeEvent = {
		"EventType": 12,
		"EventId": -1,
		"IsPlayed": False,
		"Timestamp": datetime.datetime.utcnow().isoformat(),
	}
	return df.DurableOrchestrationContext([fakeEvent], "", False, "")

Creating a mock executor

We've done most of the generator work for the executor. Now we have to build it to actually interpret the Tasks that are given to it. There are a few task types that we have to handle: AtomicTask, WhenAllTask, WhenAnyTask. We can also handle RetryAbleTask (subclass of WhenAllTask) and TimerTask (subclass of AtomicTask) separately if we choose, although I'm going to leave that as a contribution for the reader. Handling these is mostly a matter of dispatching down to the AtomicTasks in the CompoundTasks. After that, we just need to unwrap the requested Action and dispatch that to our mocks.

	def execute(self, fn):
		"""Execute an orchestrator function with external calls mocked"""
		ctx = make_ctx()
		g = fn(ctx)
		result = None
		try:
			while True:
				task = g.send(result)
				result = self._handle(task)
		except StopIteration as ret:
			return ret.value

Mocking the evaluation of remote calls

The Action itself has 3 properties we care about:

  • type (ActionType) : the category of thing we want invoked, like an Activity Function or a SubOrchestrator
  • function name : the name of that thing we want invoked
  • intput_ : the paylod to send to that function

So now it's just a matter of building a system which can dispatch the first 2 and pass in the 3rd item. This too isn't hard. One hiccough is that some of the types are sortof equivalent from a unit-testing point of view (CALL_ACTIVITY vs CALL_ACTIVITY_WITH_RETRY). We don't need to handle WHEN_ANY or WHEN_ALL because we expand those into the component tasks, although we do need to add a way of determining which task to resolve.

There's honestly not much special with this code, so I'm just going to skip talking about it. you can see the whole resources in the repo.

Making a useful mocking library

This executor is competent enough at walking through a function, but it lacks the ability to intelligently verify calls like other popular mocking libraries. Good thing all of our functions accept Task objects which contain the function invocation and the argument.
We should try for a similar API to unittest.mock. Implementing it is straightforward, although does require a bit of gruntwork. Reporting results is where most of the utility comes from, and I'll admit that I skimped on that.

import copy
import datetime
import itertools
import json
import operator
import random
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, NewType, Union

import azure.durable_functions as df
import azure.durable_functions.models.actions as dfactions
import azure.durable_functions.models.Task as dftask


def make_ctx() -> df.DurableOrchestrationContext:
	"""Create a DurableOrchestrationContext by filling in dummy values """
	fakeEvent = {
		"EventType": 12,
		"EventId": -1,
		"IsPlayed": False,
		"Timestamp": datetime.datetime.utcnow().isoformat(),
	}
	return df.DurableOrchestrationContext([fakeEvent], "", False, "")


@dataclass
class AZDFMock:
	type_: dfactions.ActionType
	name: str
	fn: Callable


Handlers = NewType("Handlers", Dict[dfactions.ActionType, Dict[str, Callable]])


def mocks2handlers(mocks: List[AZDFMock]) -> Handlers:
	by_type = itertools.groupby(mocks, key=operator.attrgetter("type_"))
	handlers = {type_: {mock.name: mock.fn for mock in mocks} for type_, mocks in by_type}
	return handlers


def combine_handlers(this, that):
	combined = {}
	for type_ in dfactions.ActionType:
		combined[type_] = {**this.get(type_, {}), **that.get(type_, {})}
	return combined


class MockExecutor:
	def __init__(
		self, handlers: Handlers, _select_winning_task: dftask.WhenAnyTask = None
	) -> None:
		self.handlers = handlers
		self._select_winning_task = _select_winning_task or self._select_random_task

		self._calls = []
		self._invocations = []

	@staticmethod
	def _select_random_task(task: dftask.WhenAnyTask):
		"""Resolve a WhenAnyTask by selecting a random Task. This is used as the default"""
		return random.choice(task.children)

	@staticmethod
	def _collapse_types(type_: dfactions.ActionType) -> dfactions.ActionType:
		remapped = {
			dfactions.ActionType.CALL_ACTIVITY_WITH_RETRY: dfactions.ActionType.CALL_ACTIVITY,
			dfactions.ActionType.CALL_SUB_ORCHESTRATOR_WITH_RETRY: dfactions.ActionType.CALL_SUB_ORCHESTRATOR,
		}
		if type_ in remapped:
			return remapped["type"]
		else:
			return type_

	@classmethod
	def create(cls, mocks: Union[AZDFMock, Iterable[AZDFMock]]) -> "MockExecutor":
		"""Create an executor with handlers taken from mocks"""
		if isinstance(mocks, AZDFMock):
			mocks = [mocks]
		return cls(mocks2handlers(mocks))

	@classmethod
	def lax(cls, default: Callable = lambda x: x) -> "MockExecutor":
		"""Create an executor where all not-found handlers will be passed to a default.
		If the `default` is not supplied, it defaults to just returning the arguments."""
		default_by_type = defaultdict(lambda: default)
		default_handler = defaultdict(lambda: default_by_type)
		return cls(default_handler)

	def with_handlers(self, handlers: Handlers) -> "MockExecutor":
		"""Specialise a copy of this executor with another handler tree"""
		return MockExecutor(combine_handlers(self.handlers, handlers))

	def with_mock(self, mock: Union[AZDFMock, Iterable[AZDFMock]]):
		"""Specialise a copy of this executor with other mocks"""
		if isinstance(mock, AZDFMock):
			mock = [mock]
		return MockExecutor(combine_handlers(self.handlers, mocks2handlers(mock)))

	def execute(self, fn):
		"""Execute an orchestrator function with external calls mocked"""
		ctx = make_ctx()
		g = fn(ctx)
		result = None
		try:
			while True:
				task = g.send(result)
				result = self._handle(task)
		except StopIteration as ret:
			return ret.value

	def _handle(self, task: dftask.TaskBase):
		self._invocations.append(task)
		if isinstance(task, dftask.AtomicTask):
			return self._handle_task(task)
		elif isinstance(task, dftask.WhenAllTask):
			return self._handle_task_all(task)
		elif isinstance(task, dftask.WhenAnyTask):
			return self._handle_task_any(task)

	def _handle_task(self, task):
		self._calls.append(task)
		return self._handle_action(task._get_action())

	def _handle_task_all(self, task: dftask.WhenAllTask):
		return [self._handle_task(x) for x in task.children]

	def _handle_task_any(self, task: dftask.WhenAnyTask):
		winning_task = copy.copy(self._select_winning_task(task))
		result = self._handle_task(winning_task)
		setattr(winning_task, "result", result)
		return winning_task

	def _handle_action(self, action: dfactions.Action):
		return self.handlers[self._collapse_types(action.action_type)][
			action.function_name
		](json.loads(action.input_))

	def invocations(self):
		"""
		RMIs that are submitted for execution through `yield` statements.
		WhenAllTasks will have subcalls nested within it.
		WhenAnyTasks will have _all_ of their subcalls nested as children
		"""
		return self._invocations

	def calls(self):
		"""
		RMIs that are actually executed
		`WhenAllTask`s will not be included, but their children will
		`WhenAnyTask`s will not be included, and only the "winning" task will be included
		"""
		return self._calls

	def _find_calls_matching_action(
		self, predicate: Callable[[dftask.TaskBase, dfactions.Action], bool]
	) -> List[dftask.TaskBase]:
		calls_and_actions = ((t, t._get_action()) for t in self.calls())
		# matched = next(ca for ca in calls_and_actions if predicate(*ca), (None, None))
		matched = filter(lambda ca: predicate(*ca), calls_and_actions)
		tasks = map(lambda ca: ca[0], matched)
		return list(tasks)

	def _find_called(self, action_type, function_name) -> List[dftask.TaskBase]:
		def _p(c, a):
			return a.action_type == action_type and a.function_name == function_name

		return self._find_calls_matching_action(_p)

	def assert_called(self, action_type, function_name):
		assert len(self._find_called(action_type, function_name)) > 0

	def assert_called_once(self, action_type, function_name):
		assert len(self._find_called(action_type, function_name)) == 1

	def _find_called_with(self, action: dfactions.Action):
		action_as_json = action.to_json()

		def _p(c, a):
			return a.to_json() == action_as_json

		return self._find_calls_matching_action(_p)

	def assert_called_with(self, action: dfactions.Action):
		c = self.calls()[-1]
		assert c._get_action().to_json() == action.to_json()

	def assert_called_once_with(self, action: dfactions.Action):
		self.assert_called_once(action.action_type, action.function_name)
		assert len(self._find_called_with(action)) == 1

	def assert_any_call(self, action: dfactions.Action):
		assert len(self._find_called_with(action)) > 0

	def assert_has_calls(
		self, actions: Iterable[dfactions.Action], any_order: bool = False
	):
		if any_order:
			for action in actions:
				self.assert_any_call(action)
		else:
			remaining_calls = iter(self.calls())
			remaining_actions = iter(actions)

			def find_next_call_matching(a: dfactions.Action):
				while True:
					call = next(remaining_calls)
					if call._get_action().to_json() == a.to_json():
						return

			for action in remaining_actions:
				try:
					find_next_call_matching(action)
				except StopIteration:
					unmatched = [action] + list(remaining_actions)
					raise AssertionError("not all calls matched", unmatched)

	def assert_not_called(self, action_type, function_name):
		assert len(self._find_called(action_type, function_name)) == 0