Adjust save location of Custom XCom Backend on a per task basis - airflow

I have posted a discussion question about this here as well https://github.com/apache/airflow/discussions/19868
Is it possible to specify arguments to a custom xcom backend? If I could force a task to return data (pyarrow table/dataset, pandas dataframe) which would save a file in the correct container with a "predictable file location" path, then that would be amazing. A lot of my custom operator code deals with creating the blob_path, saving the blob, and pushing a list of the blob_paths to xcom.
Since I work with many clients, I would prefer to have the data for Client A inside of the client-a container which uses a different SAS
When I save a file I consider that a "stage" of the data so I would prefer to keep it, so ideally I could provide a blob_path which matches the folder structure I generally use
class WasbXComBackend(BaseXCom):
def __init__(
self,
container: str = "airflow-xcom-backend",
path: str = guid(),
partition_columns: Optional[list[str]] = None,
existing_data_behavior: Optional[str] = None,
) -> None:
super().__init__()
self.container = container
self.path = path
self.partition_columns = partition_columns
self.existing_data_behavior = existing_data_behavior
#staticmethod
def serialize_value(self, value: Any):
if isinstance(value, pd.DataFrame):
hook = AzureBlobHook(wasb_conn_id="azure_blob")
with io.StringIO() as buf:
value.to_csv(path_or_buf=buf, index=False)
hook.load_string(
container_name=self.container,
blob_name=f"{self.path}.csv",
string_data=buf.getvalue(),
)
value = f"{self.container}/{self.path}.csv"
elif isinstance(value, pa.Table):
hook = AzureBlobHook(wasb_conn_id="azure_blob")
write_options = ds.ParquetFileFormat().make_write_options(
version="2.6", use_dictionary=True, compression="snappy"
)
written_files = []
ds.write_dataset(
data=value,
schema=value.schema,
base_dir=f"{self.container}/{self.path}",
format="parquet",
partitioning=self.partition_columns,
partitioning_flavor="hive",
existing_data_behavior=self.existing_data_behavior,
basename_template=f"{self.task_id}-{self.ts_nodash}-{{i}}.parquet",
filesystem=hook.create_filesystem(),
file_options=write_options,
file_visitor=lambda x: written_files.append(x.path),
use_threads=True,
max_partitions=2_000,
)
value = written_files
return BaseXCom.serialize_value(value)
#staticmethod
def deserialize_value(self, result) -> Any:
result = BaseXCom.deserialize_value(result)
if isinstance(result, str) and result.endswith(".csv"):
hook = AzureBlobHook(wasb_conn_id="azure_blob")
with io.BytesIO() as input_io:
hook.get_stream(
container_name=self.container,
blob_name=str(self.path),
input_stream=input_io,
)
input_io.seek(0)
return pd.read_csv(input_io)
elif isinstance(result, list) and ".parquet" in result:
hook = AzureBlobHook(wasb_conn_id="azure_blob")
return ds.dataset(
source=result, partitioning="hive", filesystem=hook.create_filesystem()
)
return result

It's not clear exactly what information you want to be able to retrieve to use as part of your "predictable file location". But there is a PR to pass basic things like dag_id, task_id etc on to serialize_value so that you can use them when naming your stored objects.
Until that is merged, you'll have to override BaseXCom.set.

You need to override BaseXCom.set
a working ,in production example
class MyXComBackend(BaseXCom):
#classmethod
#provide_session
def set(cls, key, value, execution_date, task_id, dag_id, session=None):
session.expunge_all()
# logic to use this custom_xcom_backend only with the necessary dag and task
if cls.is_task_to_custom_xcom(dag_id, task_id):
value = cls.custom_backend_saving_fn(value, dag_id, execution_date, task_id)
else:
value = BaseXCom.serialize_value(value)
# remove any duplicate XComs
session.query(cls).filter(
cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id
).delete()
session.commit()
# insert new XCom
from airflow.models.xcom import XCom # noqa
session.add(XCom(key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id))
session.commit()
#staticmethod
def is_task_to_custom_xcom(dag_id: str, task_id: str) -> bool:
return True # custom your logic here if necessary

Related

Airflow Dynamic Task mapping - DagBag import timeout

