Airflow File Sensor for sensing files on my local drive - airflow

does anybody have any idea on FileSensor ? I came through it while i was researching on sensing files on my local directory. The code is as follows:
task= FileSensor(
task_id="senseFile"
filepath="etc/hosts",
fs_conn_id='fs_local',
_hook=self.hook,
dag=self.dag,)
I have also set my conn_id and conn type as File (path) and gave the {'path':'mypath'} but even though i set a non existing path or if the file isnt there in the specified path, the task is completed and the dag is successful. The FileSensor doesnt seem to sense files at all.

I found the community contributed FileSenor a little bit underwhelming so wrote my own.
I got it working for files locally to where the server/scheduler was running however ran into problems when using network paths.
The trick for network paths I found was to mount the network drive to my Linux Box.
This is my DAG used to sensor_task >> proccess_task >> archive_task >> trigger rerun
Note: We use variables (sourcePath, filePattern & archivePath) entered via the WebGUI
from airflow import DAG
from airflow.operators import PythonOperator, OmegaFileSensor, ArchiveFileOperator, TriggerDagRunOperator
from datetime import datetime, timedelta
from airflow.models import Variable
default_args = {
'owner': 'glsam',
'depends_on_past': False,
'start_date': datetime(2017, 6, 26),
'provide_context': True,
'retries': 100,
'retry_delay': timedelta(seconds=30)
}
task_name = 'my_first_file_sensor_task'
filepath = Variable.get("soucePath")
filepattern = Variable.get("filePattern")
archivepath = Variable.get("archivePath")
dag = DAG(
'task_name',
default_args=default_args,
schedule_interval=None,
catchup=False,
max_active_runs=1,
concurrency=1)
sensor_task = OmegaFileSensor(
task_id=task_name,
filepath=filepath,
filepattern=filepattern,
poke_interval=3,
dag=dag)
def process_file(**context):
file_to_process = context['task_instance'].xcom_pull(
key='file_name', task_ids=task_name)
file = open(filepath + file_to_process, 'w')
file.write('This is a test\n')
file.write('of processing the file')
file.close()
proccess_task = PythonOperator(
task_id='process_the_file', python_callable=process_file, dag=dag)
archive_task = ArchiveFileOperator(
task_id='archive_file',
filepath=filepath,
task_name=task_name,
archivepath=archivepath,
dag=dag)
trigger = TriggerDagRunOperator(
task_id='trigger_dag_rerun', trigger_dag_id=task_name, dag=dag)
sensor_task >> proccess_task >> archive_task >> trigger
And then this is my FileSenor
import os
import re
from datetime import datetime
from airflow.models import BaseOperator
from airflow.plugins_manager import AirflowPlugin
from airflow.utils.decorators import apply_defaults
from airflow.operators.sensors import BaseSensorOperator
class ArchiveFileOperator(BaseOperator):
#apply_defaults
def __init__(self, filepath, archivepath, task_name, *args, **kwargs):
super(ArchiveFileOperator, self).__init__(*args, **kwargs)
self.filepath = filepath
self.archivepath = archivepath
self.task_name = task_name
def execute(self, context):
file_name = context['task_instance'].xcom_pull(self.task_name, key='file_name')
os.rename(self.filepath + file_name, self.archivepath + file_name)
class OmegaFileSensor(BaseSensorOperator):
#apply_defaults
def __init__(self, filepath, filepattern, *args, **kwargs):
super(OmegaFileSensor, self).__init__(*args, **kwargs)
self.filepath = filepath
self.filepattern = filepattern
def poke(self, context):
full_path = self.filepath
file_pattern = re.compile(self.filepattern)
directory = os.listdir(full_path)
for files in directory:
if not re.match(file_pattern, files):
# do nothing
else:
context['task_instance'].xcom_push('file_name', files)
return True
return False
class OmegaPlugin(AirflowPlugin):
name = "omega_plugin"
operators = [OmegaFileSensor, ArchiveFileOperator]

Related

View on_failure_callback DAG logger

