Airflow start multiple concurrent generic tasks - airflow

Trying to get a few tasks concurrently on cloud composer:
arr = {}
for i in xrange(3):
print("i: " + str(i))
command_formatted = command_template.format(str(i))
create_training_instance = bash_operator.BashOperator(
task_id='create_training_instance',
bash_command=command_formatted)
arr[i] = create_training_instance
start_training.set_downstream(arr[i])
Getting the following error:
Broken DAG: [/home/airflow/gcs/dags/scale_simple.py] Dependency
, create_training_instance already
registered

The task_id should always be unique for a single task. So, you can use something like create_training_instance_{}.format(i) as task_id.

You need to parameterize your task id as well, e.g.,
task_id='create_training_instance' --> 'create_traiing_instance-{}'.format(i)

Related

Apache Airflow unit and integration test

I am new to Apache Airflow and I am trying to figure out how to unit/integration test my dags/tasks
Here is my directory structure
/airflow
/dags
/tests/dags
I created a simple DAG which has a task to reads data from a Postgres table
def read_files(ti):
sql = "select id from files where status='NEW'"
pg_hook = PostgresHook(postgres_conn_id="metadata")
connection = pg_hook.get_conn()
cursor = connection.cursor()
cursor.execute(sql)
files = cursor.fetchall()
ti.xcom_push(key="files_to_process", value=files)
with DAG(dag_id="check_for_new_files", schedule_interval=timedelta(minutes=30),
start_date=datetime(2022, 9, 1), catchup=False) as dag:
check_files = PythonOperator(task_id="read_files",
python_callable=read_files)
Is it possible to test this by mocking Airflow/Postgres connection etc
yes it is possible to do test in dags, here is an example of basic things you can do:
import unittest
from airflow.models import DagBag
class TestCheckForNewFilesDAG(unittest.TestCase):
"""Check Dag"""
def setUp(self):
self.dagbag = DagBag()
def test_task_count(self):
"""Check task count for a dag"""
dag_id='check_for_new_files'
dag = self.dagbag.get_dag(dag_id)
self.assertEqual(len(dag.tasks), 1)
def test_contain_tasks(self):
"""Check task contains in hello_world dag"""
dag_id='check_for_new_files'
dag = self.dagbag.get_dag(dag_id)
tasks = dag.tasks
task_ids = list(map(lambda task: task.task_id, tasks))
self.assertListEqual(task_ids, ['read_files'])
def test_dependencies_of_read_files_task(self):
"""Check the task dependencies of a taskin hello_world dag"""
dag_id='check_for_new_files'
dag = self.dagbag.get_dag(dag_id)
read_files_task = dag.get_task('read_files')
# to be use in case you have upstream task
upstream_task_ids = list(map(lambda task: task.task_id,
read_files_task.upstream_list))
self.assertListEqual(upstream_task_ids, [])
downstream_task_ids = list(map(lambda task: task.task_id,
read_files_task.downstream_list))
self.assertListEqual(downstream_task_ids, [])
suite = unittest.TestLoader().loadTestsFromTestCase(TestHelloWorldDAG)
unittest.TextTestRunner(verbosity=2).run(suite)
In case of verifying that manipulated data of files are moved correctly the documentations suggest:
https://airflow.apache.org/docs/apache-airflow/2.0.1/best-practices.html#self-checks
Self-Checks
You can also implement checks in a DAG to make sure the tasks are producing the results as expected. As an example, if you have a task that pushes data to S3, you can implement a check in the next task. For example, the check could make sure that the partition is created in S3 and perform some simple checks to determine if the data is correct.
I think this is an excellent and straightforward way to verify a specific task.
Here there are other useful links you can use:
https://www.youtube.com/watch?v=ANJnYbLwLjE
In the next ones, they talk about mock
https://www.astronomer.io/guides/testing-airflow/
https://medium.com/#montadhar/apache-airflow-testing-guide-7956a3f4bbf5
https://godatadriven.com/blog/testing-and-debugging-apache-airflow/

Airflow how to connect the previous task to the right next dynamic branch with multiple tasks?

