I have a few airflow tasks, I want to make sure a specific task executes always as the last task in my DAG. There might be some tasks that will be skipped depending on the input.
this_will_run_for_sure_1 = DummyOperator()
this_will_run_for_sure_2 = DummyOperator()
this_might_not_run_1 = DummyOperator() # Can run in parallel with this_will_run_for_sure_1
this_might_not_run_2 = DummyOperator() # Can run only after this_will_run_for_sure_2 and this_might_not_run_1 finished
final_task = DummyOperator()
How could be this set up with bitwise operators? I was thinking something like:
this_will_run_for_sure_1 >> this_will_run_for_sure_2
this_might_not_run_1 >> this_might_not_run_2
[this_will_run_for_sure_2, this_might_not_run_2] >> final_task
But this does not work as expected, if this_might_not_run_1 does not start because the input does not have the required data for it to start, it won't get to final task. Any help on this is much appreciated
I think you are just missing also:
this_will_run_for_sure_2 >> this_might_not_run_2
This should resolve your issue:
from datetime import datetime
from airflow.decorators import dag
from airflow.operators.dummy import DummyOperator
#dag(
schedule_interval=None,
start_date=datetime(2022, 1, 1)
)
def example():
this_will_run_for_sure_1 = DummyOperator(task_id='this_will_run_for_sure_1')
this_will_run_for_sure_2 = DummyOperator(task_id='this_will_run_for_sure_2')
this_might_not_run_1 = DummyOperator(task_id='this_might_not_run_1') # Can run in parallel with this_will_run_for_sure_1
this_might_not_run_2 = DummyOperator(task_id='this_might_not_run_2') # Can run only after this_will_run_for_sure_2 and this_might_not_run_1 finished
final_task = DummyOperator(task_id='final_task')
this_will_run_for_sure_1 >> this_will_run_for_sure_2 >> this_might_not_run_2
this_might_not_run_1 >> this_might_not_run_2
[this_will_run_for_sure_2, this_might_not_run_2] >> final_task
dag = example()
Example:
Related
Is it possible to change branches based on environment variable set in Airflow.
For example:
Variable.get(‘Enviroment’) returns “dev”, “test”, or “prod”
Let’s say I have at task, Task_B, that I don’t want to run in “dev”.
Start_dag_task = DummyOperator(task_id=”Start_dag_task”, ...)
Task_A = PythonOperator(task_id=”Task_A”,…)
Task_B = PythonOperator(task_id=”Task_B”,…) (Do not run in dev environment)
Task_C= PythonOperator(task_id=”Task_C”,…)
End_dag_task = DummyOperator(task_id=”End_dag_task”, …)
env = Variable.get(“Environment”)
Start_dag_task >> Task_A >> End_dag_task
if env != “dev”:
Start_dag_task >> Task_B >> End_dag_task
Start_dag_task >> Task_C >> End_dag_task
Is this possible, to run this same code in all three environments and have Task_B not run in Dev?
You can achieve that by adding a ShortCircuitOperator before task B to check if the variable env value is dev or not, if it's dev, the task B will be skipped. But you need to set ignore_downstream_trigger_rules to False in order to execute the End_dag_task and the others downstream tasks, and set End_dag_task trigger_rule to NONE_FAILED to execute when the task B state is success or skipped:
from datetime import datetime
from airflow import DAG
from airflow.models import Variable
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator, PythonOperator
from airflow.utils.trigger_rule import TriggerRule
with DAG(
"dag_id",
schedule_interval=None,
start_date=datetime(2022, 1, 1),
) as dag:
Start_dag_task = EmptyOperator(task_id="Start_dag_task")
Task_A = PythonOperator(task_id="Task_A", python_callable=lambda: print("Task A"))
Task_B = PythonOperator(task_id="Task_B", python_callable=lambda: print("B"))
Task_C = PythonOperator(task_id="Task_C", python_callable=lambda: print("C"))
End_dag_task = EmptyOperator(task_id="End_dag_task", trigger_rule=TriggerRule.NONE_FAILED)
shortCircuitTaskB = ShortCircuitOperator(
task_id="short_circuit_for_task_B",
python_callable=lambda: Variable.get("env", "default_env") != "dev",
ignore_downstream_trigger_rules=False,
)
Start_dag_task >> Task_A >> End_dag_task
Start_dag_task >> shortCircuitTaskB >> Task_B >> End_dag_task
Start_dag_task >> Task_C >> End_dag_task
I've an Airflow DAG where I've a task_group with a loop inside that generates two dynamic tasks. After the task_group I need to perform other actions. My problem is:
Inside the task_group I've a branching operators that validates if the last task should run or not. In case of one of the two flows are completed with success, I want to continue my process. For that I'm using the trigger_rule one_success. My code:
with DAG(
dag_id='hello_world',
schedule_interval=None,
start_date=datetime(2022, 8, 25),
default_args=default_args,
max_active_runs=1,
catchup = False,
concurrency = 1,
) as dag:
task_a = DummyOperator(task_id="task_a")
with TaskGroup(group_id='task_group') as my_group:
my_list = ['a','b']
for i in my_list:
task_b = PythonOperator(
task_id="task_a_".format(i),
python_callable=p_task_1)
var_to_continue = check_status(i)
is_running = ShortCircuitOperator(
task_id="is_{}_running".format(i),
python_callable=lambda x: x in [True],
op_args=[var_to_continue])
task_c = PythonOperator(
task_id="task_a_".format(i),
python_callable=p_task_2)
task_b >> is_running >> task_c
task_d = DummyOperator(task_id="task_c",trigger_rule=TriggerRule.ONE_SUCCESS)
task_a >> my_group >> task_d
My problem is: if one of the iterations return skipped the task_d is always skipped, even one of the flow return success.
Do you know how to resolve this?
Thanks!
After a deep search, I found the problem.
In fact, by default ShortCircuitOperator ignore all the downstream tasks trigger rules, if its value is False, it will cut the circuit, which means it will skip all the downstream tasks (its downstream tasks and their downstream tasks and their downstream tasks, ...).
In Airflow 2.3.0, in this PR, they added a new argument ignore_downstream_trigger_rules with default value True to ignore the downstream trigger rules, but you can stop that by providing a False value.
If you are using a version older than 2.3.0, you should replace the operator ShortCircuitOperator by another solution, for ex:
def check_condition():
if not condition: # add your logic
raise AirflowSkipException()
is_running = PythonOperator(..., python_callable=check_condition)
is_running >> task_c
I read this How to use airflow xcoms with MySqlOperator and while it has a similiar title it doesn't really address my issue.
I have the following code:
def branch_func_is_new_records(**kwargs):
ti = kwargs['ti']
xcom = ti.xcom_pull(task_ids='query_get_max_order_id')
string_to_print = 'Value in xcom is: {}'.format(xcom)
logging.info(string_to_print)
if int(xcom) > int(LAST_IMPORTED_ORDER_ID)
return 'import_orders'
else:
return 'skip_operation'
query_get_max_order_id = 'SELECT COALESCE(max(orders_id),0) FROM warehouse.orders where orders_id>1 limit 10'
get_max_order_id = MySqlOperator(
task_id='query_get_max_order_id',
sql= query_get_max_order_id,
mysql_conn_id=MyCon,
xcom_push=True,
dag=dag)
branch_op_is_new_records = BranchPythonOperator(
task_id='branch_operation_is_new_records',
provide_context=True,
python_callable=branch_func_is_new_records,
dag=dag)
get_max_order_id >> branch_op_is_new_records >> import_orders
branch_op_is_new_records >> skip_operation
The MySqlOperator returns a number according to the number the BranchPythonOperator choose the next task. It's guaranteed that the MySqlOperator has returned value greater than 0.
My problem is that nothing is pushed to XCOM by the MySqlOperator
On the UI when I go to XCOM I see nothing. The BranchPythonOperator oviously reads nothing so my code fails.
Why the XCOM doesn't work here?
The MySQL operator currently (airflow 1.10.0 at time of writing) doesn't support returning anything in XCom, so the fix for you for now is to write a small operator yourself. You can do this directly in your DAG file (untested, so there may be silly errors):
from airflow.operators.mysql_operator import MySqlOperator as BaseMySqlOperator
from airflow.hooks.mysql_hook import MySqlHook
class ReturningMySqlOperator(BaseMySqlOperator):
def execute(self, context):
self.log.info('Executing: %s', self.sql)
hook = MySqlHook(mysql_conn_id=self.mysql_conn_id,
schema=self.database)
return hook.get_first(
self.sql,
parameters=self.parameters)
def branch_func_is_new_records(**kwargs):
ti = kwargs['ti']
xcom = ti.xcom_pull(task_ids='query_get_max_order_id')
string_to_print = 'Value in xcom is: {}'.format(xcom)
logging.info(string_to_print)
if str(xcom) == 'NewRecords':
return 'import_orders'
else:
return 'skip_operation'
query_get_max_order_id = 'SELECT COALESCE(max(orders_id),0) FROM warehouse.orders where orders_id>1 limit 10'
get_max_order_id = ReturningMySqlOperator(
task_id='query_get_max_order_id',
sql= query_get_max_order_id,
mysql_conn_id=MyCon,
# xcom_push=True,
dag=dag)
branch_op_is_new_records = BranchPythonOperator(
task_id='branch_operation_is_new_records',
provide_context=True,
python_callable=branch_func_is_new_records,
dag=dag)
get_max_order_id >> branch_op_is_new_records >> import_orders
branch_op_is_new_records >> skip_operation
I have a task that I'll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this task gets skipped as well. I don't want final task to get skipped as it has to report on DAG success.
To avoid it getting skipped I used trigger_rule='all_done', but it still gets skipped.
If I use BranchPythonOperator instead of ShortCircuitOperator final task doesn't get skipped. It would seem like branching workflow could be a solution, even though not optimal, but now final will not respect failures of upstream tasks.
How do I get it to only run when upstreams are successful or skipped?
Sample ShortCircuit DAG:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 8, 1)}
dag = DAG(
'shortcircuit_test',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
def shortcircuit_fn():
return randint(0, 1) == 1
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")
task_1 >> short >> work >> final
task_1 >> task_2 >> final
Sample Branch DAG:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 8, 1)}
dag = DAG(
'branch_test',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')
def branch_fn():
return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")
task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final
I've ended up with developing custom ShortCircuitOperator based on the original one:
class ShortCircuitOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to continue only if a condition is met. Otherwise, the
workflow "short-circuits" and downstream tasks that only rely on this operator
are skipped.
The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
condition and short-circuits the workflow if the condition is False. Any
downstream tasks that only rely on this operator are marked with a state of "skipped".
If the condition is True, downstream tasks proceed as normal.
The condition is determined by the result of `python_callable`.
"""
def find_tasks_to_skip(self, task, found_tasks=None):
if not found_tasks:
found_tasks = []
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
if len(t.upstream_task_ids) == 1:
found_tasks.append(t)
self.find_tasks_to_skip(t, found_tasks)
return found_tasks
def execute(self, context):
condition = super(ShortCircuitOperator, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.
I'm posting another possible workaround for this since this is a method that does not require a custom operator implementation.
I was influenced by the solution in this blog using a PythonOperator which raises an AirflowSkipException which skips the task itself and then downstream tasks individually.
https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/
This then respects the trigger_rule of the final downstream task, which in my case I set to trigger_rule='none_failed'.
Modfied example as per the blog to include a final task:
def fn_short_circuit(**context):
if <<<some condition>>>:
raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")
check_date = PythonOperator(
task_id="check_if_min_date",
python_callable=_check_date,
provide_context=True,
dag=dag,
)
task1 = DummyOperator(task_id="task1", dag=dag)
task2 = DummyOperator(task_id="task2", dag=dag)
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
final_task = DummyOperator(task_id="final_task",
trigger_rule='none_failed',
dag=dag)
task_1 >> short >> work >> final_task
task_1 >> task_2 >> final_task
This question is still legit with airflow 1.10.X
The following solution work with airflow 1.10.X , not tested yet with airflow 2.X
ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set
The solution of #michael-spector will only work with simple case and not this case :
with #michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )
A solution is this (based on #michael-spector proposition) :
class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
"""
Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
So if a task have this task in his upstream AND another task it will not be skipped
-> B -> C -> D ------\
/ \
A -> K -> Y
\ /
-> F -> G - P -----------/
If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
found_tasks_name contains the names of the previous skipped task
found_tasks contains the airflow_task_id of the previous skipped task
:return found_tasks
"""
def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
if not found_tasks_to_skip: # list of task_id to skip
found_tasks_to_skip = []
# necessary because found_tasks do not keep a copy of names but airflow task_id
if not found_tasks_to_skip_names:
found_tasks_to_skip_names = set()
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
self.log.info(
" Does all skipped task " +
str(found_tasks_to_skip_names) +
" contain the upstream tasks" +
str(t.upstream_task_ids)
)
# if len == 1 then the task is only precede by a skipped task
# otherwise check if ALL upstream task are skipped
if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
found_tasks_to_skip.append(t)
found_tasks_to_skip_names.add(t.task_id)
self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
return found_tasks_to_skip
def execute(self, context):
condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
I've made it work by making final task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.
# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule
def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
"""
find directly upstream task instances and count how many are not in prefered statuses.
return True if we got no instances with non-preferred statuses.
"""
upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
session = Session()
query = (session
.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_([task_instance.execution_date]),
TaskInstance.task_id.in_(upstream_task_ids)
)
)
upstream_task_instances = query.all()
unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
print(unhappy_task_instances)
return len(unhappy_task_instances) == 0
def final_fn(**context):
"""
fail if upstream task instances have unwanted statuses
"""
if not all_upstreams_either_succeeded_or_skipped(**context):
raise AirflowException("Not all upstream tasks succeeded.")
# Do things
# will run when upstream task instances are done, including failed
final = PythonOperator(
dag=dag,
task_id="final",
trigger_rule=TriggerRule.ALL_DONE,
python_callable=final_fn,
provide_context=True)
The ShortCircuitOperator can now be configured to respect downstream task. Default behavior is not respecting it. You can make the operator repsect by setting ignore_downstream_trigger_rules=False.
task = ShortCircuitOperator(
task_id='task_id',
python_callable=function,
ignore_downstream_trigger_rules=False,
)
This may have been added after you asked your initial question, but Airflow now conveniently has a trigger_rule value of none_failed. If you set this on your final task, it should complete whether upstream tasks are skipped or succeeded, just not when they fail.
More info: https://airflow.apache.org/concepts.html#trigger-rules
Problem: I've been trying to find a way to get tasks from a DAG that have no downstream tasks following them.
Why I need it: I'm building an "on success" notification for DAGs. Airflow DAGs have an on_success_callback argument, but problem with that is that it gets triggered after every task success instead of just DAG. I've seen other people approach this problem by creating notification task and appending it to the end. Problem I have with this approach is that many DAGs we're using have multiple ends, and some are auto-generated.
Making sure that all ends are caught manually is tedious.
I've spent hours digging for a way to access data I need to build this.
Sample DAG setup:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from datetime import datetime
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 7, 29)}
dag = DAG(
'append_to_end',
description='append a tast to all tasks without downstream',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
task_3 = DummyOperator(dag=dag, task_id='task_3')
task_1 >> task_2
task_1 >> task_3
This produces following DAG:
What I want to achieve is an automated way to include a new task to a DAG that connects to all ends, like in an image below.
I know it's an old post, but I've had a similar need as the above posted.
You can add to your return function a statement that doesn't return your "final_task" id, and so it won't be added to the get_leaf_task return, something like:
def get_leaf_tasks(dag):
return [task for task_id, task in dag.task_dict.items() if len(task.downstream_list) == 0 and task_ids != 'final_task']
Additionally, you can change this part:
for task in leaf_tasks:
task >> final_task
to:
get_leaf_tasks(dag) >> final_task
Since it already gives you a list of task instances and the bitwise operator ">>" will do the loop for you.
What I've got to so far is code below:
def get_leaf_tasks(dag):
return [task for task_id, task in dag.task_dict.items() if len(task.downstream_list) == 0]
leaf_tasks = get_leaf_tasks(dag)
final_task = DummyOperator(dag=dag, task_id='final_task')
for task in leaf_tasks:
task >> final_task
It produces the result I want, but what I don't like about this solution is that get_leaf_tasks must be executed before final_task is created, or it will be included in leaf_tasks list and I'll have to find ways to exclude it.
I could wrap assignment in another function:
def append_to_end(dag, task):
leaf_tasks = get_leaf_tasks(dag)
dag.add_task(task)
for task in leaf_tasks:
task >> final_task
final_task = DummyOperator(task_id='final_task')
append_to_end(dag, final_task)
This is not ideal either, as caller must ensure they've created a final_task without DAG assigned to it.