Skip to main content

Overview

Callbacks in Flow Core provide hooks into the workflow execution lifecycle, enabling monitoring, logging, debugging, and custom integrations.

BaseCallbackHandler

All callbacks inherit from the base handler:
from nadoo_flow import BaseCallbackHandler

class BaseCallbackHandler:
    """Base class for workflow callbacks"""

    async def on_workflow_start(self, workflow_id, input_data):
        """Called when workflow starts"""
        pass

    async def on_workflow_end(self, workflow_id, output_data):
        """Called when workflow completes"""
        pass

    async def on_node_start(self, node_id, input_data):
        """Called before node execution"""
        pass

    async def on_node_end(self, node_id, output_data):
        """Called after node execution"""
        pass

    async def on_error(self, error_info):
        """Called on any error"""
        pass

Creating Custom Callbacks

Logging Callback

import logging
from nadoo_flow import BaseCallbackHandler

class LoggingCallback(BaseCallbackHandler):
    """Log workflow execution details"""

    def __init__(self, logger=None):
        self.logger = logger or logging.getLogger(__name__)

    async def on_workflow_start(self, workflow_id, input_data):
        self.logger.info(f"Workflow {workflow_id} started")
        self.logger.debug(f"Input: {input_data}")

    async def on_workflow_end(self, workflow_id, output_data):
        self.logger.info(f"Workflow {workflow_id} completed")
        self.logger.debug(f"Output: {output_data}")

    async def on_node_start(self, node_id, input_data):
        self.logger.debug(f"Node {node_id} starting")

    async def on_node_end(self, node_id, output_data):
        self.logger.debug(f"Node {node_id} completed")

    async def on_error(self, error_info):
        self.logger.error(f"Error: {error_info}")

Metrics Callback

import time
from collections import defaultdict

class MetricsCallback(BaseCallbackHandler):
    """Collect execution metrics"""

    def __init__(self):
        self.metrics = {
            "workflow_count": 0,
            "node_count": 0,
            "error_count": 0,
            "total_duration": 0,
            "node_durations": defaultdict(list),
        }
        self.start_times = {}

    async def on_workflow_start(self, workflow_id, input_data):
        self.metrics["workflow_count"] += 1
        self.start_times[f"workflow_{workflow_id}"] = time.time()

    async def on_workflow_end(self, workflow_id, output_data):
        start = self.start_times.pop(f"workflow_{workflow_id}", None)
        if start:
            duration = time.time() - start
            self.metrics["total_duration"] += duration
            self.metrics["avg_workflow_duration"] = (
                self.metrics["total_duration"] / self.metrics["workflow_count"]
            )

    async def on_node_start(self, node_id, input_data):
        self.metrics["node_count"] += 1
        self.start_times[f"node_{node_id}"] = time.time()

    async def on_node_end(self, node_id, output_data):
        start = self.start_times.pop(f"node_{node_id}", None)
        if start:
            duration = time.time() - start
            self.metrics["node_durations"][node_id].append(duration)

    async def on_error(self, error_info):
        self.metrics["error_count"] += 1
        self.metrics["last_error"] = error_info

    def get_report(self):
        """Generate metrics report"""
        return {
            "total_workflows": self.metrics["workflow_count"],
            "total_nodes": self.metrics["node_count"],
            "error_rate": self.metrics["error_count"] / max(self.metrics["node_count"], 1),
            "avg_workflow_duration": self.metrics.get("avg_workflow_duration", 0),
            "slowest_nodes": self._get_slowest_nodes(),
        }

    def _get_slowest_nodes(self):
        """Identify slowest nodes"""
        avg_durations = {}
        for node_id, durations in self.metrics["node_durations"].items():
            if durations:
                avg_durations[node_id] = sum(durations) / len(durations)

        return sorted(avg_durations.items(), key=lambda x: x[1], reverse=True)[:5]

CallbackManager

Manage multiple callbacks:
from nadoo_flow import CallbackManager

class CallbackManager:
    """Manage workflow callbacks"""

    def __init__(self):
        self.handlers = []

    def add_handler(self, handler: BaseCallbackHandler):
        """Add a callback handler"""
        self.handlers.append(handler)

    def remove_handler(self, handler: BaseCallbackHandler):
        """Remove a callback handler"""
        self.handlers.remove(handler)

    async def trigger(self, event: str, **kwargs):
        """Trigger callbacks for an event"""
        for handler in self.handlers:
            method = getattr(handler, event, None)
            if method:
                try:
                    await method(**kwargs)
                except Exception as e:
                    logger.error(f"Callback error in {handler}: {e}")

Advanced Callbacks

Database Callback

