Airflow Advanced

Airflow Advanced

Overview

This document covers advanced Airflow features including XCom for data sharing between tasks, dynamic DAG generation, Sensors, Hooks, TaskGroups, and more. Leveraging these features allows you to build more flexible and powerful pipelines.


1. XCom (Cross-Communication)

1.1 Basic XCom Usage

XCom is a mechanism for sharing small amounts of data between tasks.

from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime

def push_data(**kwargs):
    """Push data to XCom"""
    ti = kwargs['ti']

    # Method 1: Using xcom_push
    ti.xcom_push(key='my_key', value={'status': 'success', 'count': 100})

    # Method 2: Return value (automatically saved with key='return_value')
    return {'result': 'completed', 'rows': 500}


def pull_data(**kwargs):
    """Pull data from XCom"""
    ti = kwargs['ti']

    # Method 1: Pull by specific key
    custom_data = ti.xcom_pull(key='my_key', task_ids='push_task')
    print(f"Custom data: {custom_data}")

    # Method 2: Pull return value
    return_value = ti.xcom_pull(task_ids='push_task')  # key='return_value' by default
    print(f"Return value: {return_value}")

    # Method 3: Pull from multiple tasks
    multiple_results = ti.xcom_pull(task_ids=['task1', 'task2'])


with DAG('xcom_example', start_date=datetime(2024, 1, 1), schedule_interval=None) as dag:

    push_task = PythonOperator(
        task_id='push_task',
        python_callable=push_data,
    )

    pull_task = PythonOperator(
        task_id='pull_task',
        python_callable=pull_data,
    )

    push_task >> pull_task

1.2 Using XCom in Jinja Templates

from airflow.operators.bash import BashOperator
from airflow.providers.postgres.operators.postgres import PostgresOperator

# Using XCom in Bash
bash_task = BashOperator(
    task_id='bash_with_xcom',
    bash_command='echo "Result: {{ ti.xcom_pull(task_ids="push_task") }}"',
)

# Using XCom in SQL
sql_task = PostgresOperator(
    task_id='sql_with_xcom',
    postgres_conn_id='my_postgres',
    sql="""
        INSERT INTO process_log (task_id, result_count, processed_at)
        VALUES (
            'data_load',
            {{ ti.xcom_pull(task_ids='count_task', key='row_count') }},
            NOW()
        );
    """,
)

1.3 XCom Limitations and Alternatives

# XCom limitation: default 1GB (stored in DB, recommended for small data only)

# Handling large data
class LargeDataHandler:
    """Pattern for handling large data"""

    @staticmethod
    def save_to_storage(data, path: str):
        """Save data to external storage and pass only the path via XCom"""
        import pandas as pd

        # Save to S3, GCS, etc.
        data.to_parquet(path)
        return path  # Return only the path

    @staticmethod
    def load_from_storage(path: str):
        """Load data from path"""
        import pandas as pd
        return pd.read_parquet(path)


# Usage example
def produce_large_data(**kwargs):
    import pandas as pd

    # Generate large dataset
    df = pd.DataFrame({'col': range(1000000)})

    # Save to S3 and return only the path
    path = f"s3://bucket/data/{kwargs['ds']}/output.parquet"
    df.to_parquet(path)

    return path  # Store only path in XCom


def consume_large_data(**kwargs):
    import pandas as pd

    ti = kwargs['ti']
    path = ti.xcom_pull(task_ids='produce_task')

    # Load data from path
    df = pd.read_parquet(path)
    print(f"Loaded {len(df)} rows from {path}")

2. Dynamic DAG Generation

2.1 Configuration-Based Dynamic DAGs

# dags/dynamic_dag_factory.py
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime

# Define configuration
DAG_CONFIGS = [
    {
        'dag_id': 'etl_customers',
        'table': 'customers',
        'schedule': '0 1 * * *',
    },
    {
        'dag_id': 'etl_orders',
        'table': 'orders',
        'schedule': '0 2 * * *',
    },
    {
        'dag_id': 'etl_products',
        'table': 'products',
        'schedule': '0 3 * * *',
    },
]


