How to write unittest for #task decorated Airflow tasks? - airflow

I am trying to write unittests for some of the tasks built with Airflow TaskFlow API. I tried multiple approaches for example, by creating a dagrun or only running the task function but nothing is helping.
Here is a task where I download a file from S3, there is more stuff going on but I removed that for this example.
#task()
def updates_process(files):
context = get_current_context()
try:
updates_file_path = utils.download_file_from_s3_bucket(files.get("updates_file"))
except FileNotFoundError as e:
log.error(e)
return
# Do something else
Now I was trying to write a test case where I can check this except clause. Following is one the example I started with
class TestAccountLinkUpdatesProcess(TestCase):
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
task = account_link_updates_process({"updates_file": "path/to/file.csv"})
get_current_context.assert_called_once()
log.error.assert_called_once()
I also tried by creating a dagrun as shown in the example here in docs and fetching the task from the dagrun but that also didin't help.

I was struggling to do this myself, but I found that the decorated tasks have a .function parameter: https://github.dev/apache/airflow/blob/be7cb1e837b875f44fcf7903329755245dd02dc3/airflow/decorators/base.py#L522
You can then use .funciton to call the actual function. Using your example:
class TestAccountLinkUpdatesProcess(TestCase):
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
task = dags.delta_load.updates.updates_process
# Call the function for testing
task.function({"updates_file": "path/to/file.csv"})
get_current_context.assert_called_once()
log.error.assert_called_once()
This prevents you from having to set up any of the DAG infrastructure and just run the python function as intended!

This is what I could figure out. Not sure if this is the right thing but it works.
class TestAccountLinkUpdatesProcess(TestCase):
TASK_ID = "updates_process"
#classmethod
def setUpClass(cls) -> None:
cls.dag = dag_delta_load()
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
task = self.dag.get_task(task_id=self.TASK_ID)
task.op_args = [{"updates_file": "file.csv"}]
task.execute(context={})
log.error.assert_called_once()
UPDATE: Based on the answer of #AetherUnbound I did some investigation and found that we can use task.__wrapped__() to call the actual python function.
class TestAccountLinkUpdatesProcess(TestCase):
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
update_process.__wrapped__({"updates_file": "file.csv"})
log.error.assert_called_once()

Related

Apache Airflow unit and integration test

I am new to Apache Airflow and I am trying to figure out how to unit/integration test my dags/tasks
Here is my directory structure
/airflow
/dags
/tests/dags
I created a simple DAG which has a task to reads data from a Postgres table
def read_files(ti):
sql = "select id from files where status='NEW'"
pg_hook = PostgresHook(postgres_conn_id="metadata")
connection = pg_hook.get_conn()
cursor = connection.cursor()
cursor.execute(sql)
files = cursor.fetchall()
ti.xcom_push(key="files_to_process", value=files)
with DAG(dag_id="check_for_new_files", schedule_interval=timedelta(minutes=30),
start_date=datetime(2022, 9, 1), catchup=False) as dag:
check_files = PythonOperator(task_id="read_files",
python_callable=read_files)
Is it possible to test this by mocking Airflow/Postgres connection etc
yes it is possible to do test in dags, here is an example of basic things you can do:
import unittest
from airflow.models import DagBag
class TestCheckForNewFilesDAG(unittest.TestCase):
"""Check Dag"""
def setUp(self):
self.dagbag = DagBag()
def test_task_count(self):
"""Check task count for a dag"""
dag_id='check_for_new_files'
dag = self.dagbag.get_dag(dag_id)
self.assertEqual(len(dag.tasks), 1)
def test_contain_tasks(self):
"""Check task contains in hello_world dag"""
dag_id='check_for_new_files'
dag = self.dagbag.get_dag(dag_id)
tasks = dag.tasks
task_ids = list(map(lambda task: task.task_id, tasks))
self.assertListEqual(task_ids, ['read_files'])
def test_dependencies_of_read_files_task(self):
"""Check the task dependencies of a taskin hello_world dag"""
dag_id='check_for_new_files'
dag = self.dagbag.get_dag(dag_id)
read_files_task = dag.get_task('read_files')
# to be use in case you have upstream task
upstream_task_ids = list(map(lambda task: task.task_id,
read_files_task.upstream_list))
self.assertListEqual(upstream_task_ids, [])
downstream_task_ids = list(map(lambda task: task.task_id,
read_files_task.downstream_list))
self.assertListEqual(downstream_task_ids, [])
suite = unittest.TestLoader().loadTestsFromTestCase(TestHelloWorldDAG)
unittest.TextTestRunner(verbosity=2).run(suite)
In case of verifying that manipulated data of files are moved correctly the documentations suggest:
https://airflow.apache.org/docs/apache-airflow/2.0.1/best-practices.html#self-checks
Self-Checks
You can also implement checks in a DAG to make sure the tasks are producing the results as expected. As an example, if you have a task that pushes data to S3, you can implement a check in the next task. For example, the check could make sure that the partition is created in S3 and perform some simple checks to determine if the data is correct.
I think this is an excellent and straightforward way to verify a specific task.
Here there are other useful links you can use:
https://www.youtube.com/watch?v=ANJnYbLwLjE
In the next ones, they talk about mock
https://www.astronomer.io/guides/testing-airflow/
https://medium.com/#montadhar/apache-airflow-testing-guide-7956a3f4bbf5
https://godatadriven.com/blog/testing-and-debugging-apache-airflow/

