In Airflow 2 taskflow API I can, using the following code examples, easily push and pull XCom values between tasks:-
#task(task_id="task_one")
def get_height() -> int:
response = requests.get("https://swapi.dev/api/people/4")
data = json.loads(response.text)
height = int(data["height"])
return height
#task(task_id="task_two")
def check_height(val):
# Show val:
print(f"Value passed in is: {val}")
check_height(get_height())
I can see that the val passed into check_height is 202 and is wrapped in the xcom default key 'return_value' and that's fine for some of the time, but I generally prefer to use specific keys.
My question is how can I push the XCom with a named key? This was really easy previously with ti.xcom_push where you could just supply the key name you wanted the value to be stuffed into, but I can't quite put my finger on how to achieve this in the taskflow api workflow.
Would appreciate any pointers or (simple, please!) examples on how to do this.
You can just set ti in the decorator as:
#task(task_id="task_one", ti)
def get_height() -> int:
response = requests.get("https://swapi.dev/api/people/4")
data = json.loads(response.text)
height = int(data["height"])
# Handle named Xcom
ti.xcom_push("my_key", height)
For cases where you need context in deep function you can also use get_current_context. I'll use it in my example below just to show it but it's not really required in your case.
here is a working example:
import json
from datetime import datetime
import requests
from airflow.decorators import dag, task
from airflow.operators.python import get_current_context
DEFAULT_ARGS = {"owner": "airflow"}
#dag(dag_id="stackoverflow_dag", default_args=DEFAULT_ARGS, schedule_interval=None, start_date=datetime(2020, 2, 2))
def my_dag():
#task(task_id="task_one")
def get_height() -> int:
response = requests.get("https://swapi.dev/api/people/4")
data = json.loads(response.text)
height = int(data["height"])
# Handle named Xcom
context = get_current_context()
ti = context["ti"]
ti.xcom_push("my_key", height)
return height
#task(task_id="task_two")
def check_height(val):
# Show val:
print(f"Value passed in is: {val}")
#Read from named Xcom
context = get_current_context()
ti = context["ti"]
ti.xcom_pull("task_one")
print(f"Value passed from xcom my_key is: {val}")
check_height(get_height())
my_dag = my_dag()
two xcoms being pushed (one for the returned value and one with the by the key we choose):
printing the two xcoms in downstream task_two:
Related
Is there a built-in facility or some operator that will run a sensor and negate its status? I am writing a workflow that needs to detect that an object does not exist in order to proceed to eventual success. I have a sensor, but it detects when the object does exist.
For instance, I would like my workflow to detect that an object does not exist. I need almost exactly S3KeySensor, except that I need to negate its status.
The use case you are describing is checking key in S3, if exist wait otherwise continue workflow. As you mentioned this is a Sensor use case. The S3Hook has function check_for_key that checks if key exist so all needed is just to wrap it with Sensor poke function..
A simple basic implementation would be:
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.sensors.base import BaseSensorOperator
class S3KeyNotPresentSensor(BaseSensorOperator):
""" Waits for a key to not be present in S3. """
template_fields: Sequence[str] = ('bucket_key', 'bucket_name')
def __init__(
self,
*,
bucket_key: str,
bucket_name: Optional[str] = None,
aws_conn_id: str = 'aws_default',
verify: Optional[Union[str, bool]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.bucket_key = [bucket_key] if isinstance(bucket_key, str) else bucket_key
self.aws_conn_id = aws_conn_id
self.verify = verify
self.hook: Optional[S3Hook] = None
def poke(self, context: 'Context'):
return not self.get_hook().check_for_key(self.bucket_key, self.bucket_name)
def get_hook(self) -> S3Hook:
"""Create and return an S3Hook"""
if self.hook:
return self.hook
self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
return self.hook
I ended up going another way. I can use the trigger_rule argument of (any) Task -- by setting it to one_failed or all_failed on the next task I can play around with the desired status.
For example,
file_exists = FileSensor(task_id='exists', timeout=3, poke_interval=1, filepath='/tmp/error', mode='reschedule')
sing = SmoothOperator(task_id='sing', trigger_rule='all_failed')
file_exists >> sing
It requires no added code or operator, but has the possible disadvantage of being somewhat surprising.
Replying to myself in the hope that this may be useful to someone else. Thanks!
i've a task that returns a tuple. passing one element of that tuple to another task is not working. i can pass the entire tuple, but not an element from the return value:
from airflow.decorators import dag, task
from pendulum import datetime
#task
def create():
return 1, 2
#task
def consume(one):
print('arg is', one)
#dag(
schedule_interval='#once',
start_date=datetime(2022, 4, 10),
)
def test_dag():
out = create()
consume(out[0]) # does not work: the task gets None as argument
consume(out) # this works
dag = test_dag()
Within TaskFlow the object returned from a TaskFlow function is actually an XComArg. These XComArgs are abstractions over the classic task_instance.xcom_pull(...) retrieval of XComs. Additionally XComArg objects implement __getitem__ for specifying an XCom key other than "return_value" (which is the default).
So what's going on in the case of using consume(out[0]) is that Airflow is leveraging an XComArg object to retrieve an XCom with a key of 0 not retrieving the output from create() and then the first item. What's going on behind the scenes is task_instance.xcom_pull(task_ids="create", key=0).
Yes, this is unexpected in a way and it's not quite inline with the classic xcom_pull() approach. This issue has been opened to try and achieve feature parity.
In the meantime, you can of course access the whole XComArg like you show by just using consume(out) or you can update the TaskFlow function to return a dictionary and use multiple_outputs to have each key/value pair serialized as their own XComs.
For example:
from pendulum import datetime
from airflow.decorators import dag, task
#task(multiple_outputs=True)
def create():
return {"one": 1, "two": 2}
#task
def consume(arg):
print('arg is', arg)
#dag(
schedule_interval='#once',
start_date=datetime(2022, 4, 10),
)
def test_dag():
out = create()
consume(out["one"])
dag = test_dag()
Separate XComs created from the create task:
consume task log:
Side note: multiple_outputs can also be inferred if the TaskFlow function has a dictionary return type annotation too. This will set multiple_outputs=True based on the return annotation:
from typing import Dict
#task
def create() -> Dict[str, int]:
return {"one": 1, "two": 2}
I have a requirement to compute a value in python operator and use it in other operators as shown below .But I'm getting "dag_var does not exist" for spark submit and email operators/
I'm declaring dag_var as a global variable in the python callable. But I'm not able to access this in other operators.
def get_dag_var(ds, **kwargs):
global dag_var
dag_var = kwargs['dag_run'].run_id
with DAG(
dag_id='sample',
schedule_interval=None, # executes at 6 AM UTC every day
start_date=datetime(2021, 1, 1),
default_args=default_args,
catchup=False
) as dag:
get_dag_var = PythonOperator(
task_id='get_dag_id',
provide_context=True,
python_callable=get_dag_var)
spark_submit = SparkSubmitOperator(application="abc".....
..
application_args = [dag_var])
failure_notification = EmailOperator(
task_id = "failure_notification ",
to='abc#gmail.com',
subject='Workflow Failes',
trigger_rule="one_failed",
html_content= f""" <h3>Failure Mail - {dag_var}</h3> """
)
get_dag_var >> spark_submit >> failure_notification
Any help is appreciated. Thank you.
You can share data between operators using XComs. In your get_dag_var function, any returned value is automatically stored as an XCom record in Airflow. You can inspect the values under Admin -> XComs.
To use an XCom value in a following task, you can apply templating:
spark_submit = SparkSubmitOperator(
application="ABC",
...,
application_args = ["{{ ti.xcom_pull(task_ids='get_dag_id') }}"],
)
The {{ }} define a templated string that is evaluated at runtime. ti.xcom_pull will "pull" the XCom value from the get_dag_id task at runtime.
One thing to note using templating: not all operator's arguments are template-able. Non-template-able arguments do not evaluate {{ }} at runtime. SparkSubmitOperator.application_args and EmailOperator.html_content are template-able, meaning a templated string is evaluated at runtime and you'll be able to provide an XCom value. Inspect the template_fields property for your operator to know which fields are template-able and which are not.
And one thing to note using XComs: be aware the XCom value is stored in the Airflow metastore, so be careful not to return huge variables which might not fit in a database record. To store XCom values in a different system than the Airflow metastore, check out custom XCom backends.
My code look like this:
def etl():
for item in ['FIRST','SECCOND','THIRD']:
if item == 'a':
requests = ['Data1','Data3']
elif item == 'b':
requests = ['Data1']
for data_name in requests:
#task(task_id=f'{item}_{data_name}_task_a')
def taska():
a,b = some_func
vars_dict = {'a': a,
'b': b}
return vars_dict
#task(task_id=f'{account}_{data_name}_get_liveops_data')
def taskb(vars_dict):
some_other_func
return True
if data_name=='Data1':
#task(task_id='last_task')
def last_task(success):
dim_experiments.main()
return
vars_dict = taska()
success = taskb(vars_dict)
last_task(success)
myc_dag = etl()
The dag looks like this:
When should look like this:
The goals is to have last_task dependent of taska and taskb except for that taska and taskb that download Data3 Requests. I am not able to achieve it using TaskFlow API
The parallel dependency is occurring because calling the last_task() TaskFlow function and setting the task dependency to it (implicitly via the TaskFlow API) is done within the same loop which calls the other tasks. Each call of a TaskFlow function will create a new task node. If last_task was pulled outside the loops and only the necessary dependencies were set inside the loops, you would achieve the desired structure.
Let's take a simplified version of your code as an example.
from datetime import datetime
from airflow.decorators import dag, task
#dag(dag_id="__example__", start_date=datetime(2021, 11, 1), schedule_interval=None)
def etl():
#task(task_id="last_task")
def last_task(some_input=None):
...
for item in ["a", "b"]:
#task
def taska():
return {"a": "A", "b": "B"}
#task
def taskb(input):
...
success = taskb(taska())
last_task(success)
myc_dag = etl()
In the DAG above, taska(), taskb(), and last_task() TaskFlow functions are all called and their task dependencies set within the loop. So, we see 2 parallel paths:
To have last_task() become a shared downstream task to both paths, we need to pull the call to last_task() (meaning that we only create a task node once) but keep the task dependency between taskb() and last_task() intact. This can be done with a small refactor of the example:
#dag(dag_id="__example__", start_date=datetime(2021, 11, 1), schedule_interval=None)
def etl():
#task(task_id="last_task")
def last_task(some_input=None):
...
last_task = last_task()
for item in ["a", "b"]:
#task
def taska():
return {"a": "A", "b": "B"}
#task
def taskb(input):
...
success = taskb(taska())
success >> last_task
myc_dag = etl()
Notice that the last_task() TaskFlow function is called outside of the loop creating the other tasks. This ensures that the last_task() task is only created once. The other change is to set the last_task() call to a variable and use this variable to then declare the task dependency to taskb() (similar to what you were doing with the success variable in your original code snippet). With these small changes we get 2 paths with a shared final task as last_task():
Is there a way to convert a pydantic model to query parameters in fastapi?
Some of my endpoints pass parameters via the body, but some others pass them directly in the query. All this endpoints share the same data model, for example:
class Model(BaseModel):
x: str
y: str
I would like to avoid duplicating my definition of this model in the definition of my "query-parameters endpoints", like for example test_query in this code:
class Model(BaseModel):
x: str
y: str
#app.post("/test-body")
def test_body(model: Model): pass
#app.post("/test-query-params")
def test_query(x: str, y: str): pass
What's the cleanest way of doing this?
The documentation gives a shortcut to avoid this kind of repetitions. In this case, it would give:
from fastapi import Depends
#app.post("/test-query-params")
def test_query(model: Model = Depends()): pass
This will allow you to request /test-query-params?x=1&y=2 and will also produce the correct OpenAPI description for this endpoint.
Similar solutions can be used for using Pydantic models as form-data descriptors.
Special case that isn't mentioned in the documentation for Query Parameters Lists, for example with:
/members?member_ids=1&member_ids=2
The answer provided by #cglacet will unfortunately ignore the array for such a model:
class Model(BaseModel):
member_ids: List[str]
You need to modify your model like so:
class Model(BaseModel):
member_ids: List[str] = Field(Query([]))
Answer from #fnep on GitHub here
This solution is very apt if your schema is "minimal".
But, when it comes to a complicated one like this, Set description for query parameter in swagger doc using Pydantic model, it is better to use a "custom dependency class"
from fastapi import Depends, FastAPI, Query
app = FastAPI()
class Model:
def __init__(
self,
y: str,
x: str = Query(
default='default for X',
title='Title for X',
deprecated=True
)
):
self.x = x
self.y = y
#app.post("/test-body")
def test_body(model: Model = Depends()):
return model
If you are using this method, you will have more control over the OpenAPI doc.
#cglacet 's answer is simple and works, but it will raise pydantic ValidationError when validation fail and not gonna pass the error to client.
You can find reason here.
This works and pass message to client. Code from here.
import inspect
from fastapi import Query, FastAPI, Depends
from pydantic import BaseModel, ValidationError
from fastapi.exceptions import RequestValidationError
class QueryBaseModel(BaseModel):
def __init_subclass__(cls, *args, **kwargs):
field_default = Query(...)
new_params = []
for field in cls.__fields__.values():
default = Query(field.default) if not field.required else field_default
annotation = inspect.Parameter.empty
new_params.append(
inspect.Parameter(
field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=default,
annotation=annotation,
)
)
async def _as_query(**data):
try:
return cls(**data)
except ValidationError as e:
raise RequestValidationError(e.raw_errors)
sig = inspect.signature(_as_query)
sig = sig.replace(parameters=new_params)
_as_query.__signature__ = sig # type: ignore
setattr(cls, "as_query", _as_query)
#staticmethod
def as_query(parameters: list) -> "QueryBaseModel":
raise NotImplementedError
class ParamModel(QueryBaseModel):
start_datetime: datetime
app = FastAPI()
#app.get("/api")
def test(q_param: ParamModel: Depends(ParamModel.as_query))
start_datetime = q_param.start_datetime
...
return {}