Sensor return value can not be stored/retreived using PokeReturnValue - airflow

Below code creates the dag (the graph is also attached) which contains 2 PythonSensors and a PythonOperator.
First Sensors creates a random integer list as data and a random boolean with 50% chance of success. It logs generated values and returns PokeReturnValue
Second sensor and Python operator both tries to get data from xcom and log them.
Graph of DAG
# region IMPORTS
import random
import logging
from datetime import datetime, timedelta
from airflow import DAG
from heliocampus.configuration.constants import Constants
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor
from airflow.sensors.base import PokeReturnValue
from airflow.utils.trigger_rule import TriggerRule
from box import Box
# endregion
# region configuration
constants = Constants()
dagconfig = Box({ "Code":"Test" })
# endregion
def main() -> DAG:
# region default_args
args = dict()
args['start_date'] = datetime(2021, 1, 1)
# endregion
with DAG(dag_id=dagconfig.Code, schedule_interval="#once", default_args=args, tags=['test', 'V0.1.4']) as dag:
start = EmptyOperator(task_id="start")
# region Sensors
check_all_expired_tables = PythonSensor(
task_id="CHECK_ALL_EXPIRED_TABLES",
poke_interval=timedelta(seconds=20).total_seconds(),
timeout=timedelta(minutes=1).total_seconds(),
mode="reschedule",
python_callable=check_expired_tables,
trigger_rule=TriggerRule.ALL_SUCCESS
)
check_all_expired_tables_notification = PythonOperator(
task_id="CHECK_ALL_EXPIRED_TABLES_NOTIFICATION",
python_callable=sensor_result_nofitication,
op_kwargs={"notification_source":"CHECK_ALL_EXPIRED_TABLES"},
trigger_rule=TriggerRule.ALL_FAILED
)
verify_ods_operator = PythonSensor(
task_id="VERIFY_ODS",
poke_interval=timedelta(seconds=30).total_seconds(),
timeout=timedelta(hours=2).total_seconds(),
mode="reschedule",
python_callable=verify_ods,
op_kwargs={"notification_source":"CHECK_ALL_EXPIRED_TABLES"},
trigger_rule=TriggerRule.ALL_SUCCESS
)
# endregion
end = EmptyOperator(task_id="end")
start >> check_all_expired_tables >> verify_ods_operator >> end
check_all_expired_tables >> check_all_expired_tables_notification
return dag
# region Notifications
def sensor_result_nofitication(ti, notification_source):
actual_xcom_value = ti.xcom_pull(task_ids=[notification_source])
logging.info(f"sensor_result_nofitication : Sensor without key from {notification_source} is {actual_xcom_value}")
actual_xcom_value = ti.xcom_pull(key='return_value', task_ids=[notification_source])
logging.info(f"sensor_result_nofitication : Sensor return_value from {notification_source} is {actual_xcom_value}")
# endregion
def check_expired_tables():
randomlist = random.sample(range(10, 30), 5)
randomResult = random.randint(0, 100) > 50
logging.info(f"check_expired_tables : returning PokeReturnValue(is_done={randomResult}, xcom_value={randomlist})")
return PokeReturnValue(is_done=randomResult, xcom_value=randomlist)
def verify_ods(ti, notification_source):
actual_xcom_value = ti.xcom_pull(task_ids=[notification_source])
logging.info(f"verify_ods : Sensor without key from {notification_source} is {actual_xcom_value}")
actual_xcom_value = ti.xcom_pull(key='return_value', task_ids=[notification_source])
logging.info(f"verify_ods : Sensor return_value from {notification_source} is {actual_xcom_value}")
rnd = random.randint(0, 100)
logging.info("Random Number : {num}".format(num=rnd))
return (rnd > 20)
main()
Regardless of whether the first sensor is successfull or not the data from xcom can not be logged in the second sensor or python operator.
I don't know if the problem is on the pushing side or pulling side.
I can not see any rows inserted in airflow database (xcom table).