I have a DAG that fetches a list of items from a source, in batches of 10 at a time, and then does a dynamic task mapping on each batch. Here is the code
def tutorial_taskflow_api():
#task(multiple_outputs=True)
def get_items(limit, cur):
#actual logic is to fetch items and cursor from external API call
if cur == None:
cursor =limit+1
items = range (0, limit)
else:
cursor = cur+limit+1
items = range(cur, cur+limit)
return {'cursor': cursor, 'kinds': items}
#task
def process_item(item):
print(f"Processing item {item}")
#task
def get_cursor_from_response(response):
return response['cursor']
#task
def get_items_from_response(response):
return response['items']
cursor = None
limit = 10
while True:
response = get_items(limit, cursor)
items = get_items_from_response(response)
cursor = get_cursor_from_response(response)
if cursor:
process_item.expand(item=items)
if cursor == None:
break
tutorial_taskflow_api()
If you see, I attempt to get a list of items from a source, in batches of 10, and then do a dynamic task mapping on each of the batch.
However, when I import this item, i get the Dag Import timeout error:
Broken DAG: [/opt/airflow/dags/Test.py] Traceback (most recent call last):
File "/home/airflow/.local/lib/python3.7/site-packages/airflow/decorators/base.py", line 144, in _find_id_suffixes
for task_id in dag.task_ids:
File "/home/airflow/.local/lib/python3.7/site-packages/airflow/utils/timeout.py", line 69, in handle_timeout
raise AirflowTaskTimeout(self.error_message)
airflow.exceptions.AirflowTaskTimeout: DagBag import timeout for /opt/airflow/dags/Test.py after 30.0s.
Please take a look at these docs to improve your DAG import time:
* https://airflow.apache.org/docs/apache-airflow/2.5.1/best-practices.html#top-level-python-code
* https://airflow.apache.org/docs/apache-airflow/2.5.1/best-practices.html#reducing-dag-complexity, PID: 23822
How to solve this?
I went through the documentation and found that executing the While loop logic shouldn't really be there, but in some other task. But if I put that in some other task, how can I perform dynamic task mapping from inside that other task?
This code:
while True:
response = get_items(limit, cursor)
items = get_items_from_response(response)
cursor = get_cursor_from_response(response)
if cursor:
process_item.expand(item=items)
if cursor == None:
break
is running in the DagFileProcessor before creating a DAG run, and it's executing every min_file_process_interval, and each time Airflow retry to run a task in this dag. Airflow has some timeouts like dagbag_import_timeout which is the maximum duration the different DagFileProcessor have to process the dag files before a timeout exception, in your case if you have a big batch, or the API has some latency, you can easily exceed this duration.
Also you are considering cursor = get_cursor_from_response(response) as a normal python variable, but it is not the case, where the value is not available before creating a dag run.
Solution and best practices:
The Dynamic Task Mapping is designed to solve this problem, and it's flexible, so you can use it in different ways:
import pendulum
from airflow.decorators import dag, task
#dag(dag_id="tutorial_taskflow_api", start_date=pendulum.datetime(2023, 1, 1), schedule=None)
def tutorial_taskflow_api():
#task
def get_items(limit):
data = []
start_ind = 0
while True:
end_ind = min(start_ind + limit, 95) # 95 records in the API
items = range(start_ind, end_ind) if start_ind <= 90 else None # a fake end of data
if items is None:
break
data.extend(items)
start_ind = end_ind
return data
#task
def process_item(item):
print(f"Processing item {item}")
process_item.expand(item=get_items(limit=10))
tutorial_taskflow_api()
But if you want to process the data in batches, the best way is the mapped task groups, but unfortunately the nested mapped tasks is not supported yet, so you need to process items in a loop:
import pendulum
from airflow.decorators import dag, task, task_group
#dag(dag_id="tutorial_taskflow_api", start_date=pendulum.datetime(2023, 1, 1), schedule=None)
def tutorial_taskflow_api():
#task
def get_pages(limit):
start_ind = 0
pages = []
while True:
end_ind = min(start_ind + limit, 95) # 95 records in the API
page = dict(start=start_ind, end=end_ind) if start_ind <= 90 else None # a fake end of data
if page is None:
break
pages.append(page)
start_ind = end_ind
return pages
#task_group()
def process_batch(start, end):
#task
def get_items(start, end):
return list(range(start, end))
#task
def process_items(items):
for item in items:
print(f"Processing item {item}")
process_items(get_items(start=start, end=end))
process_batch.expand_kwargs(get_pages(10))
tutorial_taskflow_api()
Update:
There is the conf max_map_length which the maximum number of parallel mapped tasks/task group you can have. If you have some picks in your API, you can increase this limit (not recommended) or calculating the limit (batch size) dynamically:
import pendulum
from airflow.decorators import dag, task, task_group
#dag(dag_id="tutorial_taskflow_api", start_date=pendulum.datetime(2023, 1, 1), schedule=None)
def tutorial_taskflow_api():
#task
def get_limit():
import math
max_map_length = 1024
elements_count = 9999 # get from the API
preferd_batch_size = 10
return max(preferd_batch_size, math.ceil(elements_count/max_map_length))
#task
def get_pages(limit):
start_ind = 0
pages = []
while True:
end_ind = min(start_ind + limit, 95) # 95 records in the API
page = dict(start=start_ind, end=end_ind) if start_ind <= 90 else None # a fake end of data
if page is None:
break
pages.append(page)
start_ind = end_ind
return pages
#task_group()
def process_batch(start, end):
#task
def get_items(start, end):
return list(range(start, end))
#task
def process_items(items):
for item in items:
print(f"Processing item {item}")
process_items(get_items(start=start, end=end))
process_batch.expand_kwargs(get_pages(get_limit()))
tutorial_taskflow_api()