class DatabaseCallback(BaseCallbackHandler):
    """Store execution history in database"""

    def __init__(self, db_connection):
        self.db = db_connection

    async def on_workflow_start(self, workflow_id, input_data):
        await self.db.execute(
            """
            INSERT INTO workflow_executions (workflow_id, status, input_data, started_at)
            VALUES ($1, $2, $3, $4)
            """,
            workflow_id, "running", json.dumps(input_data), datetime.now()
        )

    async def on_workflow_end(self, workflow_id, output_data):
        await self.db.execute(
            """
            UPDATE workflow_executions
            SET status = $1, output_data = $2, completed_at = $3
            WHERE workflow_id = $4
            """,
            "completed", json.dumps(output_data), datetime.now(), workflow_id
        )

    async def on_node_start(self, node_id, input_data):
        await self.db.execute(
            """
            INSERT INTO node_executions (node_id, status, input_data, started_at)
            VALUES ($1, $2, $3, $4)
            """,
            node_id, "running", json.dumps(input_data), datetime.now()
        )

    async def on_error(self, error_info):
        await self.db.execute(
            """
            INSERT INTO execution_errors (error_type, error_message, stack_trace, occurred_at)
            VALUES ($1, $2, $3, $4)
            """,
            error_info.get("type"), error_info.get("message"),
            error_info.get("stack_trace"), datetime.now()
        )

Notification Callback

class NotificationCallback(BaseCallbackHandler):
    """Send notifications on events"""

    def __init__(self, notification_service):
        self.notifier = notification_service

    async def on_workflow_start(self, workflow_id, input_data):
        # Notify on long-running workflows
        if input_data.get("estimated_duration", 0) > 300:
            await self.notifier.send(
                channel="slack",
                message=f"Long-running workflow {workflow_id} started"
            )

    async def on_error(self, error_info):
        # Send alert on critical errors
        if error_info.get("severity") == "critical":
            await self.notifier.send(
                channel="pagerduty",
                message=f"Critical error in workflow: {error_info['message']}",
                urgency="high"
            )

    async def on_workflow_end(self, workflow_id, output_data):
        # Notify on completion
        if output_data.get("notify_on_complete"):
            await self.notifier.send(
                channel="email",
                to=output_data.get("user_email"),
                subject=f"Workflow {workflow_id} completed",
                body=self._format_completion_email(output_data)
            )

Progress Callback

class ProgressCallback(BaseCallbackHandler):
    """Track and report progress"""

    def __init__(self, progress_store):
        self.store = progress_store
        self.node_count = {}
        self.completed_nodes = {}

    async def on_workflow_start(self, workflow_id, input_data):
        # Initialize progress
        total_nodes = input_data.get("total_nodes", 0)
        self.node_count[workflow_id] = total_nodes
        self.completed_nodes[workflow_id] = 0

        await self.store.set(f"progress:{workflow_id}", {
            "status": "started",
            "progress": 0,
            "total": total_nodes
        })

    async def on_node_end(self, node_id, output_data):
        # Update progress
        workflow_id = self._get_workflow_id(node_id)
        self.completed_nodes[workflow_id] += 1

        progress = (
            self.completed_nodes[workflow_id] / self.node_count[workflow_id] * 100
            if self.node_count[workflow_id] > 0 else 0
        )

        await self.store.set(f"progress:{workflow_id}", {
            "status": "running",
            "progress": progress,
            "completed": self.completed_nodes[workflow_id],
            "total": self.node_count[workflow_id],
            "last_completed_node": node_id
        })

    async def on_workflow_end(self, workflow_id, output_data):
        await self.store.set(f"progress:{workflow_id}", {
            "status": "completed",
            "progress": 100,
            "completed": self.node_count[workflow_id],
            "total": self.node_count[workflow_id]
        })

Conditional Callbacks

Conditional Execution

class ConditionalCallback(BaseCallbackHandler):
    """Execute callbacks based on conditions"""

    def __init__(self, condition_fn, wrapped_callback):
        self.condition_fn = condition_fn
        self.wrapped = wrapped_callback

    async def on_workflow_start(self, workflow_id, input_data):
        if self.condition_fn("workflow_start", workflow_id, input_data):
            await self.wrapped.on_workflow_start(workflow_id, input_data)

    async def on_error(self, error_info):
        if self.condition_fn("error", error_info):
            await self.wrapped.on_error(error_info)

# Usage
def only_production(event_type, *args):
    """Only trigger in production environment"""
    return os.environ.get("ENV") == "production"

production_callback = ConditionalCallback(
    condition_fn=only_production,
    wrapped_callback=AlertingCallback()
)

Filtering Callback