Why does a pythonoperator callable not need to accept parameters in airflow?

I do not understand how callables (function called as specified by PythonOperator) n Airflow should have their parameter list set. I have seen the with no parameters or with named params or **kwargs. I can always add "ti" or **allargs as parameters it seems, and ti seems to be used for task instance info, or ds for execution date. But my callables do not NEED params apparently. They can be simply be "def function():". If I wrote a regular python function func() instead of func(**kwargs), it would fail at runtime when called unless no params were passed. Airflow always seems to pass t1 all the time, so how can the callable function signature not require it?? Example below from a training site where _process_data func gets the ti, but _extract_bitcoin_price() does not. I was thinking that is because of the xcom push, but ti is ALWAYS available it seems, so how can "def somefunc()" ever work? I tried looking at pythonoperator source code, but I am unclear how this works or best practices for including parameters in a callable. Thanks!!
from airflow import DAG
from airflow.operators.python_operator
import PythonOperator
from datetime import datetime
import json
from typing import Dict
import requests
import logging
API = "https://api.coingecko.com/api/v3/simple/price?ids=bitcoin&vs_currencies=usd&include_market_cap=true&include_24hr_vol=true&include_24hr_change=true&include_last_updated_at=true"
def \_extract_bitcoin_price():
return requests.get(API).json()\['bitcoin'\]
def \_process_data(ti):
response = ti.xcom_pull(task_ids='extract_bitcoin_price')
logging.info(response)
processed_data = {'usd': response\['usd'\], 'change': response\['usd_24h_change'\]}
ti.xcom_push(key='processed_data', value=processed_data)
def \_store_data(ti):
data = ti.xcom_pull(task_ids='process_data', key='processed_data')
logging.info(f"Store: {data\['usd'\]} with change {data\['change'\]}")
with DAG('classic_dag', schedule_interval='#daily', start_date=datetime(2021, 12, 1), catchup=False) as dag:
extract_bitcoin_price = PythonOperator(
task_id='extract_bitcoin_price',
python_callable=_extract_bitcoin_price
)
process_data = PythonOperator(
task_id='process_data',
python_callable=_process_data
)
store_data = PythonOperator(
task_id='store_data',
python_callable=_store_data
)
extract_bitcoin_price >> process_data >> store_data
Tried callables with no params somefunc() expecting to get error saying too many params passed, but it succeeded. Adding somefunc(ti) also works! How can both work?
I think what you are missing is that Airflow allows to pass the context of the task to the python callable (as you can see one of them is the ti). These are additional useful parameters that Airflow provides and you can use them in your task.
In older Airflow versions user had to set provide_context=True which for that to work:
process_data = PythonOperator(
...,
provide_context=True
)
Since Airflow>=2.0 there is no need to use provide_context. Airflow handles it under the hood.
When you see in the Python Callable signatures like:
def func(ti, **kwargs):
...
This means that the ti is "unpacked" from the kwargs. You can also do:
def func(**kwargs):
ti = kwargs['ti']
EDIT:
I think what you are missing is that while you write:
def func()
...
store_data = PythonOperator(
task_id='task',
python_callable=func
)
Airflow does more than just calling func. The code being executed is the execute() function of PythonOperator and this function calls the python callable you provided with args and kwargs.

dagster solid Parallel Run Test exmaple