def create_dag(config: dict) -> DAG:
    """Create DAG based on configuration"""

    def extract_table(table_name: str, **kwargs):
        print(f"Extracting {table_name} for {kwargs['ds']}")

    def load_table(table_name: str, **kwargs):
        print(f"Loading {table_name} for {kwargs['ds']}")

    dag = DAG(
        dag_id=config['dag_id'],
        schedule_interval=config['schedule'],
        start_date=datetime(2024, 1, 1),
        catchup=False,
        tags=['dynamic', 'etl'],
    )

    with dag:
        extract = PythonOperator(
            task_id='extract',
            python_callable=extract_table,
            op_kwargs={'table_name': config['table']},
        )

        load = PythonOperator(
            task_id='load',
            python_callable=load_table,
            op_kwargs={'table_name': config['table']},
        )

        extract >> load

    return dag


# Register DAGs in globals() (so Airflow can discover them)
for config in DAG_CONFIGS:
    dag_id = config['dag_id']
    globals()[dag_id] = create_dag(config)

2.2 YAML/JSON-Based Dynamic DAGs

# dags/yaml_driven_dag.py
import yaml
from pathlib import Path
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime

# Load YAML configuration
config_path = Path(__file__).parent / 'configs' / 'dag_configs.yaml'

# Example configs/dag_configs.yaml:
"""
dags:
  - id: sales_etl
    schedule: "0 6 * * *"
    tasks:
      - name: extract
        type: python
        function: extract_sales
      - name: transform
        type: python
        function: transform_sales
      - name: load
        type: python
        function: load_sales
"""

def load_config():
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def create_task_callable(func_name: str):
    """Create callable from function name"""
    def task_func(**kwargs):
        print(f"Executing {func_name} for {kwargs['ds']}")
    return task_func


def create_dag_from_yaml(dag_config: dict) -> DAG:
    """Create DAG from YAML configuration"""

    dag = DAG(
        dag_id=dag_config['id'],
        schedule_interval=dag_config['schedule'],
        start_date=datetime(2024, 1, 1),
        catchup=False,
    )

    with dag:
        tasks = {}
        for task_config in dag_config['tasks']:
            task = PythonOperator(
                task_id=task_config['name'],
                python_callable=create_task_callable(task_config['function']),
            )
            tasks[task_config['name']] = task

        # Set sequential dependencies
        task_list = list(tasks.values())
        for i in range(len(task_list) - 1):
            task_list[i] >> task_list[i + 1]

    return dag


# Create and register DAGs
try:
    config = load_config()
    for dag_config in config.get('dags', []):
        dag_id = dag_config['id']
        globals()[dag_id] = create_dag_from_yaml(dag_config)
except Exception as e:
    print(f"Error loading DAG config: {e}")

2.3 Dynamic Task Generation

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.empty import EmptyOperator
from datetime import datetime

# List of tables to process
TABLES = ['users', 'orders', 'products', 'reviews', 'inventory']

with DAG(
    dag_id='dynamic_tasks_example',
    start_date=datetime(2024, 1, 1),
    schedule_interval='@daily',
    catchup=False,
) as dag:

    start = EmptyOperator(task_id='start')
    end = EmptyOperator(task_id='end')

    # Dynamically create tasks
    for table in TABLES:
        def process_table(table_name=table, **kwargs):
            print(f"Processing table: {table_name}")

        task = PythonOperator(
            task_id=f'process_{table}',
            python_callable=process_table,
            op_kwargs={'table_name': table},
        )

        start >> task >> end

3. Sensors

3.1 Built-in Sensors

from airflow import DAG
from airflow.sensors.filesystem import FileSensor
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.sensors.time_delta import TimeDeltaSensor
from airflow.providers.http.sensors.http import HttpSensor
from airflow.providers.postgres.sensors.postgres import SqlSensor
from datetime import datetime, timedelta