Let's take an example DAG.
Here is the code for it.
import logging
from airflow import DAG
from datetime import datetime, timedelta
from airflow.models import TaskInstance
from airflow.operators.python import PythonOperator
from airflow.operators.dummy import DummyOperator
def task_failure_notification_alert(context):
logging.info("Task context details: %s", str(context))
def dag_failure_notification_alert(context):
logging.info("DAG context details: %s", str(context))
def red_exception_task(ti: TaskInstance, **kwargs):
raise Exception('red')
default_args = {
"owner": "analytics",
"start_date": datetime(2021, 12, 12),
'retries': 0,
'retry_delay': timedelta(),
"schedule_interval": "#daily"
}
dag = DAG('logger_dag',
default_args=default_args,
catchup=False,
on_failure_callback=dag_failure_notification_alert
)
start_task = DummyOperator(task_id="start_task", dag=dag, on_failure_callback=task_failure_notification_alert)
red_task = PythonOperator(
dag=dag,
task_id='red_task',
python_callable=red_exception_task,
provide_context=True,
on_failure_callback=task_failure_notification_alert
)
end_task = DummyOperator(task_id="end_task", dag=dag, on_failure_callback=task_failure_notification_alert)
start_task >> red_task >> end_task
We can see two functions i.e. task_failure_notification_alert and dag_failure_notification_alert are being called in case of failures.
We can see logs in case of Task failure by the below steps.
We can see logs for the task as below.
but I am unable to find logs for the on_failure_callback of DAG anywhere in UI. Where can we see it?
Under airflow/logs find the "scheduler" folder, under it look for the specific date you ran the Dag for example 2022-12-03 and there you will see name of the dag_file.log.

Using airflow dag_run.conf inside custom operator