The problem lives in the PythonSensor which is coercing the return of the python callable to boolean without checking its type first:
return_value = self.python_callable(*self.op_args, **self.op_kwargs)
return PokeReturnValue(bool(return_value))
To get the expected behavior something like this needs to be added to the PythonSensor:
return return_value if isinstance(return_value, PokeReturnValue) else PokeReturnValue(bool(return_value)

Related

Airflow timetable that combines multiple cron expressions?

I have several cron expressions that I need to apply to a single DAG. There is no way to express them with one single cron expression.
Airflow 2.2 introduced Timetable. Is there an implementation that takes a list of cron expressions?
I was looking for the same thing, but didn't find anything. It would be nice if a standard one came with Airflow.
Here's a 0.1 version that I wrote for Airflow 2.2.5.
# This file is <airflow plugins directory>/timetable.py
from typing import Any, Dict, List, Optional
import pendulum
from croniter import croniter
from pendulum import DateTime, Duration, timezone, instance as pendulum_instance
from airflow.plugins_manager import AirflowPlugin
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.exceptions import AirflowTimetableInvalid
class MultiCronTimetable(Timetable):
valid_units = ['minutes', 'hours', 'days']
def __init__(self,
cron_defs: List[str],
timezone: str = 'Europe/Berlin',
period_length: int = 0,
period_unit: str = 'hours'):
self.cron_defs = cron_defs
self.timezone = timezone
self.period_length = period_length
self.period_unit = period_unit
def infer_manual_data_interval(self, run_after: DateTime) -> DataInterval:
"""
Determines date interval for manually triggered runs.
This is simply (now - period) to now.
"""
end = run_after
if self.period_length == 0:
start = end
else:
start = self.data_period_start(end)
return DataInterval(start=start, end=end)
def next_dagrun_info(
self,
*,
last_automated_data_interval: Optional[DataInterval],
restriction: TimeRestriction) -> Optional[DagRunInfo]:
"""
Determines when the DAG should be scheduled.
"""
if restriction.earliest is None:
# No start_date. Don't schedule.
return None
is_first_run = last_automated_data_interval is None
if is_first_run:
if restriction.catchup:
scheduled_time = self.next_scheduled_run_time(restriction.earliest)
else:
scheduled_time = self.previous_scheduled_run_time()
if scheduled_time is None:
# No previous cron time matched. Find one in the future.
scheduled_time = self.next_scheduled_run_time()
else:
last_scheduled_time = last_automated_data_interval.end
if restriction.catchup:
scheduled_time = self.next_scheduled_run_time(last_scheduled_time)
else:
scheduled_time = self.previous_scheduled_run_time()
if scheduled_time is None or scheduled_time == last_scheduled_time:
# No previous cron time matched,
# or the matched cron time was the last execution time,
scheduled_time = self.next_scheduled_run_time()
elif scheduled_time > last_scheduled_time:
# Matched cron time was after last execution time, but before now.
# Use this cron time
pass
else:
# The last execution time is after the most recent matching cron time.
# Next scheduled run will be in the future
scheduled_time = self.next_scheduled_run_time()
if scheduled_time is None:
return None
if restriction.latest is not None and scheduled_time > restriction.latest:
# Over the DAG's scheduled end; don't schedule.
return None
start = self.data_period_start(scheduled_time)
return DagRunInfo(run_after=scheduled_time, data_interval=DataInterval(start=start, end=scheduled_time))
def data_period_start(self, period_end: DateTime):
return period_end - Duration(**{self.period_unit: self.period_length})
def croniter_values(self, base_datetime=None):
if not base_datetime:
tz = timezone(self.timezone)
base_datetime = pendulum.now(tz)
return [croniter(expr, base_datetime) for expr in self.cron_defs]
def next_scheduled_run_time(self, base_datetime: DateTime = None):
min_date = None
tz = timezone(self.timezone)
if base_datetime:
base_datetime_localized = base_datetime.in_timezone(tz)
else:
base_datetime_localized = pendulum.now(tz)
for cron in self.croniter_values(base_datetime_localized):
next_date = cron.get_next(DateTime)
if not min_date:
min_date = next_date
else:
min_date = min(min_date, next_date)
if min_date is None:
return None
return pendulum_instance(min_date)
def previous_scheduled_run_time(self, base_datetime: DateTime = None):
"""
Get the most recent time in the past that matches one of the cron schedules
"""
max_date = None
tz = timezone(self.timezone)
if base_datetime:
base_datetime_localized = base_datetime.in_timezone(tz)
else:
base_datetime_localized = pendulum.now(tz)
for cron in self.croniter_values(base_datetime_localized):
prev_date = cron.get_prev(DateTime)
if not max_date:
max_date = prev_date
else:
max_date = max(max_date, prev_date)
if max_date is None:
return None
return pendulum_instance(max_date)
def validate(self) -> None:
if not self.cron_defs:
raise AirflowTimetableInvalid("At least one cron definition must be present")
if self.period_unit not in self.valid_units:
raise AirflowTimetableInvalid(f'period_unit must be one of {self.valid_units}')
if self.period_length < 0:
raise AirflowTimetableInvalid(f'period_length must not be less than zero')
try:
self.croniter_values()
except Exception as e:
raise AirflowTimetableInvalid(str(e))
#property
def summary(self) -> str:
"""A short summary for the timetable.
This is used to display the timetable in the web UI. A cron expression
timetable, for example, can use this to display the expression.
"""
return ' || '.join(self.cron_defs) + f' [TZ: {self.timezone}]'
def serialize(self) -> Dict[str, Any]:
"""Serialize the timetable for JSON encoding.
This is called during DAG serialization to store timetable information
in the database. This should return a JSON-serializable dict that will
be fed into ``deserialize`` when the DAG is deserialized.
"""
return dict(cron_defs=self.cron_defs,
timezone=self.timezone,
period_length=self.period_length,
period_unit=self.period_unit)
#classmethod
def deserialize(cls, data: Dict[str, Any]) -> "MultiCronTimetable":
"""Deserialize a timetable from data.
This is called when a serialized DAG is deserialized. ``data`` will be
whatever was returned by ``serialize`` during DAG serialization.
"""
return cls(**data)
class CustomTimetablePlugin(AirflowPlugin):
name = "custom_timetable_plugin"
timetables = [MultiCronTimetable]
To use it, you provide a list of cron expressions, optionally a timezone string, optionally a period length and period unit.
For my use case I don't actually need the period length + unit, which are used to determine the DAG's data_interval. You can just leave them at the default value of 0 minutes, if your DAG doesn't care about the data_interval.
I tried to imitate standard schedule_interval behaviour. For example if catchup = False and the DAG could have potentially been triggered several times since the last run (for whatever reason, for example the DAG ran longer than expected, or the scheduler wasn't running, or it's the DAG's very first time being scheduled), then the DAG will be scheduled to run for the latest previous matching time.
I haven't really tested it with catchup = True, but in theory it would run for every matching cron time since the DAG's start_date (but only once per distinct time, for example with */30 * * * * and 0 * * * * the DAG would run twice per hour, not three times).
Example DAG file:
from time import sleep
import airflow
from airflow.operators.python import PythonOperator
import pendulum
from timetable import MultiCronTimetable
def sleepy_op():
sleep(660)
with airflow.DAG(
dag_id='timetable_test',
start_date=pendulum.datetime(2022, 6, 2, tz=pendulum.timezone('America/New_York')),
timetable=MultiCronTimetable(['*/5 * * * *', '*/3 * * * fri,sat', '1 12 3 * *'], timezone='America/New_York', period_length=10, period_unit='minutes'),
catchup=False,
max_active_runs=1) as dag:
sleepy = PythonOperator(
task_id='sleepy',
python_callable=sleepy_op
)

Dynamic glue operators in task group on runtime

I am trying to create a dag for a ETL solution where we have 4 stages to the the pipeline. We have source_2_s3,staging,dim_population,fact_loads
Each of these stages have multiple glue jobs in them that can run in parallel. I am trying to run these stages as taskgroups and passing a list of jobnames as parameters so we can create the glue operators on runtime but I am only getting one gluejob for each task group? What am I doing wrong. My code is as follows:
The yaml file contains all jobs for each task group and the gluejob script location as keyvalue pair
from os import path
from airflow import DAG
from airflow.providers.amazon.aws.operators.glue import AwsGlueJobOperator
import yaml
from airflow.models import Variable
from airflow.utils.dates import days_ago
from airflow.operators.dummy import DummyOperator
from airflow.utils.task_group import TaskGroup
AwsGlueJobOperator.ui_color = "#F3D5C5"
DAG_FOLDER_PATH = path.dirname(__file__)
ENVIRONMENT = Variable.get("environment")
CONFIG_FILE_NAME = f"dwh_config_{ENVIRONMENT}.yml"
with open(path.join(DAG_FOLDER_PATH, CONFIG_FILE_NAME), 'r') as fl:
cfg = yaml.safe_load(fl)["dwh"]
aws_conn_id = cfg["aws"]["aws_conn_id"]
glue_scripts_bucket_name = cfg["glue"]["app"]["bucket_name"]
src1_tg = cfg["glue"]["app"]["src1_taskgroup"]
src2_tg = cfg["glue"]["app"]["src2_taskgroup"]
staging_tg = cfg["glue"]["app"]["staging_taskgroup"]
dim_tg = cfg["glue"]["app"]["dim_taskgroup"]
fact_tg = cfg["glue"]["app"]["fact_taskgroup"]
def glue_operator(list_of_jobs):
for jobname, script_path in list_of_jobs.items():
return AwsGlueJobOperator(
task_id=jobname,
dag=dag,
aws_conn_id='aws_default',
region_name='eu-west-2',
job_name=jobname,
script_location=path.join("s3://", glue_scripts_bucket_name,
script_path)
)
with DAG(dag_id="CLIENTUSAGE_DAILY_DWH_REFRESH", schedule_interval="#daily", start_date=days_ago(1),
tags=['Clientusage']) as dag:
batch_start_job = DummyOperator(task_id="START")
batch_close_job = DummyOperator(task_id="END")
with TaskGroup(group_id='src1_tg ') as src1_taskgroup:
list = glue_operator(src1_tg)
with TaskGroup(group_id='src2_tg') as src2_taskgroup:
list = glue_operator(src2_tg)
with TaskGroup(group_id='staging_tg') as staging_taskgroup:
list = glue_operator(staging_tg)
with TaskGroup(group_id='dim_tg') as dim_taskgroup:
list = glue_operator(dim_tg)
with TaskGroup(group_id='fact_tg') as fact_taskgroup:
list = glue_operator(fact_tg)
batch_start_job >> [src1_taskgroup,
src2_taskgroup] >> staging_taskgroup >> dim_taskgroup >> fact_taskgroup
'''
The issue was my function for generating dags. I forgot the basics of functions :D. Had to append all operators to an array before returning it.
def glue_operator(list_of_jobs):
operators = []
for jobname, script_path in list_of_jobs.items():
operators.append(AwsGlueJobOperator(
task_id=jobname,
dag=dag,
aws_conn_id='aws_default',
region_name='eu-west-2',
job_name=jobname,
script_location=path.join("s3://", glue_scripts_bucket_name,
script_path)
))
return operators

Airflow dynamically genarated task not run in order

I have created dynamic tasks generation dag. Tasks are generated accurately, But those tasks are not trigger in order,not work in consistently.
i have noticed it triggered on alphanumeric order.
Let's check run_modification_ tasks. i have generated 0 to 29 tasks. i have noticed it trigger on below format.
run_modification_0
run_modification_1
run_modification_10
run_modification_11
run_modification_12
run_modification_13
run_modification_14
run_modification_15
run_modification_16
run_modification_17
run_modification_18
run_modification_19
run_modification_2
run_modification_21
run_modification_23....
But i need to run it on tasks order like
run_modification_0
run_modification_1
run_modification_2
run_modification_3
run_modification_4
run_modification_5..
Please help me to run those tasks on task created order.
from datetime import date, timedelta, datetime
from airflow.utils.dates import days_ago
from airflow.models import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.postgres_operator import PostgresOperator
from airflow.hooks.postgres_hook import PostgresHook
from airflow.models import Variable
import os
args = {
'owner': 'Airflow',
'start_date': days_ago(2),
}
dag = DAG(
dag_id='tastOrder',
default_args=args,
schedule_interval=None,
tags=['task']
)
modification_processXcom = """ cd {{ ti.xcom_pull(task_ids=\'run_modification_\'+params.i, key=\'taskDateFolder\') }} """
def modificationProcess(ds,**kwargs):
today = datetime.strptime('2021-01-01', '%Y-%m-%d').date()
i = str(kwargs['i'])
newDate = today-timedelta(days=int(i))
print(str(newDate))
kwargs["ti"].xcom_push("taskDateFolder", str(newDate))
def getDays():
today = today = datetime.strptime('2021-01-01', '%Y-%m-%d').date()
yesterday = today - timedelta(days=30)
day_Diff = today-yesterday
return day_Diff,today
day_Diff, today = getDays()
for i in reversed(range(0,day_Diff.days)):
run_modification = PythonOperator(
task_id='run_modification_'+str(i),
provide_context=True,
python_callable=modificationProcess,
op_kwargs={'i': str(i)},
dag=dag,
)
modification_processXcom = BashOperator(
task_id='modification_processXcom_'+str(i),
bash_command=modification_processXcom,
params = {'i' :str(i)},
dag = dag
)
run_modification >> modification_processXcom
To get the dependency as:
run_modification_1 -> modification_processXcom_1 ->
run_modification_2 -> modification_processXcom_2 -> ... - >
run_modification_29 -> modification_processXcom_29
You can do:
from datetime import datetime
from airflow import DAG
from airflow.operators.bash import BashOperator
dag = DAG(
dag_id='my_dag',
schedule_interval=None,
start_date=datetime(2021, 8, 10),
catchup=False,
is_paused_upon_creation=False,
)
mylist1 = []
mylist2 = []
for i in range(1, 30):
mylist1.append(
BashOperator( # Replace with your requested operator
task_id=f'run_modification_{i}',
bash_command=f"""echo executing run_modification_{i}""",
dag=dag,
)
)
mylist2.append(
BashOperator( # Replace with your requested operator
task_id=f'modification_processXcom_{i}',
bash_command=f"""echo executing modification_processXcom_{i}""",
dag=dag,
)
)
if len(mylist1) > 0:
mylist1[-1] >> mylist2[-1] # This set dependency between run_modifiation to modification_processXcom
if len(mylist1) > 1:
mylist2[-2] >> mylist1[-1] # This set dependency between modification_processXcom to previous run_modifiation
This code create a list of operators and set them to run one after another as:
Tree view:

Run an airflow task after a task in a loop, not after all tasks in a loop

Let's say we have these tasks:
for endpoint in ENDPOINTS:
latest_only = LatestOnlyOperator(
task_id=f'{endpoint.name}_latest_only',
)
s3 = SnowflakeQOperator(
task_id=f'{endpoint.name}_to_S3',
boostr_conn_id='boostr_default',
s3_conn_id='aws_default',
partition=endpoint.partition,
endpoint=endpoint
)
short_circuit = ShortCircuitOperator(
task_id=f"short_circuit_missing_{endpoint.name}",
op_kwargs={'endpoint_to_check': endpoint, 'aws_conn_id': 'aws_default'},
python_callable=check_file_exists,
provide_context=True
)
s3 >> short_circuit
and let's say I want to add one task to run after nbc_to_s3 which is one of the '{endpoint.name}' task in the s3 task.
we're importing ENDPOINTS which contains several class, with the 'name' method:
#property
def name(self) -> str:
return 'nbc'
I've tried to add it outside of the loop like this:
nbc_to_s3 >> new_task but that doesn't work because 'nbc_to_s3' is not defined
You could apply some logic within the loop to set a new dependency for new_task like so (apologies for the quick mockup):
from airflow.decorators import dag
from airflow.operators.dummy import DummyOperator
from datetime import datetime
ENDPOINTS = ["nbc", "cbs", "bravo", "espn"]
DEFAULT_ARGS = dict(owner="airflow", start_date=datetime(2021, 6, 9))
DAG_ARGS = dict(schedule_interval=None, default_args=DEFAULT_ARGS, catchup=False)
#dag(**DAG_ARGS)
def run_task_after_loop():
for endpoint in ENDPOINTS:
s3 = DummyOperator(
task_id=f"{endpoint}_to_S3",
)
short_circuit = DummyOperator(
task_id=f"short_circuit_missing_{endpoint}",
)
s3 >> short_circuit
if endpoint == "nbc":
new_task = DummyOperator(task_id=f"new_task_{endpoint}")
s3 >> new_task
dag = run_task_after_loop()

Fetch datastore entity by id inside of a Dataflow transform

I have 2 datastore models:
class KindA(ndb.Model):
field_a1 = ndb.StringProperty()
field_a2 = ndb.StringProperty()
class KindB(ndb.Model):
field_b1 = ndb.StringProperty()
field_b2 = ndb.StringProperty()
key_to_kind_a = ndb.KeyProperty(KindA)
I want to query KindB and output it to a csv file, but if an entity of KindB points to an entity in KindA I want those fields to be present in the csv as well.
If I was able to use ndb inside of a transform I would setup my pipeline like this
def format(element): # element is an `entity_pb2` object of KindB
try:
obj_a_key_id = element.properties.get('key_to_kind_a', None).key_value.path[0]
except:
obj_a_key_id = None
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<< HOW DO I DO THIS
obj_a = ndb.Key(KindA, obj_a_key_id).get() if obj_a_key_id else None
return ",".join([
element.properties.get('field_b1', None).string_value,
element.properties.get('field_b2', None).string_value,
obj_a.properties.get('field_a1', None).string_value if obj_a else '',
obj_a.properties.get('field_a2', None).string_value if obj_a else '',
]
def build_pipeline(project, start_date, end_date, export_path):
query = query_pb2.Query()
query.kind.add().name = 'KindB'
filter_1 = datastore_helper.set_property_filter(query_pb2.Filter(), 'field_b1', PropertyFilter.GREATER_THAN, start_date)
filter_2 = datastore_helper.set_property_filter(query_pb2.Filter(), 'field_b1', PropertyFilter.LESS_THAN, end_date)
datastore_helper.set_composite_filter(query.filter, CompositeFilter.AND, filter_1, filter_2)
p = beam.Pipeline(options=pipeline_options)
_ = (p
| 'read from datastore' >> ReadFromDatastore(project, query, None)
| 'format' >> beam.Map(format)
| 'write' >> apache_beam.io.WriteToText(
file_path_prefix=export_path,
file_name_suffix='.csv',
header='field_b1,field_b2,field_a1,field_a2',
num_shards=1)
)
return p
I suppose I could use ReadFromDatastore to query all entities of KindA and then use CoGroupByKey to merge them, but KindA has millions of records and that would be very inefficient.
Per the reccommendations in this answer: https://stackoverflow.com/a/49130224/4458510
I created the following utils, which were inspired by the source code of
DatastoreWriteFn in apache_beam.io.gcp.datastore.v1.datastoreio
write_mutations and fetch_entities in apache_beam.io.gcp.datastore.v1.helper
import logging
import time
from socket import error as _socket_error
from apache_beam.metrics import Metrics
from apache_beam.transforms import DoFn, window
from apache_beam.utils import retry
from apache_beam.io.gcp.datastore.v1.adaptive_throttler import AdaptiveThrottler
from apache_beam.io.gcp.datastore.v1.helper import make_partition, retry_on_rpc_error, get_datastore
from apache_beam.io.gcp.datastore.v1.util import MovingSum
from apache_beam.utils.windowed_value import WindowedValue
from google.cloud.proto.datastore.v1 import datastore_pb2, query_pb2
from googledatastore.connection import Datastore, RPCError
_WRITE_BATCH_INITIAL_SIZE = 200
_WRITE_BATCH_MAX_SIZE = 500
_WRITE_BATCH_MIN_SIZE = 10
_WRITE_BATCH_TARGET_LATENCY_MS = 5000
def _fetch_keys(project_id, keys, datastore, throttler, rpc_stats_callback=None, throttle_delay=1):
req = datastore_pb2.LookupRequest()
req.project_id = project_id
for key in keys:
req.keys.add().CopyFrom(key)
#retry.with_exponential_backoff(num_retries=5, retry_filter=retry_on_rpc_error)
def run(request):
# Client-side throttling.
while throttler.throttle_request(time.time() * 1000):
logging.info("Delaying request for %ds due to previous failures", throttle_delay)
time.sleep(throttle_delay)
if rpc_stats_callback:
rpc_stats_callback(throttled_secs=throttle_delay)
try:
start_time = time.time()
response = datastore.lookup(request)
end_time = time.time()
if rpc_stats_callback:
rpc_stats_callback(successes=1)
throttler.successful_request(start_time * 1000)
commit_time_ms = int((end_time - start_time) * 1000)
return response, commit_time_ms
except (RPCError, _socket_error):
if rpc_stats_callback:
rpc_stats_callback(errors=1)
raise
return run(req)
# Copied from _DynamicBatchSizer in apache_beam.io.gcp.datastore.v1.datastoreio
class _DynamicBatchSizer(object):
"""Determines request sizes for future Datastore RPCS."""
def __init__(self):
self._commit_time_per_entity_ms = MovingSum(window_ms=120000, bucket_ms=10000)
def get_batch_size(self, now):
"""Returns the recommended size for datastore RPCs at this time."""
if not self._commit_time_per_entity_ms.has_data(now):
return _WRITE_BATCH_INITIAL_SIZE
recent_mean_latency_ms = (self._commit_time_per_entity_ms.sum(now) / self._commit_time_per_entity_ms.count(now))
return max(_WRITE_BATCH_MIN_SIZE,
min(_WRITE_BATCH_MAX_SIZE,
_WRITE_BATCH_TARGET_LATENCY_MS / max(recent_mean_latency_ms, 1)))
def report_latency(self, now, latency_ms, num_mutations):
"""Reports the latency of an RPC to Datastore.
Args:
now: double, completion time of the RPC as seconds since the epoch.
latency_ms: double, the observed latency in milliseconds for this RPC.
num_mutations: int, number of mutations contained in the RPC.
"""
self._commit_time_per_entity_ms.add(now, latency_ms / num_mutations)
class LookupKeysFn(DoFn):
"""A `DoFn` that looks up keys in the Datastore."""
def __init__(self, project_id, fixed_batch_size=None):
self._project_id = project_id
self._datastore = None
self._fixed_batch_size = fixed_batch_size
self._rpc_successes = Metrics.counter(self.__class__, "datastoreRpcSuccesses")
self._rpc_errors = Metrics.counter(self.__class__, "datastoreRpcErrors")
self._throttled_secs = Metrics.counter(self.__class__, "cumulativeThrottlingSeconds")
self._throttler = AdaptiveThrottler(window_ms=120000, bucket_ms=1000, overload_ratio=1.25)
self._elements = []
self._batch_sizer = None
self._target_batch_size = None
def _update_rpc_stats(self, successes=0, errors=0, throttled_secs=0):
"""Callback function, called by _fetch_keys()"""
self._rpc_successes.inc(successes)
self._rpc_errors.inc(errors)
self._throttled_secs.inc(throttled_secs)
def start_bundle(self):
"""(re)initialize: connection with datastore, _DynamicBatchSizer obj"""
self._elements = []
self._datastore = get_datastore(self._project_id)
if self._fixed_batch_size:
self._target_batch_size = self._fixed_batch_size
else:
self._batch_sizer = _DynamicBatchSizer()
self._target_batch_size = self._batch_sizer.get_batch_size(time.time()*1000)
def process(self, element):
"""Collect elements and process them as a batch"""
self._elements.append(element)
if len(self._elements) >= self._target_batch_size:
return self._flush_batch()
def finish_bundle(self):
"""Flush any remaining elements"""
if self._elements:
objs = self._flush_batch()
for obj in objs:
yield WindowedValue(obj, window.MAX_TIMESTAMP, [window.GlobalWindow()])
def _flush_batch(self):
"""Fetch all of the collected keys from datastore"""
response, latency_ms = _fetch_keys(
project_id=self._project_id,
keys=self._elements,
datastore=self._datastore,
throttler=self._throttler,
rpc_stats_callback=self._update_rpc_stats)
logging.info("Successfully read %d keys in %dms.", len(self._elements), latency_ms)
if not self._fixed_batch_size:
now = time.time()*1000
self._batch_sizer.report_latency(now, latency_ms, len(self._elements))
self._target_batch_size = self._batch_sizer.get_batch_size(now)
self._elements = []
return [entity_result.entity for entity_result in response.found]
class LookupEntityFieldFn(LookupKeysFn):
"""
Looks-up a field on an EntityPb2 object
Expects a EntityPb2 object as input
Outputs a tuple, where the first element is the input object and the second element is the object found during the
lookup
"""
def __init__(self, project_id, field_name, fixed_batch_size=None):
super(LookupEntityFieldFn, self).__init__(project_id=project_id, fixed_batch_size=fixed_batch_size)
self._field_name = field_name
#staticmethod
def _pb2_key_value_to_tuple(kv):
"""Converts a key_value object into a tuple, so that it can be a dictionary key"""
path = []
for p in kv.path:
path.append(p.name)
path.append(p.id)
return tuple(path)
def _flush_batch(self):
_elements = self._elements
keys_to_fetch = []
for element in self._elements:
kv = element.properties.get(self._field_name, None)
if kv and kv.key_value and kv.key_value.path:
keys_to_fetch.append(kv.key_value)
self._elements = keys_to_fetch
read_keys = super(LookupEntityFieldFn, self)._flush_batch()
_by_key = {self._pb2_key_value_to_tuple(entity.key): entity for entity in read_keys}
output_pairs = []
for input_obj in _elements:
kv = input_obj.properties.get(self._field_name, None)
output_obj = None
if kv and kv.key_value and kv.key_value.path:
output_obj = _by_key.get(self._pb2_key_value_to_tuple(kv.key_value), None)
output_pairs.append((input_obj, output_obj))
return output_pairs
The Key to this is the line response = datastore.lookup(request), where:
datastore = get_datastore(project_id) (from apache_beam.io.gcp.datastore.v1.helper.get_datastore)
request is a LookupRequest from google.cloud.proto.datastore.v1.datastore_pb2
response is LookupResponse from google.cloud.proto.datastore.v1.datastore_pb2
The rest of the above code does things like:
using a single connection to the datastore for a dofn bundle
batches keys together before performing a lookup request
throttles interactions with the datastore if requests start to fail
(honestly I don't know how critical these bits are, I just came across them when browsing the apache_beam source code)
The resulting util function LookupEntityFieldFn(project_id, field_name) is a DoFn that takes in an entity_pb2 object as input, extracts and fetches/gets the key_property that resides on the field field_name, and outputs the result as a tuple (the fetch-result is paired with the input object)
My Pipeline code then became
def format(element): # element is a tuple `entity_pb2` objects
kind_b_element, kind_a_element = element
return ",".join([
kind_b_element.properties.get('field_b1', None).string_value,
kind_b_element.properties.get('field_b2', None).string_value,
kind_a_element.properties.get('field_a1', None).string_value if kind_a_element else '',
kind_a_element.properties.get('field_a2', None).string_value if kind_a_element else '',
]
def build_pipeline(project, start_date, end_date, export_path):
query = query_pb2.Query()
query.kind.add().name = 'KindB'
filter_1 = datastore_helper.set_property_filter(query_pb2.Filter(), 'field_b1', PropertyFilter.GREATER_THAN, start_date)
filter_2 = datastore_helper.set_property_filter(query_pb2.Filter(), 'field_b1', PropertyFilter.LESS_THAN, end_date)
datastore_helper.set_composite_filter(query.filter, CompositeFilter.AND, filter_1, filter_2)
p = beam.Pipeline(options=pipeline_options)
_ = (p
| 'read from datastore' >> ReadFromDatastore(project, query, None)
| 'extract field' >> apache_beam.ParDo(LookupEntityFieldFn(project_id=project, field_name='key_to_kind_a'))
| 'format' >> beam.Map(format)
| 'write' >> apache_beam.io.WriteToText(
file_path_prefix=export_path,
file_name_suffix='.csv',
header='field_b1,field_b2,field_a1,field_a2',
num_shards=1)
)
return p

Resources