Adding context from django_tables2 library to admin panel

I'm trying to use django_tables2 in Django admin panel.
Part of the library that I'm interested in is:
'''
class SingleTableMixin(TableMixinBase):
table_class = None
table_data = None
def get_table_class(self):
if self.table_class:
return self.table_class
if self.model:
return tables.table_factory(self.model)
raise ImproperlyConfigured(
"You must either specify {0}.table_class or
{0}.model".format(type(self).__name__)
)
def get_table(self, **kwargs):
table_class = self.get_table_class()
table = table_class(data=self.get_table_data(), **kwargs)
return RequestConfig(self.request,
paginate=self.get_table_pagination(table)).configure(
table
)
def get_table_data(self):
if self.table_data is not None:
return self.table_data
elif hasattr(self, "object_list"):
return self.object_list
elif hasattr(self, "get_queryset"):
return self.get_queryset()
klass = type(self).__name__
raise ImproperlyConfigured(
"Table data was not specified. Define {}.table_data".format(klass)
)
def get_table_kwargs(self):
return {}
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
table = self.get_table(**self.get_table_kwargs())
context[self.get_context_table_name(table)] = table
return context
'''
The problem is that context is not sent to admin panel. Can I import somehow this context from the last method?

Airflow: get xcom from previous dag run

I am writing a sensor which scan s3 files for fix period of time and add the list of new files arrived at that period to xcom for next task. For that, I am trying to access list of files passed in xcom from previous run. I can do that using below snippet.
context['task_instance'].get_previous_ti(state=State.SUCCESS).xcom_pull(key='new_files',task_ids=self.task_id,dag_id=self.dag_id)
However, context object is passed in poke method and I was to access it in init. Is there another way to do it without using context.
Note - I do not want to directly access underlying database for xcom.
Thanks
I found this solution which (kinda) uses the underlying database but you dont have to create a sqlalchemy connection directly to use it.
The trick is using the airflow.models.DagRun object and specifically the find() function which allows you to grab all dags by id between two dates, then pull out the task instances and from there, access the xcoms.
default_args = {
"start_date": days_ago(0),
"retries": 0,
"max_active_runs": 1,
}
with models.DAG(
f"prev_xcom_tester",
catchup=False,
default_args=default_args,
schedule_interval="#hourly",
tags=["testing"],
) as dag:
def get_new_value(**context):
num = randint(1, 100)
logging.info(f"building new value: {num}")
return num
def get_prev_xcom(**context):
try:
dag_runs = models.DagRun.find(
dag_id="prev_xcom_tester",
execution_start_date=(datetime.now(timezone.utc) - timedelta(days=1)),
execution_end_date=datetime.now(timezone.utc),
)
this_val = context["ti"].xcom_pull(task_ids="get_new_value")
for dr in dag_runs[:-1]:
prev_val = dr.get_task_instance("get_new_value").xcom_pull(
"get_new_value"
)
logging.info(f"Checking dag run: {dr}, xcom was: {prev_val}")
if this_val == prev_val:
logging.info(f"we already processed {this_val} in {dr}")
return (
dag_runs[-2]
.get_task_instance("get_new_value")
.xcom_pull("get_new_value")
)
except Exception as e:
logging.info(e)
return 0
def check_vals_match(**context):
ti = context["ti"]
prev_run_val = ti.xcom_pull(task_ids="get_prev_xcoms")
current_run_val = ti.xcom_pull(task_ids="get_new_value")
logging.info(
f"Prev Run Val: {prev_run_val}\nCurrent Run Val: {current_run_val}"
)
return prev_run_val == current_run_val
xcom_setter = PythonOperator(task_id="get_new_value", python_callable=get_new_value)
xcom_getter = PythonOperator(
task_id="get_prev_xcoms",
python_callable=get_prev_xcom,
)
xcom_checker = PythonOperator(
task_id="check_xcoms_match", python_callable=check_vals_match
)
xcom_setter >> xcom_getter >> xcom_checker
This dag demonstrates how to:
Set a random int between 1 and 100 and passing it through xcom
Find all dagruns by dag_id and time span -> check if we have processed this value in the past
Return True if current value matches value from previous run.
Hope this helps!