I am facing this situation:
I have generated two dynamic branches. Each branch has multiple chained tasks.
This is what I need Airflow create for me:
taskA1->taskB1 taskC1->taskD1
taskA2->taskB2... taskZ.. taskC2->taskD2
taskA3->taskB3 taskC3->taskD3
and here is my sudocode:
def create_branch1(task_ids):
source = []
for task_id in task_ids:
source += [
Operator1(task_id=’task_A{0}'.format(task_id))) >>
Operator2(task_id=’task_B{0}’.format(task_id)) ]
return source
def create_branch2(task_ids):
source = []
for task_id in task_ids:
source += [
Operator1(task_id=’task_C{0}'.format(task_id))) >>
Operator2(task_id=’task_D{0}’.format(task_id)) ]
return source
create_branch1 >> dummyOperator(Z) >> create_branch2 >> end
However, what the Airflow generates, looks like this:
taskA1->taskB1 taskD1<-taskC1
taskA2->taskB2...taskZ...taskD2<-taskC2
taskA3->taskB3 taskD3<-taskC3
I mean in the second branch, dummyOperator(Z) will be connected to the last task of the chain (D), instead of connecting to the first task of the chain in the second branch (C).
It seems, no matters what, DummpyOperator(task-Z) will connect to the last task of the chained branches.
Do you have any idea, how to tackle this issue?

How to write unittest for #task decorated Airflow tasks?

I am trying to write unittests for some of the tasks built with Airflow TaskFlow API. I tried multiple approaches for example, by creating a dagrun or only running the task function but nothing is helping.
Here is a task where I download a file from S3, there is more stuff going on but I removed that for this example.
#task()
def updates_process(files):
context = get_current_context()
try:
updates_file_path = utils.download_file_from_s3_bucket(files.get("updates_file"))
except FileNotFoundError as e:
log.error(e)
return
# Do something else
Now I was trying to write a test case where I can check this except clause. Following is one the example I started with
class TestAccountLinkUpdatesProcess(TestCase):
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
task = account_link_updates_process({"updates_file": "path/to/file.csv"})
get_current_context.assert_called_once()
log.error.assert_called_once()
I also tried by creating a dagrun as shown in the example here in docs and fetching the task from the dagrun but that also didin't help.
I was struggling to do this myself, but I found that the decorated tasks have a .function parameter: https://github.dev/apache/airflow/blob/be7cb1e837b875f44fcf7903329755245dd02dc3/airflow/decorators/base.py#L522
You can then use .funciton to call the actual function. Using your example:
class TestAccountLinkUpdatesProcess(TestCase):
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
task = dags.delta_load.updates.updates_process
# Call the function for testing
task.function({"updates_file": "path/to/file.csv"})
get_current_context.assert_called_once()
log.error.assert_called_once()
This prevents you from having to set up any of the DAG infrastructure and just run the python function as intended!
This is what I could figure out. Not sure if this is the right thing but it works.
class TestAccountLinkUpdatesProcess(TestCase):
TASK_ID = "updates_process"
#classmethod
def setUpClass(cls) -> None:
cls.dag = dag_delta_load()
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
task = self.dag.get_task(task_id=self.TASK_ID)
task.op_args = [{"updates_file": "file.csv"}]
task.execute(context={})
log.error.assert_called_once()
UPDATE: Based on the answer of #AetherUnbound I did some investigation and found that we can use task.__wrapped__() to call the actual python function.
class TestAccountLinkUpdatesProcess(TestCase):
#mock.patch("dags.delta_load.updates.log")
#mock.patch("dags.delta_load.updates.get_current_context")
#mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
download_file_from_s3_bucket.side_effect = FileNotFoundError
update_process.__wrapped__({"updates_file": "file.csv"})
log.error.assert_called_once()

Airflow 2.0 - running locally keeps running the function

I have the below task keeps running I know this because it runs a query in Snowflake and I keep getting the DUO push notification. every. 5. seconds! What can I do to stop this and only have it run when the DAG runs
This is the task:
create_foreign_keys = SnowflakeQueryOperator(
dag=dag,
task_id='check_and_run_foreign_key_query',
sql=SnowHook().run_fk_alter_statements(schema,query),
trigger_rule=TriggerRule.ALL_DONE
)
This is the method being called in the sql part:
def run_fk_alter_statements(self, schema, additional_fk):
fk_query_path = "/fkeys.sql"
fd = open(f'{fk_query_path}', 'r')
query = fd.read()
fd.close()
additions = []
for fk in additional_fk:
additions.append(f""" or (t2.table_name = '{fk['table_name']}' and t2.column_name = '{fk['column_name']}'
and t1.table_name = '{fk['ref_table_name']}' and t1.column_name = '{fk['ref_column_name']}')\n""".upper())
raw_out = self.execute_query(query.format(schema=schema, fks=''.join(additions)), fetch_all=True)
query_jobs = []
for raw_query in raw_out:
query_jobs.append(raw_query[0])
return query_jobs
The sql=SnowHook().run_fk_alter_statements(schema,query) call in your instantiation of the SnowflakeQueryOperator is actually top-level code so it will execute every time the DAG is parsed by the Scheduler. You need to find a way to have that function called within an operator's execute() method.
You could add a TaskFlow function/PythonOperator task to push the output from run_fk_alter_statements() to XCom and then the SnowflakeQueryOperator uses this XCom to execute the SQL(s) that's generated.

Airflow task setup with execution date

I want to customize the task to be weekday dependent in the dag file. It seems the airflow macros like {{ next_execution_date }} are not directly available in the python dag file. This is my dag definition:
RG_TASKS = {
'asia': {
'start_date': pendulum.datetime.(2021,1,1,16,0,tzinfo='Asia/Tokyo'),
'tz': 'Asia/Tokyo',
'files': [
'/path/%Y%m%d/asia_file1.%Y%m%d.csv',
'/path/%Y%m%d/asia_file2.%Y%m%d.csv',
...], },
'euro': {
'start_date': pendulum.datetime.(2021,1,1,16,0,tzinfo='Europe/London'),
'tz': 'Europe/London',
'files': [
'/path/%Y%m%d/euro_file1.%Y%m%d.csv',
'/path/%Y%m%d/euro_file2.%Y%m%d.csv',
...], },
}
dag = DAG(..., start_date=pendulum.datetime.(2021,1,1,16,0,tzinfo='Asia/Tokyo'),
schedule='00 16 * * 0-6')
for rg, t in RG_TASKS.items():
tz = t['tz']
h = t['start_date'].hour
m = t['start_date'].minute
target_time = f'{{{{ next_execution_date.replace(tzinfo="{tz}", hour={h}, minute={m}) }}}}'
time_sensor = DateTimeSensor(dag=dag, task_id=f'wait_for_{rg}', tartget_time=target_time)
bash_task = BashOperator(dag=dag, task_id='load_{rg}', trigger_rule='all_success', depends_on_past=True, bash_command=...)
for fname in t['files']:
fpath = f'{{{{ next_execution_date.strftime("{fname}") }}}}'
task_id = os.path.basename(fname).split('.')[0]
file_sensor = FileSensor(dag=dag, task_id=task_id, filepath=fpath, ...)
file_sensor.set_upstream(time_sensor)
file_sensor.set_downstream(bash_task)
The above works, and the bash_task will be triggered if all files are available, and it is set depend_on_past=True. However, the files have slightly different schedule. {rg}_file1 will be available 6 days/week, except Saturday, while the rest are available 7 days a week.
One option is to create 2 dags, one scheduled to run Sun-Fri, while the other is scheduled to run Sat only. But with this option, the depends_on_past=True is broken on Saturday.
Is there any better way to keep depends_on_past=True 7 days/week? Ideally in the files loop, I could do sth like:
for fname in t['files']:
dt = ...
if dt.weekday()==5 and task_id==f'{rg}_file1':
continue
Generally I think it's better to accomplish things in a single task when it is easy enough to do, and in this case it seem to me you can.
I'm not entirely sure why you are using a datetime sensor, but it does not seem necessary. As far as I can tell, you just want your process to run every day (ideally after the file is there) and skip once per week.
I think we can do away with file sensor too.
Option 1: everything in bash
Check for existence in your bash script and fail (with retries) if missing. Just return non-zero exit code when file missing.
Then in your bash script you could silently do nothing on the skip day.
On skip days, your bash task will be green even though it did nothing.
Option 2: subclass bash operator
Subclass BashOperator and add a skip_day parameter. Then your execute is like this:
def execute(self, context):
next_execution_date = context['next_execution_date']
if next_execution_date.day_of_week == self.skip_day:
raise AirflowSkipException(f'we skip on day {self.skip_day}')
super().execute(context)
With this option your bash script still needs to fail if file missing, but doesn't need to deal with the skip logic. And you'll be able to see that the task skipped in the UI.
Either way, no sensors.
Other note
You can simplify your filename templating.
files=[
'/path/{{ next_ds_nodash }}/euro_file2.{{ next_ds_nodash }}.csv',
...
]
Then you don't need to mess with strftime.

Resources