Source code for llmsql.evaluation.evaluate

"""
LLMSQL Evaluation Module
=========================

Provides the `evaluate()` function to benchmark Text-to-SQL model outputs
on the LLMSQL benchmark.

See the documentation for full usage details.
"""

from datetime import datetime, timezone
import uuid

from rich.progress import track

from llmsql.config.config import (
    DEFAULT_LLMSQL_VERSION,
    get_repo_id,
)
from llmsql.utils.evaluation_utils import (
    connect_sqlite,
    evaluate_sample,
)
from llmsql.utils.inference_utils import _maybe_download, resolve_workdir_path
from llmsql.utils.rich_utils import log_mismatch, print_summary
from llmsql.utils.utils import load_jsonl, load_jsonl_dict_by_key, save_json_report


[docs] def evaluate( outputs: str | list[dict[int, str | int]], *, version: str = DEFAULT_LLMSQL_VERSION, workdir_path: str | None = None, save_report: str | None = None, show_mismatches: bool = True, max_mismatches: int = 5, ) -> dict: """ Evaluate predicted SQL queries against the LLMSQL benchmark. Args: version: LLMSQL version outputs: Either a JSONL file path or a list of dicts. workdir_path: Directory to store downloaded benchmark files. If omitted, a temporary directory is created automatically. save_report: Optional manual save path. If None → auto-generated. show_mismatches: Print mismatches while evaluating. max_mismatches: Max mismatches to print. Returns: dict: Metrics and mismatches. """ # Determine input type input_mode = "jsonl_path" if isinstance(outputs, str) else "dict_list" workdir = resolve_workdir_path(workdir_path) repo_id = get_repo_id(version) questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) db_path = _maybe_download(repo_id, "sqlite_tables.db", workdir) # --- Load benchmark questions --- questions = load_jsonl_dict_by_key(questions_path, key="question_id") # --- Load predictions (path or list) --- if isinstance(outputs, str): outputs_list = load_jsonl(outputs) elif isinstance(outputs, list): outputs_list = outputs else: raise TypeError( "outputs must be file path or list of dicts in format {'question_id': int, 'completion': str}" ) # --- Connect to DB --- conn = connect_sqlite(db_path) # --- Evaluation loop --- metrics = { "total": 0, "matches": 0, "pred_none": 0, "gold_none": 0, "sql_errors": 0, } mismatches: list[dict] = [] for item in track(outputs_list, description="Evaluating"): metrics["total"] += 1 is_match, mismatch_info, m = evaluate_sample(item, questions, conn) metrics["matches"] += is_match metrics["pred_none"] += m["pred_none"] metrics["gold_none"] += m["gold_none"] metrics["sql_errors"] += m["sql_error"] if mismatch_info: mismatches.append(mismatch_info) if show_mismatches and len(mismatches) <= max_mismatches: log_mismatch(**mismatch_info) print_summary( metrics["total"], metrics["matches"], metrics["pred_none"], metrics["gold_none"], metrics["sql_errors"], ) # --- Build report structure --- report = { **metrics, "accuracy": metrics["matches"] / metrics["total"] if metrics["total"] else 0, "mismatches": mismatches, "timestamp": datetime.now(timezone.utc).isoformat(), "input_mode": input_mode, } # --- Auto-generate report filename (if not provided) --- if save_report is None: save_report = f"evaluation_results_{uuid.uuid4()}.json" save_json_report(save_report, report) conn.close() return report