class FilteringCallback(BaseCallbackHandler):
    """Filter events before processing"""

    def __init__(self, filters=None):
        self.filters = filters or {}

    async def on_node_start(self, node_id, input_data):
        # Only log specific nodes
        if self._should_process("node_start", node_id):
            logger.info(f"Monitored node {node_id} started")

    def _should_process(self, event_type, *args):
        """Check if event should be processed"""
        if event_type not in self.filters:
            return True

        filter_fn = self.filters[event_type]
        return filter_fn(*args)

# Usage
callback = FilteringCallback(
    filters={
        "node_start": lambda node_id: node_id.startswith("critical_"),
        "error": lambda error_info: error_info.get("severity") == "high"
    }
)

Callback Composition

Chaining Callbacks

class ChainedCallback(BaseCallbackHandler):
    """Chain multiple callbacks"""

    def __init__(self, *callbacks):
        self.callbacks = callbacks

    async def on_workflow_start(self, workflow_id, input_data):
        for callback in self.callbacks:
            await callback.on_workflow_start(workflow_id, input_data)

    async def on_workflow_end(self, workflow_id, output_data):
        for callback in self.callbacks:
            await callback.on_workflow_end(workflow_id, output_data)

    async def on_error(self, error_info):
        for callback in self.callbacks:
            try:
                await callback.on_error(error_info)
            except Exception as e:
                logger.error(f"Error in chained callback: {e}")

# Usage
combined_callback = ChainedCallback(
    LoggingCallback(),
    MetricsCallback(),
    DatabaseCallback(db),
    NotificationCallback(notifier)
)

Priority Callbacks

class PriorityCallbackManager(CallbackManager):
    """Execute callbacks by priority"""

    def add_handler(self, handler, priority=0):
        """Add handler with priority (higher = earlier)"""
        self.handlers.append((priority, handler))
        self.handlers.sort(key=lambda x: x[0], reverse=True)

    async def trigger(self, event: str, **kwargs):
        """Trigger callbacks in priority order"""
        for priority, handler in self.handlers:
            method = getattr(handler, event, None)
            if method:
                try:
                    await method(**kwargs)
                except Exception as e:
                    logger.error(f"Callback error (priority {priority}): {e}")

Testing Callbacks

Mock Callback

class MockCallback(BaseCallbackHandler):
    """Callback for testing"""

    def __init__(self):
        self.events = []

    async def on_workflow_start(self, workflow_id, input_data):
        self.events.append(("workflow_start", workflow_id, input_data))

    async def on_node_start(self, node_id, input_data):
        self.events.append(("node_start", node_id, input_data))

    async def on_error(self, error_info):
        self.events.append(("error", error_info))

    def assert_called(self, event_type, times=None):
        """Assert callback was called"""
        count = sum(1 for e in self.events if e[0] == event_type)
        if times is not None:
            assert count == times, f"Expected {times} calls, got {count}"
        else:
            assert count > 0, f"Expected {event_type} to be called"

# Test usage
async def test_workflow_callbacks():
    mock_callback = MockCallback()
    workflow.add_callback(mock_callback)

    await workflow.run(input_data)

    mock_callback.assert_called("workflow_start", times=1)
    mock_callback.assert_called("workflow_end", times=1)

Performance Considerations

Async Callbacks

class AsyncCallback(BaseCallbackHandler):
    """Non-blocking async callbacks"""

    async def on_node_start(self, node_id, input_data):
        # Don't block execution
        asyncio.create_task(self._async_process(node_id, input_data))

    async def _async_process(self, node_id, input_data):
        """Process in background"""
        await asyncio.sleep(0)  # Yield control
        # Do expensive operation
        await self.send_to_analytics(node_id, input_data)

Batched Callbacks

class BatchedCallback(BaseCallbackHandler):
    """Batch events before processing"""

    def __init__(self, batch_size=100, flush_interval=5.0):
        self.batch = []
        self.batch_size = batch_size
        self.flush_interval = flush_interval
        self._start_flush_timer()

    async def on_node_end(self, node_id, output_data):
        self.batch.append(("node_end", node_id, output_data))

        if len(self.batch) >= self.batch_size:
            await self._flush()

    async def _flush(self):
        """Process batched events"""
        if not self.batch:
            return

        # Process batch
        await self.process_batch(self.batch)
        self.batch.clear()

    def _start_flush_timer(self):
        """Periodic flush"""
        async def flush_loop():
            while True:
                await asyncio.sleep(self.flush_interval)
                await self._flush()

        asyncio.create_task(flush_loop())

Best Practices

Callbacks should complete quickly to avoid blocking execution. Use background tasks for heavy operations.
Always wrap callback code in try/except to prevent callback errors from affecting workflow execution.
Create focused callbacks for specific purposes rather than one large callback doing everything.
Test callbacks independently from workflows using mock events.
Clearly document what each callback does and when it’s triggered.

Next Steps