Notes:

➜  pydantic-airflow-agent poetry add apache-airflow

Current Python version (3.13.1) is not allowed by the project (>=3.12,<3.13).
Please change python executable via the "env use" command.

Solve:

pyenv install 3.12
poetry env use /Users/marclamberti/.pyenv/versions/3.12.8/bin/python3.12

Link to the Medium post

https://blog.det.life/talk-to-airflow-build-an-ai-agent-using-pydanticai-and-gemini-2-0-fd645cf99fcb

The code

# agent.py

from dataclasses import dataclass
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
import asyncio
import json
import logging
from devtools import pprint
import colorlog
from httpx import AsyncClient, HTTPStatusError

log_format = '%(log_color)s%(asctime)s [%(levelname)s] %(reset)s%(purple)s[%(name)s] %(reset)s%(blue)s%(message)s'
handler = colorlog.StreamHandler()
handler.setFormatter(colorlog.ColoredFormatter(log_format))
logging.basicConfig(level=logging.INFO, handlers=[handler])

logger = logging.getLogger(__name__)

@dataclass  
class Deps:  
    airflow_api_base_uri: str  
    airflow_api_port: int  
    airflow_api_user: str  
    airflow_api_pass: str
    
class DAGStatus(BaseModel):  
    dag_id: str = Field(description='ID of the DAG')  
    dag_display_name: str = Field(description='Display name of the DAG')  
    is_paused: bool = Field(description='Whether the DAG is paused')  
    next_dag_run_data_interval_start: str = Field(description='Next DAG run data interval start')  
    next_dag_run_data_interval_end: str = Field(description='Next DAG run data interval end')  
    last_dag_run_id: str = Field(default='No DAG run', description='Last DAG run ID')  
    last_dag_run_state: str = Field(default='No DAG run', description='Last DAG run state')  
    total_dag_runs: int = Field(description='Total number of DAG runs')

airflow_agent = Agent(
    model='google-gla:gemini-2.0-flash',
    system_prompt=(
        'You are an Airflow monitoring assistant. For each request:\\n'  
        '1. Use `list_dags` first to get available DAGs\\n'  
        '2. Match the user request to the most relevant DAG ID\\n'  
        '3. Use `get_dag_status` to fetch the DAG status details'
    ),
    result_type=DAGStatus,
    deps_type=Deps,
    retries=2
)

@airflow_agent.tool
async def list_dags(ctx: RunContext[Deps]) -> str:
    """
    Get a list of all DAGs from the Airflow instance. Returns DAGs with their IDs and display names.
    """
    logger.info('Getting available DAGs...')
    uri = f'{ctx.deps.airflow_api_base_uri}:{ctx.deps.airflow_api_port}/api/v1/dags'
    auth = (ctx.deps.airflow_api_user, ctx.deps.airflow_api_pass)

    async with AsyncClient() as client:
        response = await client.get(uri, auth=auth)
        response.raise_for_status()

        dags_data = response.json()['dags']
        result = json.dumps([
            {'dag_id': dag['dag_id'], 'dag_display_name': dag['dag_display_name']} for dag in dags_data
        ])
        logger.debug(f'Available DAGs: {result}')
        return result
    
@airflow_agent.tool
async def get_dag_status(ctx: RunContext[Deps], dag_id: str) -> str:
    """
    Get detailed status information for a specific DAG by DAG ID.
    """
    logger.info(f'Getting status for DAG with ID: {dag_id}')
    base_url = f'{ctx.deps.airflow_api_base_uri}:{ctx.deps.airflow_api_port}/api/v1'
    auth = (ctx.deps.airflow_api_user, ctx.deps.airflow_api_pass)

    try:
        async with AsyncClient() as client:
            dag_response = await client.get(f'{base_url}/dags/{dag_id}', auth=auth)
            dag_response.raise_for_status()

            runs_response = await client.get(
                f'{base_url}/dags/{dag_id}/dagRuns',
                auth=auth,
                params={'order_by': '-execution_date', 'limit': 1}
            )
            runs_response.raise_for_status()

            result = {
                'dag_data': dag_response.json(),
                'runs_data': runs_response.json()
            }

            logger.debug(f'DAG status: {json.dumps(result)}')
            return json.dumps(result)

    except HTTPStatusError as e:
        if e.response.status_code == 404:
            return f'DAG with ID {dag_id} not found'
        raise
    
async def main():
    deps = Deps(
        airflow_api_base_uri='<http://localhost>',
        airflow_api_port=8080,
        airflow_api_user='admin',
        airflow_api_pass='admin'
    )

    user_request = 'What is the status of the DAG for our daily payment report?'
    result = await airflow_agent.run(user_request, deps=deps)
    pprint(result.data)

if __name__ == "__main__":
    asyncio.run(main())