I tried to run parallel, but it didn't work as I expected
The progress bar doesn't work the way I thought it would.
I think that both operations should be executed at the same time.
but first run find_highest_calorie_cereal after find_highest_protein_cereal
import csv
import time
import requests
from dagster import pipeline, solid
# start_complex_pipeline_marker_0
#solid
def download_cereals():
response = requests.get("https://docs.dagster.io/assets/cereal.csv")
lines = response.text.split("\n")
return [row for row in csv.DictReader(lines)]
#solid
def find_highest_calorie_cereal(cereals):
time.sleep(5)
sorted_cereals = list(
sorted(cereals, key=lambda cereal: cereal["calories"])
)
return sorted_cereals[-1]["name"]
#solid
def find_highest_protein_cereal(context, cereals):
time.sleep(10)
sorted_cereals = list(
sorted(cereals, key=lambda cereal: cereal["protein"])
)
# for i in range(1, 11):
# context.log.info(str(i) + '~~~~~~~~')
# time.sleep(1)
return sorted_cereals[-1]["name"]
#solid
def display_results(context, most_calories, most_protein):
context.log.info(f"Most caloric cereal 테스트: {most_calories}")
context.log.info(f"Most protein-rich cereal: {most_protein}")
#pipeline
def complex_pipeline():
cereals = download_cereals()
display_results(
most_protein=find_highest_protein_cereal(cereals),
most_calories=find_highest_calorie_cereal(cereals),
)
I am not sure but I think you should set up a executor with parallelism available. You could use multiprocess_executor.
"Executors are responsible for executing steps within a pipeline run.
Once a run has launched and the process for the run, or run worker,
has been allocated and started, the executor assumes responsibility
for execution."
modes provide the possible set of executors one can use. Use the executor_defs property on ModeDefinition.
MODE_DEV = ModeDefinition(name="dev", executor_defs=[multiprocess_executor])
#pipeline(mode_defs=[MODE_DEV], preset_defs=[Preset_test])
the execution config section of the run config determines the actual executor.
in the yml file or run_config, set:
execution:
multiprocess:
config:
max_concurrent: 4
retrieved from : https://docs.dagster.io/deployment/executors

mocked function is called but fails assert_called test

I have a test like:
import somemodule
import somemodule2
class SomeTestCase(unittest.TestCase):
def setUp(self):
super().setUp()
self.ft_mock = mock.MagicMock(spec=somemodule.SomeClass,
name='mock_it')
def tearDown(self) -> None:
self.ft_mock.reset_mock(return_value=True, side_effect=True)
#mock.patch('somemodule2.someFunk2',
name = 'mock_2',
spec=True,
side_effect='some val')
def testSomething(self, mock_rr):
m = self.ft_mock
m.addMemberSomething()
self.assertEqual(len(m.addMember.call_args_list), 1)
m.addSomethingElse.return_value = 'RETURN IT'
m.addSomethingElse()
m.addSomethingElse.assert_called_once()
res = somemodule.FooClass.foo()
mock_rr.assert_called()
Here somemodule.FooClass.foo() internally calls someFunk2 from somemodule2 which has been mocked as mock_rr
In the test debug it does call it as i print out a line from someFunk2 but testing it when mock_rr.assert_called() is called, it throws:
AssertionError: Expected 'mock_2' to have been called.
I've tried several ways using patch and patch.object
The issue was that somemodule also imported someFunk2 from somemodule2.
So when patching/mocking, the patched object was from somemodule2 while the object that Needed to be patched was somemodule.someFunk2

About tornado.gen.Task's usage - asynchronous request

below is my source:
class Get_Salt_Handler(tornado.web.RequestHandler):
#tornado.web.asynchronous
#tornado.gen.coroutine
def get(self):
#yield tornado.gen.Task(tornado.ioloop.IOLoop.instance().add_timeout, time.time() + 5)
yield tornado.gen.Task(self.get_salt_from_db, 123)
self.write("when i sleep 5s")
def get_salt_from_db(self, params):
print params
and I run it; the console reported that:
TypeError: get_salt_from_db() got an unexpected keyword argument 'callback'
and I don't know why?
gen.Task is used to adapt a callback-based function to the coroutine style; it cannot be used to call synchronous functions. What you probably want is a ThreadPoolExecutor (standard in Python 3.2+, available with pip install futures on Python 2):
# global
executor = concurrent.futures.ThreadPoolExecutor(NUM_THREADS)
#gen.coroutine
def get(self):
salt = yield executor.submit(self.get_salt_from_db)

Resources