Flask-WTF: Queries of FormFields in FieldList are none after validate_on_submit

I'm trying to generate dynamic forms using Flask-WTF to create a new product based on some templates. A product will have a list of required key-value pairs based on its type, as well as a list of parts required to build it. The current relevant code looks as follows:
forms.py:
class PartSelectionForm(Form):
selected_part = QuerySelectField('Part', get_label='serial', allow_blank=True)
part_type = StringField('Type')
slot = IntegerField('Slot')
required = BooleanField('Required')
def __init__(self, csrf_enabled=False, *args, **kwargs):
super(PartSelectionForm, self).__init__(csrf_enabled=False, *args, **kwargs)
class NewProductForm(Form):
serial = StringField('Serial', default='', validators=[DataRequired()])
notes = TextAreaField('Notes', default='')
parts = FieldList(FormField(PartSelectionForm))
views.py:
#app.route('/products/new/<prodmodel>', methods=['GET', 'POST'])
#login_required
def new_product(prodmodel):
try:
model = db.session.query(ProdModel).filter(ProdModel.id==prodmodel).one()
except NoResultFound, e:
flash('No products of model type -' + prodmodel + '- found.', 'error')
return redirect(url_for('index'))
keys = db.session.query(ProdTypeTemplate.prod_info_key).filter(ProdTypeTemplate.prod_type_id==model.prod_type_id)\
.order_by(ProdTypeTemplate.prod_info_key).all()
parts_needed = db.session.query(ProdModelTemplate).filter(ProdModelTemplate.prod_model_id==prodmodel)\
.order_by(ProdModelTemplate.part_type_id, ProdModelTemplate.slot).all()
class F(forms.NewProductForm):
pass
for key in keys:
if key.prod_info_key in ['shipped_os','factory_os']:
setattr(F, key.prod_info_key, forms.QuerySelectField(key.prod_info_key, get_label='version'))
else:
setattr(F, key.prod_info_key, forms.StringField(key.prod_info_key, validators=[forms.DataRequired()]))
form = F(request.form)
if request.method == 'GET':
for part in parts_needed:
entry = form.parts.append_entry(forms.PartSelectionForm())
entry.part_type.data=part.part_type_id
entry.slot.data=slot=part.slot
entry.required.data=part.required
entry.selected_part.query = db.session.query(Part).join(PartModel).filter(PartModel.part_type_id==part.part_type_id, Part.status=='inventory')
if form.__contains__('shipped_os'):
form.shipped_os.query = db.session.query(OSVersion).order_by(OSVersion.version)
if form.__contains__('factory_os'):
form.factory_os.query = db.session.query(OSVersion).order_by(OSVersion.version)
if form.validate_on_submit():
...
Everything works as expected on a GET request, but on the validate_on_submit I get errors. The error is that all of the queries and query_factories for the selected_part QuerySelectFields in the list of PartSelectionForms is none, causing either direct errors in WTForms validation code or when Jinja2 attempts to re-render the QuerySelectFields. I'm not sure why this happens on the POST when everything appears to be correct for the GET.
I realized that although I set the required queries on a GET I'm not doing it for any PartSelectionForm selected_part entries on the POST. Since I already intended part_type, slot, and required to be hidden form fields, I added the following immediately before the validate_on_submit and everything works correctly:
for entry in form.parts:
entry.selected_part.query = db.session.query(Part).join(PartModel).\
filter(PartModel.part_type_id==entry.part_type.data, Part.status=='inventory')