with DAG('sensor_examples', start_date=datetime(2024, 1, 1), schedule_interval='@daily') as dag:

    # 1. FileSensor - wait for file existence
    wait_for_file = FileSensor(
        task_id='wait_for_file',
        filepath='/data/input/{{ ds }}/data.csv',
        poke_interval=60,           # Check interval (seconds)
        timeout=3600,               # Timeout (seconds)
        mode='poke',                # poke or reschedule
    )

    # 2. ExternalTaskSensor - wait for another DAG's task completion
    wait_for_upstream = ExternalTaskSensor(
        task_id='wait_for_upstream',
        external_dag_id='upstream_dag',
        external_task_id='final_task',
        execution_delta=timedelta(hours=0),  # Same execution_date
        timeout=7200,
        mode='reschedule',          # Return worker and reschedule
    )

    # 3. HttpSensor - check HTTP endpoint
    wait_for_api = HttpSensor(
        task_id='wait_for_api',
        http_conn_id='my_api',
        endpoint='/health',
        request_params={},
        response_check=lambda response: response.status_code == 200,
        poke_interval=30,
        timeout=600,
    )

    # 4. SqlSensor - check SQL condition
    wait_for_data = SqlSensor(
        task_id='wait_for_data',
        conn_id='my_postgres',
        sql="""
            SELECT COUNT(*) > 0
            FROM staging_table
            WHERE date = '{{ ds }}'
        """,
        poke_interval=300,
        timeout=3600,
    )

    # 5. TimeDeltaSensor - wait for time duration
    wait_30_minutes = TimeDeltaSensor(
        task_id='wait_30_minutes',
        delta=timedelta(minutes=30),
    )

3.2 Custom Sensor

from airflow.sensors.base import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
import boto3