We created a custom airflow based on EMRContainerOperator and we need to take a decision based on a config passed using the airflow UI.
My custom operator:
from airflow.providers.amazon.aws.operators.emr_containers import EMRContainerOperator
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from uuid import uuid4
from airflow.utils.decorators import apply_defaults
class EmrBatchProcessorOperator(EMRContainerOperator):
template_fields: Sequence[str] = (
"name",
"virtual_cluster_id",
"execution_role_arn",
"release_label",
"job_driver",
"operation_type"
)
#apply_defaults
def __init__(
self,
operation_type,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.operation_type = operation_type
if self.operation_type == 'full':
number_of_pods=10
else:
number_of_pods=5
BASE_CONSUMER_DRIVER_ARG = {
"sparkSubmitJobDriver": {"entryPoint": "s3://bucket/batch_processor_engine/batch-processor-engine_2.12-3.0.1_0.28.jar","entryPointArguments": ["group_name=courier_api_group01"], "sparkSubmitParameters": f"--conf spark.executor.instances={ number_of_pods } --conf spark.executor.memory=32G --conf spark.executor.cores=5 --conf spark.driver.cores=1 --conf spark.driver.memory=12G --conf spark.sql.broadcastTimeout=2000 --class TableProcessorWrapper"}
}
self.job_driver = BASE_CONSUMER_DRIVER_ARG
This is the way that I call my operator:
with DAG(
dag_id="batch_processor_model_dag",
schedule_interval="#daily",
default_args=default_args,
catchup=False
) as dag:
start = DummyOperator(task_id='start', dag=dag)
end = DummyOperator(task_id='end', dag=dag, trigger_rule='none_failed')
base_consumer = EmrBatchProcessorOperator(
task_id="base_consumer",
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
execution_role_arn=JOB_ROLE_ARN,
configuration_overrides=CONFIGURATION_OVERRIDES_ARG,
release_label="emr-6.5.0-latest",
job_driver={},
name="pi.py",
operation_type= '{{dag_run.conf["operation_type"]}}'
)
start >> base_consumer >> end
But this code didn't work, I can't use the dag_run.conf value.
could you help me?

Airflow Custom Sensor: AttributeError: 'NoneType' object has no attribute 'get_records'

I am running Airflow v1.9.0 with Celery Executor. I have configured different workers with different queue names like DEV, QA, UAT, PROD. I have written a custom sensor which polls a source db connection and a target db connection and run different queries and do some checks before triggering downstream-tasks. This has been running fine for multiple workers. In one of the workers, this sensor is giving an AttributeError Issue:
$ airflow test PDI_Incr_20190407_v1 checkCCWatermarkDt 2019-04-09
[2019-04-09 10:02:57,769] {configuration.py:206} WARNING - section/key [celery/celery_ssl_active] not found in config
[2019-04-09 10:02:57,770] {default_celery.py:41} WARNING - Celery Executor will run without SSL
[2019-04-09 10:02:57,771] {__init__.py:45} INFO - Using executor CeleryExecutor
[2019-04-09 10:02:57,817] {models.py:189} INFO - Filling up the DagBag from /home/airflow/airflow/dags
/usr/local/lib/python2.7/site-packages/airflow/models.py:2160: PendingDeprecationWarning: Invalid arguments were passed to ExternalTaskSensor. Support for passing such arguments will be dropped in Airflow 2.0. Invalid arguments were:
*args: ()
**kwargs: {'check_existence': True}
category=PendingDeprecationWarning
[2019-04-09 10:02:57,989] {base_hook.py:80} INFO - Using connection to: 172.16.20.11:1521/GWPROD
[2019-04-09 10:02:57,991] {base_hook.py:80} INFO - Using connection to: dmuat.cwmcwghvymd3.us-east-1.rds.amazonaws.com:1521/DMUAT
Traceback (most recent call last):
File "/usr/local/bin/airflow", line 27, in <module>
args.func(args)
File "/usr/local/lib/python2.7/site-packages/airflow/bin/cli.py", line 528, in test
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
File "/usr/local/lib/python2.7/site-packages/airflow/utils/db.py", line 50, in wrapper
result = func(*args, **kwargs)
File "/usr/local/lib/python2.7/site-packages/airflow/models.py", line 1584, in run
session=session)
File "/usr/local/lib/python2.7/site-packages/airflow/utils/db.py", line 50, in wrapper
result = func(*args, **kwargs)
File "/usr/local/lib/python2.7/site-packages/airflow/models.py", line 1493, in _run_raw_task
result = task_copy.execute(context=context)
File "/usr/local/lib/python2.7/site-packages/airflow/operators/sensors.py", line 78, in execute
while not self.poke(context):
File "/home/airflow/airflow/plugins/PDIPlugin.py", line 29, in poke
wm_dt_src = hook_src.get_records(self.sql)
AttributeError: 'NoneType' object has no attribute 'get_records'
Although when I run the same test command from Scheduler CLI, it is running fine. The above issue looks like a database connection issue.
For debugging, I checked the DB Connections from Airflow UI:
Data Profiling -> Ad Hoc Query
Query: Select 1 from dual; -- This worked fine
I also did telnet from the worker node to the DB Host and port and that also went fine.
Custom Sensor Code:
from airflow.plugins_manager import AirflowPlugin
from airflow.hooks.base_hook import BaseHook
from airflow.operators.sensors import SqlSensor
class SensorWatermarkDt(SqlSensor):
def __init__(self, conn_id, sql, conn_id_tgt, sql_tgt, *args, **kwargs):
self.sql = sql
self.conn_id = conn_id
self.sql_tgt = sql_tgt
self.conn_id_tgt = conn_id_tgt
super(SqlSensor, self).__init__(*args, **kwargs)
def poke(self, context):
hook_src = BaseHook.get_connection(self.conn_id).get_hook()
hook_tgt = BaseHook.get_connection(self.conn_id_tgt).get_hook()
self.log.info('Poking: %s', self.sql)
self.log.info('Poking: %s', self.sql_tgt)
wm_dt_src = hook_src.get_records(self.sql)
wm_dt_tgt = hook_tgt.get_records(self.sql_tgt)
if wm_dt_src <= wm_dt_tgt:
return False
else:
return True
class PDIPlugin(AirflowPlugin):
name = "PDIPlugin"
operators = [SensorWatermarkDt]
Airflow DAG Snippet:
import airflow
from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.operators.email_operator import EmailOperator
from datetime import timedelta,datetime
from airflow.operators import SensorWatermarkDt
from airflow.operators.sensors import ExternalTaskSensor
from airflow.operators.dummy_operator import DummyOperator
default_args = {
'owner': 'SenseTeam',
#'depends_on_past': True,
'depends_on_past' : False,
'start_date': datetime(2019, 4, 7, 17, 00),
'email': [],
'email_on_failure': False,
'email_on_retry': False,
'queue': 'PENTAHO_UAT'
}
dag = DAG(dag_id='PDI_Incr_20190407_v1',
default_args=default_args,
max_active_runs=1,
concurrency=1,
catchup=False,
schedule_interval=timedelta(hours=24),
dagrun_timeout=timedelta(minutes=23*60))
checkCCWatermarkDt = \
SensorWatermarkDt(task_id='checkCCWatermarkDt',
conn_id='CCUSER_SOURCE_GWPROD_RPT',
sql="SELECT MAX(CC_WM.CREATETIME) as CURRENT_WATERMARK_DATE FROM CCUSER.CCX_CAPTUREREASON_ETL CC_WM INNER JOIN CCUSER.CCTL_CAPTUREREASON_ETL CC_WMLKP ON CC_WM.CAPTUREREASON_ETL = CC_WMLKP.ID AND UPPER(CC_WMLKP.DESCRIPTION)= 'WATERMARK'",
conn_id_tgt = 'RDS_DMUAT_DMCONFIG',
sql_tgt = "SELECT MAX(CURRENT_WATERMARK_DATE) FROM DMCONFIG.PRESTG_DM_WMD_WATERMARKDATE WHERE SCHEMA_NAME = 'CCUSER'",
poke_interval=60,
dag=dag)
...
I have restarted web server, scheduler and airflow worker after adding this plugin in this worker node.
What am I missing here?
I have met this problem as well when I tried to use airflow's hook to connect to Teradata database, so I have read the code in airflow, we can see the get_hook() function in the path: /<your python path(may like: /usr/lib64/python2.7/)>/site-packages/airflow/models/connection.py:
def get_hook(self):
try:
if self.conn_type == 'mysql':
from airflow.hooks.mysql_hook import MySqlHook
return MySqlHook(mysql_conn_id=self.conn_id)
elif self.conn_type == 'google_cloud_platform':
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
return BigQueryHook(bigquery_conn_id=self.conn_id)
elif self.conn_type == 'postgres':
from airflow.hooks.postgres_hook import PostgresHook
return PostgresHook(postgres_conn_id=self.conn_id)
elif self.conn_type == 'hive_cli':
from airflow.hooks.hive_hooks import HiveCliHook
return HiveCliHook(hive_cli_conn_id=self.conn_id)
elif self.conn_type == 'presto':
from airflow.hooks.presto_hook import PrestoHook
return PrestoHook(presto_conn_id=self.conn_id)
elif self.conn_type == 'hiveserver2':
from airflow.hooks.hive_hooks import HiveServer2Hook
return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
elif self.conn_type == 'sqlite':
from airflow.hooks.sqlite_hook import SqliteHook
return SqliteHook(sqlite_conn_id=self.conn_id)
elif self.conn_type == 'jdbc':
from airflow.hooks.jdbc_hook import JdbcHook
return JdbcHook(jdbc_conn_id=self.conn_id)
elif self.conn_type == 'mssql':
from airflow.hooks.mssql_hook import MsSqlHook
return MsSqlHook(mssql_conn_id=self.conn_id)
elif self.conn_type == 'oracle':
from airflow.hooks.oracle_hook import OracleHook
return OracleHook(oracle_conn_id=self.conn_id)
elif self.conn_type == 'vertica':
from airflow.contrib.hooks.vertica_hook import VerticaHook
return VerticaHook(vertica_conn_id=self.conn_id)
elif self.conn_type == 'cloudant':
from airflow.contrib.hooks.cloudant_hook import CloudantHook
return CloudantHook(cloudant_conn_id=self.conn_id)
elif self.conn_type == 'jira':
from airflow.contrib.hooks.jira_hook import JiraHook
return JiraHook(jira_conn_id=self.conn_id)
elif self.conn_type == 'redis':
from airflow.contrib.hooks.redis_hook import RedisHook
return RedisHook(redis_conn_id=self.conn_id)
elif self.conn_type == 'wasb':
from airflow.contrib.hooks.wasb_hook import WasbHook
return WasbHook(wasb_conn_id=self.conn_id)
elif self.conn_type == 'docker':
from airflow.hooks.docker_hook import DockerHook
return DockerHook(docker_conn_id=self.conn_id)
elif self.conn_type == 'azure_data_lake':
from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
elif self.conn_type == 'azure_cosmos':
from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook
return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
elif self.conn_type == 'cassandra':
from airflow.contrib.hooks.cassandra_hook import CassandraHook
return CassandraHook(cassandra_conn_id=self.conn_id)
elif self.conn_type == 'mongo':
from airflow.contrib.hooks.mongo_hook import MongoHook
return MongoHook(conn_id=self.conn_id)
elif self.conn_type == 'gcpcloudsql':
from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
except Exception:
pass
It means if you don't have this kind of type connection then the get_hook will return a 'NoneType' Type. So that is the reason why it happened.
How to Resolve:
Add a your own hook is best way in airflow , I had a sample for Teradata here:
# cat teradata_hook.py
from builtins import str
import jaydebeapi
from airflow.hooks.dbapi_hook import DbApiHook
class TeradataJdbcHook(DbApiHook):
conn_name_attr = 'teradata_conn_id'
default_conn_name = 'teradata_default'
supports_autocommit = True
def get_conn(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
host = conn.host
url = 'jdbc:teradata://' + host + '/TMODE=TERA'
login = conn.login
psw = conn.password
jdbc_driver_loc = '/opt/spark-2.3.1-bin-without-hadoop/jars/terajdbc4-16.20.00.06.jar,/opt/spark-2.3.1-bin-without-hadoop/jars/tdgssconfig-16.20.00.06.jar'
jdbc_driver_name = "com.teradata.jdbc.TeraDriver"
conn = jaydebeapi.connect(jclassname=jdbc_driver_name,
url=url,
driver_args=[str(login), str(psw)],
jars=jdbc_driver_loc.split(","))
return conn
def set_autocommit(self, conn, autocommit):
"""
Enable or disable autocommit for the given connection.
:param conn: The connection
:return:
"""
conn.jconn.setAutoCommit(autocommit)
Then you can call this hook to connect teradata database (or other database which has jdbc driver ):
[root#myhost transfer]# cat h.py
import util
from airflow.hooks.base_hook import BaseHook
from teradata_hook import TeradataJdbcHook
sql = "SELECT COUNT(*) FROM TERADATA_TABLE where month_key='202009'"
conn_id='teradata_account#dbname' # this is my environment's id format
hook = TeradataJdbcHook(conn_id)
records = hook.get_records(sql)
print(records)
if str(records[0][0]) in ('0', '',):
print("No Records")
else:
print("Has Records")
It return's result: [(7734133,)]

Airflow - Broken DAG - Timeout

I have a DAG that executes a function that connects to a Postgres DB, deletes the contents in the table and then inserts a new data set.
I am trying this in my local and I see when I try to run this, the web server takes a long time to connect and in most cases doesn't succeed. However as part of the connecting process it seems to be executing the queries from the back-end. Since I have a delete function I see the data getting deleted from the table(basically one of the functions gets executed) even though I have not scheduled the script or manually started. Could someone advice as to what I am doing wrong in this.
One error that pops out in the UI is
Broken DAG: [/Users/user/airflow/dags/dwh_sample23.py] Timeout
Also see an i next to the dag id in the UI that says This is DAG isn't available in the web server's DAG object.
Given below is the code I am using:
## Third party Library Imports
import pandas as pd
import psycopg2
import airflow
from airflow import DAG
from airflow.operators import BashOperator
from datetime import datetime, timedelta
from sqlalchemy import create_engine
import io
# Following are defaults which can be overridden later on
default_args = {
'owner': 'admin',
'depends_on_past': False,
'start_date': datetime(2018, 5, 21),
'retries': 1,
'retry_delay': timedelta(minutes=1),
}
dag = DAG('dwh_sample23', default_args=default_args)
#######################
## Login to DB
def db_login():
''' This function connects to the Data Warehouse and returns the cursor to execute queries '''
global dwh_connection
try:
dwh_connection = psycopg2.connect(" dbname = 'dbname' user = 'user' password = 'password' host = 'hostname' port = '5439' sslmode = 'require' ")
except:
print("I am unable to connect to the database.")
print('Success')
return(dwh_connection)
def tbl1_del():
''' This function takes clears all rows from tbl1 '''
cur = dwh_connection.cursor()
cur.execute("""DELETE FROM tbl1;""")
dwh_connection.commit()
def pop_tbl1():
''' This function populates all rows in tbl1 '''
cur = dwh_connection.cursor()
cur.execute(""" INSERT INTO tbl1
select id,name,price from tbl2;""")
dwh_connection.commit()
db_login()
tbl1_del()
pop_tbl1()
dwh_connection.close()
##########################################
t1 = BashOperator(
task_id='DB_Connect',
python_callable=db_login(),
bash_command='python3 ~/airflow/dags/dwh_sample23.py',
dag=dag)
t2 = BashOperator(
task_id='del',
python_callable=tbl1_del(),
bash_command='python3 ~/airflow/dags/dwh_sample23.py',
dag=dag)
t3 = BashOperator(
task_id='populate',
python_callable=pop_tbl1(),
bash_command='python3 ~/airflow/dags/dwh_sample23.py',
dag=dag)
t1.set_downstream(t2)
t2.set_downstream(t3)
Could anyone assist? Thanks.
Instead of using BashOperator you can use PythonOperator and call db_login(), tbl1_del(), pop_tbl1() in PythonOperator
## Third party Library Imports
import pandas as pd
import psycopg2
import airflow
from airflow import DAG
from airflow.operators import PythonOperator
from datetime import datetime, timedelta
from sqlalchemy import create_engine
import io
# Following are defaults which can be overridden later on
default_args = {
'owner': 'admin',
'depends_on_past': False,
'start_date': datetime(2018, 5, 21),
'retries': 1,
'retry_delay': timedelta(minutes=1),
}
dag = DAG('dwh_sample23', default_args=default_args)
#######################
## Login to DB
def db_login():
''' This function connects to the Data Warehouse and returns the cursor to execute queries '''
global dwh_connection
try:
dwh_connection = psycopg2.connect(" dbname = 'dbname' user = 'user' password = 'password' host = 'hostname' port = '5439' sslmode = 'require' ")
except:
print("I am unable to connect to the database.")
print('Success')
return(dwh_connection)
def tbl1_del():
''' This function takes clears all rows from tbl1 '''
cur = dwh_connection.cursor()
cur.execute("""DELETE FROM tbl1;""")
dwh_connection.commit()
def pop_tbl1():
''' This function populates all rows in tbl1 '''
cur = dwh_connection.cursor()
cur.execute(""" INSERT INTO tbl1
select id,name,price from tbl2;""")
dwh_connection.commit()
db_login()
tbl1_del()
pop_tbl1()
dwh_connection.close()
##########################################
t1 = PythonOperator(
task_id='DB_Connect',
python_callable=db_login(),
dag=dag)
t2 = PythonOperator(
task_id='del',
python_callable=tbl1_del(),
dag=dag)
t3 = PythonOperator(
task_id='populate',
python_callable=pop_tbl1(),
dag=dag)
t1.set_downstream(t2)
t2.set_downstream(t3)
This is really old by now, but we got this error in prod and I found this question, and think its nice that it would have an answer.
Some of the code is getting executed during DAG load, i.e. you actually run
db_login()
tbl1_del()
pop_tbl1()
dwh_connection.close()
##########################################
inside webserver and scheduler loop, when they load dag definition from the file.
I believe you didn't intend that to happen.
Everything should work just fine if you just remove these 4 lines.
Generally don't place function you want executors to execute on file/module level, because when interpreter of the scheduler/webserver loads the file to get dag definition, it would invoke them.
Just try putting this in your dag file and see check webserver logs to see what happens.
from time import sleep
def do_some_printing():
print(1111111)
sleep(60)
do_some_printing()

Triggering A SubDag

EDITED
I have edited this question by considering the inputs from #tobi6
I copied the subdag operator from Airflow source code
Source code: https://github.com/apache/incubator-airflow/blob/master/airflow/operators/subdag_operator.py
I modified a few things in the execute method. The changes were made to trigger the SubDag and wait until the SubDag completes execution. The trigger is working great but the tasks are not being executed (DAG is in the running/Green state while the tasks are in the null/White state).
Please refer below for the changes I made:
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, Pool
from airflow.utils.decorators import apply_defaults
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.executors import GetDefaultExecutor
from time import sleep
import logging
from datetime import datetime
class SubDagOperator(BaseOperator):
template_fields = tuple()
ui_color = '#555'
ui_fgcolor = '#fff'
#provide_session
#apply_defaults
def __init__(
self,
subdag,
executor=GetDefaultExecutor(),
*args, **kwargs):
"""
Yo dawg. This runs a sub dag. By convention, a sub dag's dag_id
should be prefixed by its parent and a dot. As in `parent.child`.
:param subdag: the DAG object to run as a subdag of the current DAG.
:type subdag: airflow.DAG
:param dag: the parent DAG
:type subdag: airflow.DAG
"""
import airflow.models
dag = kwargs.get('dag') or airflow.models._CONTEXT_MANAGER_DAG
if not dag:
raise AirflowException('Please pass in the `dag` param or call '
'within a DAG context manager')
session = kwargs.pop('session')
super(SubDagOperator, self).__init__(*args, **kwargs)
# validate subdag name
if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id:
raise AirflowException(
"The subdag's dag_id should have the form "
"'{{parent_dag_id}}.{{this_task_id}}'. Expected "
"'{d}.{t}'; received '{rcvd}'.".format(
d=dag.dag_id, t=kwargs['task_id'], rcvd=subdag.dag_id))
# validate that subdag operator and subdag tasks don't have a
# pool conflict
if self.pool:
conflicts = [t for t in subdag.tasks if t.pool == self.pool]
if conflicts:
# only query for pool conflicts if one may exist
pool = (
session
.query(Pool)
.filter(Pool.slots == 1)
.filter(Pool.pool == self.pool)
.first()
)
if pool and any(t.pool == self.pool for t in subdag.tasks):
raise AirflowException(
'SubDagOperator {sd} and subdag task{plural} {t} both '
'use pool {p}, but the pool only has 1 slot. The '
'subdag tasks will never run.'.format(
sd=self.task_id,
plural=len(conflicts) > 1,
t=', '.join(t.task_id for t in conflicts),
p=self.pool
)
)
self.subdag = subdag
self.executor = executor
def execute(self, context):
dag_run = self.subdag.create_dagrun(
conf=context['dag_run'].conf,
state=State.RUNNING,
execution_date=context['execution_date'],
run_id='trig__' + str(datetime.utcnow()),
external_trigger=True
)
while True:
if dag_run.get_state() == State.FAILED or dag_run.get_state() == State.SUCCESS:
break
else:
sleep(10)
continue
Below is the code that shows how I'm using the same
from airflow import DAG
from operators.sd_operator import SubDagOperator # My SubDag Operator
from airflow.operators.python_operator import PythonOperator
import logging
from datetime import datetime
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'start_date': datetime(2017, 7, 17),
'email': ['airflow#example.com'],
'email_on_failure': False,
'email_on_retry': False,
}
def print_dag_details(**kwargs):
logging.info(str(kwargs['dag_run'].conf))
with DAG('example_dag', schedule_interval=None, catchup=False, default_args=default_args) as dag:
task_1 = SubDagOperator(
subdag=sub_dag_func('example_dag', 'sub_dag_1'),
task_id='sub_dag_1'
)
task_2 = SubDagOperator(
subdag=sub_dag_func('example_dag', 'sub_dag_2'),
task_id='sub_dag_2',
)
print_kwargs = PythonOperator(
task_id='print_kwargs',
python_callable=print_dag_details,
provide_context=True
)
print_kwargs >> task_1 >> task_2
Any information you provide would be helpful. Thanks in advance.
It is a bit hard to understand your question without context.
"I copied the subdag operator and modified a few things in the execute method."
From where was this copied?
"The trigger is working great ..."
How does this look like?
There are a few things I saw in the code:
It might be helpful to add assigned fields to the function call of sub_dag_func, e.g. sub_dag_func(subdag='parent_dag'...).
In the binary shift definition, used to set upstream / downstream there are tasks defined I cannot find in the DAG (df_job_1, df_job_2). This might be connected to SubDAGs (haven't looked into them yet).
The name of the sub dag seems inconsistent with the comment in the code saying By convention, a sub dag's dag_id should be prefixed by its parent and a dot but it is sub_dag_1, sub_dag_2

Resources