Is it possible to change branches based on environment variable set in Airflow.
For example:
Variable.get(‘Enviroment’) returns “dev”, “test”, or “prod”
Let’s say I have at task, Task_B, that I don’t want to run in “dev”.
Start_dag_task = DummyOperator(task_id=”Start_dag_task”, ...)
Task_A = PythonOperator(task_id=”Task_A”,…)
Task_B = PythonOperator(task_id=”Task_B”,…) (Do not run in dev environment)
Task_C= PythonOperator(task_id=”Task_C”,…)
End_dag_task = DummyOperator(task_id=”End_dag_task”, …)
env = Variable.get(“Environment”)
Start_dag_task >> Task_A >> End_dag_task
if env != “dev”:
Start_dag_task >> Task_B >> End_dag_task
Start_dag_task >> Task_C >> End_dag_task
Is this possible, to run this same code in all three environments and have Task_B not run in Dev?
You can achieve that by adding a ShortCircuitOperator before task B to check if the variable env value is dev or not, if it's dev, the task B will be skipped. But you need to set ignore_downstream_trigger_rules to False in order to execute the End_dag_task and the others downstream tasks, and set End_dag_task trigger_rule to NONE_FAILED to execute when the task B state is success or skipped:
from datetime import datetime
from airflow import DAG
from airflow.models import Variable
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator, PythonOperator
from airflow.utils.trigger_rule import TriggerRule
with DAG(
"dag_id",
schedule_interval=None,
start_date=datetime(2022, 1, 1),
) as dag:
Start_dag_task = EmptyOperator(task_id="Start_dag_task")
Task_A = PythonOperator(task_id="Task_A", python_callable=lambda: print("Task A"))
Task_B = PythonOperator(task_id="Task_B", python_callable=lambda: print("B"))
Task_C = PythonOperator(task_id="Task_C", python_callable=lambda: print("C"))
End_dag_task = EmptyOperator(task_id="End_dag_task", trigger_rule=TriggerRule.NONE_FAILED)
shortCircuitTaskB = ShortCircuitOperator(
task_id="short_circuit_for_task_B",
python_callable=lambda: Variable.get("env", "default_env") != "dev",
ignore_downstream_trigger_rules=False,
)
Start_dag_task >> Task_A >> End_dag_task
Start_dag_task >> shortCircuitTaskB >> Task_B >> End_dag_task
Start_dag_task >> Task_C >> End_dag_task
Related
I have a few airflow tasks, I want to make sure a specific task executes always as the last task in my DAG. There might be some tasks that will be skipped depending on the input.
this_will_run_for_sure_1 = DummyOperator()
this_will_run_for_sure_2 = DummyOperator()
this_might_not_run_1 = DummyOperator() # Can run in parallel with this_will_run_for_sure_1
this_might_not_run_2 = DummyOperator() # Can run only after this_will_run_for_sure_2 and this_might_not_run_1 finished
final_task = DummyOperator()
How could be this set up with bitwise operators? I was thinking something like:
this_will_run_for_sure_1 >> this_will_run_for_sure_2
this_might_not_run_1 >> this_might_not_run_2
[this_will_run_for_sure_2, this_might_not_run_2] >> final_task
But this does not work as expected, if this_might_not_run_1 does not start because the input does not have the required data for it to start, it won't get to final task. Any help on this is much appreciated
I think you are just missing also:
this_will_run_for_sure_2 >> this_might_not_run_2
This should resolve your issue:
from datetime import datetime
from airflow.decorators import dag
from airflow.operators.dummy import DummyOperator
#dag(
schedule_interval=None,
start_date=datetime(2022, 1, 1)
)
def example():
this_will_run_for_sure_1 = DummyOperator(task_id='this_will_run_for_sure_1')
this_will_run_for_sure_2 = DummyOperator(task_id='this_will_run_for_sure_2')
this_might_not_run_1 = DummyOperator(task_id='this_might_not_run_1') # Can run in parallel with this_will_run_for_sure_1
this_might_not_run_2 = DummyOperator(task_id='this_might_not_run_2') # Can run only after this_will_run_for_sure_2 and this_might_not_run_1 finished
final_task = DummyOperator(task_id='final_task')
this_will_run_for_sure_1 >> this_will_run_for_sure_2 >> this_might_not_run_2
this_might_not_run_1 >> this_might_not_run_2
[this_will_run_for_sure_2, this_might_not_run_2] >> final_task
dag = example()
Example:
I have a DAG on GCP Airflow with tasks like the below:
with DAG(dag_name, schedule_interval='0 6 * * *', default_args=default_dag_args) as dag:
notify_start = po.PythonOperator(
task_id = 'notify-on-start',
python_callable = slack,
op_kwargs={'msg': slack_start}
)
create_dataproc_cluster = d.create_cluster(default_dag_args['cluster_name'], service_account, num_workers)
[assorted dataproc tasks]
notify_on_fail = po.PythonOperator(
task_id = 'notify-on-task-failure',
python_callable = slack,
op_kwargs={'msg': slack_error, 'err': True},
trigger_rule = trigger_rule.TriggerRule.ONE_FAILED
)
delete_cluster = d.delete_cluster(default_dag_args['cluster_name'])
notify_finish = po.PythonOperator(
task_id = 'notify-on-completion',
python_callable = slack,
op_kwargs={'msg': slack_finish},
trigger_rule = trigger_rule.TriggerRule.ALL_DONE
)
notify_start >> create_dataproc_cluster >> [assorted dataproc tasks >> delete_cluster >> notify_on_fail >> notify_finish
The problem I am facing is if one of the dataproc tasks fails, the notify_on_fail task does not trigger, despite having the ONE_FAILED trigger rule. Rather, it spins down the cluster and sends the all clear message (notify_finish) Are my tasks in the wrong order, or is something else wrong?
Based on your expectation, I think the DAG should be like this
notify_start >> create_dataproc_cluster >> [assorted dataproc tasks]
[assorted dataproc tasks] >> notify_on_fail >> delete_cluster
[assorted dataproc tasks] >> notify_finish >> delete_cluster
I read this How to use airflow xcoms with MySqlOperator and while it has a similiar title it doesn't really address my issue.
I have the following code:
def branch_func_is_new_records(**kwargs):
ti = kwargs['ti']
xcom = ti.xcom_pull(task_ids='query_get_max_order_id')
string_to_print = 'Value in xcom is: {}'.format(xcom)
logging.info(string_to_print)
if int(xcom) > int(LAST_IMPORTED_ORDER_ID)
return 'import_orders'
else:
return 'skip_operation'
query_get_max_order_id = 'SELECT COALESCE(max(orders_id),0) FROM warehouse.orders where orders_id>1 limit 10'
get_max_order_id = MySqlOperator(
task_id='query_get_max_order_id',
sql= query_get_max_order_id,
mysql_conn_id=MyCon,
xcom_push=True,
dag=dag)
branch_op_is_new_records = BranchPythonOperator(
task_id='branch_operation_is_new_records',
provide_context=True,
python_callable=branch_func_is_new_records,
dag=dag)
get_max_order_id >> branch_op_is_new_records >> import_orders
branch_op_is_new_records >> skip_operation
The MySqlOperator returns a number according to the number the BranchPythonOperator choose the next task. It's guaranteed that the MySqlOperator has returned value greater than 0.
My problem is that nothing is pushed to XCOM by the MySqlOperator
On the UI when I go to XCOM I see nothing. The BranchPythonOperator oviously reads nothing so my code fails.
Why the XCOM doesn't work here?
The MySQL operator currently (airflow 1.10.0 at time of writing) doesn't support returning anything in XCom, so the fix for you for now is to write a small operator yourself. You can do this directly in your DAG file (untested, so there may be silly errors):
from airflow.operators.mysql_operator import MySqlOperator as BaseMySqlOperator
from airflow.hooks.mysql_hook import MySqlHook
class ReturningMySqlOperator(BaseMySqlOperator):
def execute(self, context):
self.log.info('Executing: %s', self.sql)
hook = MySqlHook(mysql_conn_id=self.mysql_conn_id,
schema=self.database)
return hook.get_first(
self.sql,
parameters=self.parameters)
def branch_func_is_new_records(**kwargs):
ti = kwargs['ti']
xcom = ti.xcom_pull(task_ids='query_get_max_order_id')
string_to_print = 'Value in xcom is: {}'.format(xcom)
logging.info(string_to_print)
if str(xcom) == 'NewRecords':
return 'import_orders'
else:
return 'skip_operation'
query_get_max_order_id = 'SELECT COALESCE(max(orders_id),0) FROM warehouse.orders where orders_id>1 limit 10'
get_max_order_id = ReturningMySqlOperator(
task_id='query_get_max_order_id',
sql= query_get_max_order_id,
mysql_conn_id=MyCon,
# xcom_push=True,
dag=dag)
branch_op_is_new_records = BranchPythonOperator(
task_id='branch_operation_is_new_records',
provide_context=True,
python_callable=branch_func_is_new_records,
dag=dag)
get_max_order_id >> branch_op_is_new_records >> import_orders
branch_op_is_new_records >> skip_operation
I have a task that I'll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this task gets skipped as well. I don't want final task to get skipped as it has to report on DAG success.
To avoid it getting skipped I used trigger_rule='all_done', but it still gets skipped.
If I use BranchPythonOperator instead of ShortCircuitOperator final task doesn't get skipped. It would seem like branching workflow could be a solution, even though not optimal, but now final will not respect failures of upstream tasks.
How do I get it to only run when upstreams are successful or skipped?
Sample ShortCircuit DAG:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 8, 1)}
dag = DAG(
'shortcircuit_test',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
def shortcircuit_fn():
return randint(0, 1) == 1
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")
task_1 >> short >> work >> final
task_1 >> task_2 >> final
Sample Branch DAG:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 8, 1)}
dag = DAG(
'branch_test',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')
def branch_fn():
return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")
task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final
I've ended up with developing custom ShortCircuitOperator based on the original one:
class ShortCircuitOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to continue only if a condition is met. Otherwise, the
workflow "short-circuits" and downstream tasks that only rely on this operator
are skipped.
The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
condition and short-circuits the workflow if the condition is False. Any
downstream tasks that only rely on this operator are marked with a state of "skipped".
If the condition is True, downstream tasks proceed as normal.
The condition is determined by the result of `python_callable`.
"""
def find_tasks_to_skip(self, task, found_tasks=None):
if not found_tasks:
found_tasks = []
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
if len(t.upstream_task_ids) == 1:
found_tasks.append(t)
self.find_tasks_to_skip(t, found_tasks)
return found_tasks
def execute(self, context):
condition = super(ShortCircuitOperator, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.
I'm posting another possible workaround for this since this is a method that does not require a custom operator implementation.
I was influenced by the solution in this blog using a PythonOperator which raises an AirflowSkipException which skips the task itself and then downstream tasks individually.
https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/
This then respects the trigger_rule of the final downstream task, which in my case I set to trigger_rule='none_failed'.
Modfied example as per the blog to include a final task:
def fn_short_circuit(**context):
if <<<some condition>>>:
raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")
check_date = PythonOperator(
task_id="check_if_min_date",
python_callable=_check_date,
provide_context=True,
dag=dag,
)
task1 = DummyOperator(task_id="task1", dag=dag)
task2 = DummyOperator(task_id="task2", dag=dag)
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
final_task = DummyOperator(task_id="final_task",
trigger_rule='none_failed',
dag=dag)
task_1 >> short >> work >> final_task
task_1 >> task_2 >> final_task
This question is still legit with airflow 1.10.X
The following solution work with airflow 1.10.X , not tested yet with airflow 2.X
ShortCircuitOperator will skip all downstream TASK whatever the trigger_rule set
The solution of #michael-spector will only work with simple case and not this case :
with #michael-spector the task L will not be skipped ( only E , F , G , H tasks will be skipped )
A solution is this (based on #michael-spector proposition) :
class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
"""
Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
So if a task have this task in his upstream AND another task it will not be skipped
-> B -> C -> D ------\
/ \
A -> K -> Y
\ /
-> F -> G - P -----------/
If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
found_tasks_name contains the names of the previous skipped task
found_tasks contains the airflow_task_id of the previous skipped task
:return found_tasks
"""
def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
if not found_tasks_to_skip: # list of task_id to skip
found_tasks_to_skip = []
# necessary because found_tasks do not keep a copy of names but airflow task_id
if not found_tasks_to_skip_names:
found_tasks_to_skip_names = set()
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
self.log.info(
" Does all skipped task " +
str(found_tasks_to_skip_names) +
" contain the upstream tasks" +
str(t.upstream_task_ids)
)
# if len == 1 then the task is only precede by a skipped task
# otherwise check if ALL upstream task are skipped
if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
found_tasks_to_skip.append(t)
found_tasks_to_skip_names.add(t.task_id)
self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
return found_tasks_to_skip
def execute(self, context):
condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
I've made it work by making final task to check for statuses of upstream instances. Not beautiful as only way to access their state I've found was by querying Airflow DB.
# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule
def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
"""
find directly upstream task instances and count how many are not in prefered statuses.
return True if we got no instances with non-preferred statuses.
"""
upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
session = Session()
query = (session
.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_([task_instance.execution_date]),
TaskInstance.task_id.in_(upstream_task_ids)
)
)
upstream_task_instances = query.all()
unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
print(unhappy_task_instances)
return len(unhappy_task_instances) == 0
def final_fn(**context):
"""
fail if upstream task instances have unwanted statuses
"""
if not all_upstreams_either_succeeded_or_skipped(**context):
raise AirflowException("Not all upstream tasks succeeded.")
# Do things
# will run when upstream task instances are done, including failed
final = PythonOperator(
dag=dag,
task_id="final",
trigger_rule=TriggerRule.ALL_DONE,
python_callable=final_fn,
provide_context=True)
The ShortCircuitOperator can now be configured to respect downstream task. Default behavior is not respecting it. You can make the operator repsect by setting ignore_downstream_trigger_rules=False.
task = ShortCircuitOperator(
task_id='task_id',
python_callable=function,
ignore_downstream_trigger_rules=False,
)
This may have been added after you asked your initial question, but Airflow now conveniently has a trigger_rule value of none_failed. If you set this on your final task, it should complete whether upstream tasks are skipped or succeeded, just not when they fail.
More info: https://airflow.apache.org/concepts.html#trigger-rules
I have the following DAG with 3 tasks:
start --> special_task --> end
The task in the middle can succeed or fail, but end must always be executed (imagine this is a task for cleanly closing resources). For that, I used the trigger rule ALL_DONE:
end.trigger_rule = trigger_rule.TriggerRule.ALL_DONE
Using that, end is properly executed if special_task fails. However, since end is the last task and succeeds, the DAG is always marked as SUCCESS.
How can I configure my DAG so that if one of the tasks failed, the whole DAG is marked as FAILED?
Example to reproduce
import datetime
from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.utils import trigger_rule
dag = DAG(
dag_id='my_dag',
start_date=datetime.datetime.today(),
schedule_interval=None
)
start = BashOperator(
task_id='start',
bash_command='echo start',
dag=dag
)
special_task = BashOperator(
task_id='special_task',
bash_command='exit 1', # force failure
dag=dag
)
end = BashOperator(
task_id='end',
bash_command='echo end',
dag=dag
)
end.trigger_rule = trigger_rule.TriggerRule.ALL_DONE
start.set_downstream(special_task)
special_task.set_downstream(end)
This post seems to be related, but the answer does not suit my needs, since the downstream task end must be executed (hence the mandatory trigger_rule).
I thought it was an interesting question and spent some time figuring out how to achieve it without an extra dummy task. It became a bit of a superfluous task, but here's the end result:
This is the full DAG:
import airflow
from airflow import AirflowException
from airflow.models import DAG, TaskInstance, BaseOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
default_args = {"owner": "airflow", "start_date": airflow.utils.dates.days_ago(3)}
dag = DAG(
dag_id="finally_task_set_end_state",
default_args=default_args,
schedule_interval="0 0 * * *",
description="Answer for question https://stackoverflow.com/questions/51728441",
)
start = BashOperator(task_id="start", bash_command="echo start", dag=dag)
failing_task = BashOperator(task_id="failing_task", bash_command="exit 1", dag=dag)
#provide_session
def _finally(task, execution_date, dag, session=None, **_):
upstream_task_instances = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task.upstream_task_ids),
)
.all()
)
upstream_states = [ti.state for ti in upstream_task_instances]
fail_this_task = State.FAILED in upstream_states
print("Do logic here...")
if fail_this_task:
raise AirflowException("Failing task because one or more upstream tasks failed.")
finally_ = PythonOperator(
task_id="finally",
python_callable=_finally,
trigger_rule=TriggerRule.ALL_DONE,
provide_context=True,
dag=dag,
)
succesful_task = DummyOperator(task_id="succesful_task", dag=dag)
start >> [failing_task, succesful_task] >> finally_
Look at the _finally function, which is called by the PythonOperator. There are a few key points here:
Annotate with #provide_session and add argument session=None, so you can query the Airflow DB with session.
Query all upstream task instances for the current task:
upstream_task_instances = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task.upstream_task_ids),
)
.all()
)
From the returned task instances, get the states and check if State.FAILED is in there:
upstream_states = [ti.state for ti in upstream_task_instances]
fail_this_task = State.FAILED in upstream_states
Perform your own logic:
print("Do logic here...")
And finally, fail the task if fail_this_task=True:
if fail_this_task:
raise AirflowException("Failing task because one or more upstream tasks failed.")
The end result:
As #JustinasMarozas explained in a comment, a solution is to create a dummy task like :
dummy = DummyOperator(
task_id='test',
dag=dag
)
and bind it downstream to special_task :
failing_task.set_downstream(dummy)
Thus, the DAG is marked as failed, and the dummy task is marked as upstream_failed.
Hope there is an out-of-the-box solution, but waiting for that, this solution does the job.
To expand on Bas Harenslak answer, a simpler _finally function which will check the state of all tasks (not only the upstream ones) can be:
def _finally(**kwargs):
for task_instance in kwargs['dag_run'].get_task_instances():
if task_instance.current_state() != State.SUCCESS and \
task_instance.task_id != kwargs['task_instance'].task_id:
raise Exception("Task {} failed. Failing this DAG run".format(task_instance.task_id))