class S3KeySensorCustom(BaseSensorOperator):
    """Custom Sensor to check S3 key existence"""

    template_fields = ['bucket_key']

    @apply_defaults
    def __init__(
        self,
        bucket_name: str,
        bucket_key: str,
        aws_conn_id: str = 'aws_default',
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.bucket_name = bucket_name
        self.bucket_key = bucket_key
        self.aws_conn_id = aws_conn_id

    def poke(self, context) -> bool:
        """Check condition (returns True on success)"""
        self.log.info(f"Checking for s3://{self.bucket_name}/{self.bucket_key}")

        # Create S3 client
        s3 = boto3.client('s3')

        try:
            s3.head_object(Bucket=self.bucket_name, Key=self.bucket_key)
            self.log.info("File found!")
            return True
        except s3.exceptions.ClientError as e:
            if e.response['Error']['Code'] == '404':
                self.log.info("File not found, waiting...")
                return False
            raise


# Usage
wait_for_s3 = S3KeySensorCustom(
    task_id='wait_for_s3_file',
    bucket_name='my-bucket',
    bucket_key='data/{{ ds }}/input.parquet',
    poke_interval=60,
    timeout=3600,
    mode='reschedule',
)

3.3 Sensor Modes

# poke vs reschedule mode comparison
sensor_modes = {
    'poke': {
        'description': 'Occupies worker slot while waiting',
        'pros': 'Fast response time',
        'cons': 'Wastes worker resources',
        'use_case': 'Short wait time expected'
    },
    'reschedule': {
        'description': 'Returns worker and reschedules',
        'pros': 'Efficient worker resource usage',
        'cons': 'Slightly slower response time',
        'use_case': 'Long wait time expected'
    }
}

# Recommended configuration
wait_for_file = FileSensor(
    task_id='wait_for_file',
    filepath='/data/input.csv',
    poke_interval=300,      # Check every 5 minutes
    timeout=86400,          # 24 hour timeout
    mode='reschedule',      # Use reschedule for long waits
    soft_fail=True,         # Skip on timeout (instead of failing)
)

4. Hooks and Connections

4.1 Connection Configuration

# Configure Connection via Airflow UI or CLI
# Admin > Connections > Add

# Add Connection via CLI
"""
airflow connections add 'my_postgres' \
    --conn-type 'postgres' \
    --conn-host 'localhost' \
    --conn-port '5432' \
    --conn-login 'user' \
    --conn-password 'password' \
    --conn-schema 'mydb'

airflow connections add 'my_s3' \
    --conn-type 'aws' \
    --conn-extra '{"aws_access_key_id": "xxx", "aws_secret_access_key": "yyy", "region_name": "us-east-1"}'
"""

# Configure Connection via environment variable
# AIRFLOW_CONN_MY_POSTGRES='postgresql://user:password@localhost:5432/mydb'

4.2 Using Hooks

from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.http.hooks.http import HttpHook

def use_postgres_hook(**kwargs):
    """Using PostgreSQL Hook"""
    hook = PostgresHook(postgres_conn_id='my_postgres')

    # Execute SQL
    records = hook.get_records("SELECT * FROM users LIMIT 10")

    # Return as DataFrame
    df = hook.get_pandas_df("SELECT * FROM users")

    # Insert rows
    hook.insert_rows(
        table='users',
        rows=[(1, 'John'), (2, 'Jane')],
        target_fields=['id', 'name']
    )

    # Use connection directly
    conn = hook.get_conn()
    cursor = conn.cursor()
    cursor.execute("UPDATE users SET active = true")
    conn.commit()


def use_s3_hook(**kwargs):
    """Using S3 Hook"""
    hook = S3Hook(aws_conn_id='my_s3')

    # Upload file
    hook.load_file(
        filename='/tmp/data.csv',
        key='data/output.csv',
        bucket_name='my-bucket',
        replace=True
    )

    # Download file
    hook.download_file(
        key='data/input.csv',
        bucket_name='my-bucket',
        local_path='/tmp/input.csv'
    )

    # List files
    keys = hook.list_keys(
        bucket_name='my-bucket',
        prefix='data/',
        delimiter='/'
    )


def use_http_hook(**kwargs):
    """Using HTTP Hook"""
    hook = HttpHook(http_conn_id='my_api', method='GET')

    response = hook.run(
        endpoint='/api/data',
        headers={'Authorization': 'Bearer token'},
        data={'param': 'value'}
    )

    return response.json()

4.3 Custom Hook

from airflow.hooks.base import BaseHook
from typing import Any
import requests

class MyCustomHook(BaseHook):
    """Custom API Hook"""

    conn_name_attr = 'my_custom_conn_id'
    default_conn_name = 'my_custom_default'
    conn_type = 'http'
    hook_name = 'My Custom Hook'

    def __init__(self, my_custom_conn_id: str = default_conn_name):
        super().__init__()
        self.my_custom_conn_id = my_custom_conn_id
        self.base_url = None
        self.api_key = None

    def get_conn(self):
        """Load connection configuration"""
        conn = self.get_connection(self.my_custom_conn_id)
        self.base_url = f"https://{conn.host}"
        self.api_key = conn.password
        return conn

    def make_request(self, endpoint: str, method: str = 'GET', data: dict = None) -> Any:
        """Make API request"""
        self.get_conn()

        headers = {
            'Authorization': f'Bearer {self.api_key}',
            'Content-Type': 'application/json'
        }

        url = f"{self.base_url}{endpoint}"

        response = requests.request(
            method=method,
            url=url,
            headers=headers,
            json=data
        )

        response.raise_for_status()
        return response.json()


# Usage
def call_custom_api(**kwargs):
    hook = MyCustomHook(my_custom_conn_id='my_api')
    result = hook.make_request('/users', method='GET')
    return result

5. TaskGroup

5.1 Basic TaskGroup Usage

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.empty import EmptyOperator
from airflow.utils.task_group import TaskGroup
from datetime import datetime

with DAG('taskgroup_example', start_date=datetime(2024, 1, 1), schedule_interval='@daily') as dag:

    start = EmptyOperator(task_id='start')

    # Group related tasks with TaskGroup
    with TaskGroup(group_id='extract_group') as extract_group:
        extract_users = PythonOperator(
            task_id='extract_users',
            python_callable=lambda: print("Extracting users")
        )
        extract_orders = PythonOperator(
            task_id='extract_orders',
            python_callable=lambda: print("Extracting orders")
        )
        extract_products = PythonOperator(
            task_id='extract_products',
            python_callable=lambda: print("Extracting products")
        )

    with TaskGroup(group_id='transform_group') as transform_group:
        transform_users = PythonOperator(
            task_id='transform_users',
            python_callable=lambda: print("Transforming users")
        )
        transform_orders = PythonOperator(
            task_id='transform_orders',
            python_callable=lambda: print("Transforming orders")
        )

    with TaskGroup(group_id='load_group') as load_group:
        load_warehouse = PythonOperator(
            task_id='load_warehouse',
            python_callable=lambda: print("Loading to warehouse")
        )

    end = EmptyOperator(task_id='end')

    # Dependencies between TaskGroups
    start >> extract_group >> transform_group >> load_group >> end

5.2 Nested TaskGroups

from airflow.utils.task_group import TaskGroup

with DAG('nested_taskgroup', ...) as dag:

    with TaskGroup(group_id='data_processing') as data_processing:

        with TaskGroup(group_id='source_a') as source_a:
            extract_a = PythonOperator(task_id='extract', ...)
            transform_a = PythonOperator(task_id='transform', ...)
            extract_a >> transform_a

        with TaskGroup(group_id='source_b') as source_b:
            extract_b = PythonOperator(task_id='extract', ...)
            transform_b = PythonOperator(task_id='transform', ...)
            extract_b >> transform_b

        # Parallel execution then join
        join = EmptyOperator(task_id='join')
        [source_a, source_b] >> join

5.3 Dynamic TaskGroups

from airflow.utils.task_group import TaskGroup

SOURCES = ['mysql', 'postgres', 'mongodb']

with DAG('dynamic_taskgroup', ...) as dag:

    start = EmptyOperator(task_id='start')

    task_groups = []
    for source in SOURCES:
        with TaskGroup(group_id=f'process_{source}') as tg:
            extract = PythonOperator(
                task_id='extract',
                python_callable=lambda s=source: print(f"Extract from {s}")
            )
            load = PythonOperator(
                task_id='load',
                python_callable=lambda s=source: print(f"Load {s}")
            )
            extract >> load

        task_groups.append(tg)

    end = EmptyOperator(task_id='end')

    start >> task_groups >> end

6. Branching and Conditional Execution

6.1 BranchPythonOperator

from airflow.operators.python import BranchPythonOperator
from airflow.operators.empty import EmptyOperator

def choose_branch(**kwargs):
    """Choose next task based on condition"""
    ti = kwargs['ti']
    data_count = ti.xcom_pull(task_ids='count_data')

    if data_count > 1000:
        return 'process_large'
    elif data_count > 0:
        return 'process_small'
    else:
        return 'skip_processing'


with DAG('branch_example', ...) as dag:

    count_data = PythonOperator(
        task_id='count_data',
        python_callable=lambda: 500,  # Example return value
    )

    branch = BranchPythonOperator(
        task_id='branch',
        python_callable=choose_branch,
    )

    process_large = EmptyOperator(task_id='process_large')
    process_small = EmptyOperator(task_id='process_small')
    skip_processing = EmptyOperator(task_id='skip_processing')

    # Join after branching
    join = EmptyOperator(
        task_id='join',
        trigger_rule='none_failed_min_one_success'  # Execute if at least one succeeds
    )

    count_data >> branch >> [process_large, process_small, skip_processing] >> join

6.2 ShortCircuitOperator

from airflow.operators.python import ShortCircuitOperator

def check_condition(**kwargs):
    """Check condition - skip downstream tasks if returns False"""
    ds = kwargs['ds']
    # Skip on weekends
    day_of_week = datetime.strptime(ds, '%Y-%m-%d').weekday()
    return day_of_week < 5  # True only on weekdays


with DAG('shortcircuit_example', ...) as dag:

    check = ShortCircuitOperator(
        task_id='check_weekday',
        python_callable=check_condition,
    )

    # Tasks below are skipped if check returns False
    process = PythonOperator(task_id='process', ...)
    load = PythonOperator(task_id='load', ...)

    check >> process >> load

Practice Problems

Problem 1: Using XCom

Write a DAG with two tasks that each return a number, and a third task that calculates the sum of the two numbers.

Problem 2: Dynamic DAG

Write a DAG that dynamically generates ETL tasks for each table in a list (users, orders, products).

Problem 3: Using Sensors

Write a DAG that waits for a file to be created before processing it.


Summary

Feature Description
XCom Mechanism for sharing data between tasks
Dynamic DAG Dynamically generate DAGs/tasks based on configuration
Sensor Operator that waits until a condition is met
Hook Interface for connecting to external systems
TaskGroup Group related tasks for better visualization
Branch Conditional branching based on criteria

References

to navigate between lessons