GTK TreeView with ListStore don't display anything. Python and SQLite3 used

I use GTK and Python for developing an application.
I want to load TreeView elements (1 column) from SQLite3 database.
But something go wrong (without any error)!
Here is a whole code:
#!/usr/bin/python
import sys
import sqlite3 as sqlite
from gi.repository import Gtk
from gi.repository import Notify
def notify(notifer, text, notificationtype=""):
Notify.init("Application")
notification = Notify.Notification.new (notifer, text, notificationtype)
notification.show ()
def get_object(gtkname):
builder = Gtk.Builder()
builder.add_from_file("main.ui")
return builder.get_object(gtkname)
def base_connect(basefile):
return sqlite.connect(basefile)
class Handler:
def main_destroy(self, *args):
Gtk.main_quit(*args)
def hardwaretool_clicked(self, widget):
baselist = get_object("subjectlist")
baselist.clear()
base = base_connect("subjectbase")
with base:
cur = base.cursor()
cur.execute("SELECT * FROM sub")
while True:
row = cur.fetchone()
if row == None:
break
iter = baselist.append()
print "row ", row[0]
baselist.set(iter, 0, row[0])
cur.close()
def gamestool_clicked(self, widget):
print("gamestool clicked!!!!! =)")
def appstool_clicked(self, widget):
print("appstool clicked!!!!! =)")
def fixtool_clicked(self, widget):
notify("Charmix","Fix Applied", "dialog-ok")
def brokenfixtool_clicked(self, widget):
notify("Charmix","Broken Fix Report Sended", "dialog-error")
def sendfixtool_clicked(self, widget):
notify("Charmix","Fix Sended", "dialog-information")
class CharmixMain:
def __init__(self):
builder = Gtk.Builder()
builder.add_from_file("main.ui")
self.window = builder.get_object("main")
self.subject = builder.get_object("subjectlist")
self.problem = builder.get_object("problemlist")
self.toolbar = builder.get_object("toolbar")
self.hardwaretool = builder.get_object("hardwaretool")
self.gamestool = builder.get_object("gamestool")
self.appstool = builder.get_object("appstool")
self.fixtool = builder.get_object("fixtool")
self.brokenfixtool = builder.get_object("brokenfixtool")
self.sendfixtool = builder.get_object("sendfixtool")
builder.connect_signals(Handler())
context = self.toolbar.get_style_context()
context.add_class(Gtk.STYLE_CLASS_PRIMARY_TOOLBAR)
if __name__ == "__main__":
Charmix = CharmixMain()
Charmix.window.show()
Gtk.main()
I'm interested in this part (not working normally):
def hardwaretool_clicked(self, widget):
baselist = get_object("subjectlist")
baselist.clear()
base = base_connect("subjectbase")
with base:
cur = base.cursor()
cur.execute("SELECT * FROM sub")
while True:
row = cur.fetchone()
if row == None:
break
iter = baselist.append()
print "row ", row[0]
baselist.set(iter, 0, row[0])
cur.close()
TreeView(subjecttree) don't display anything, but print "row ", row[0] works fine and display all the strings.
Please, help me.
Maybe i need to repaint TreeView or thomething like that?
Do you know, how can I get it?
The problem is in your get_object method.
When you do:
builder = Gtk.Builder()
builder.add_from_file("main.ui")
you're actually creating a new window; even if you are using the same ui file, you are creating a completely different widget.
One way to get around the problem of accesing the widgets you need to process with your handler is to pass them as parameter of the constructor:
class Handler(object):
def __init__(self, widget1, widget2):
self.widget1 = widget1
self.widget2 = widget2
...
You can use those widgets on the handler's method afterwards.
Another way of accesing the widgets in a more 'decoupled' way is to add the object you want to use as the last parameter of the connect method when you're connecting signals; the drawback is that you would have to do this manually (since Glade doesn't provide this posibility)
self.widget.connect('some-signal', handler.handler_method, object_to_use)

Resources