Source code for llmsql.inference.inference_vllm

"""
LLMSQL vLLM Inference Function
==============================

This module provides a single function `inference_vllm()` that performs
text-to-SQL generation using large language models via the vLLM backend.

Example
-------

.. code-block:: python

    from llmsql.inference import inference_vllm

    results = inference_vllm(
        model_name="Qwen/Qwen2.5-1.5B-Instruct",
        version="2.0",
        tables_path="data/tables.jsonl",
        num_fewshots=5,
        batch_size=8,
        max_new_tokens=256,
        temperature=0.7,
        tensor_parallel_size=1,
        lora_path="path/to/lora"
    )

Notes
~~~~~

This function uses the vLLM backend. Outputs may differ from the Transformers
backend due to differences in implementation, batching, and numerical precision.

"""

from __future__ import annotations

import os

os.environ["VLLM_USE_V1"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

from typing import Any, Literal

from dotenv import load_dotenv
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from llmsql.config.config import (
    DEFAULT_LLMSQL_VERSION,
    get_repo_id,
)
from llmsql.loggers.logging_config import log
from llmsql.utils.inference_utils import (
    _maybe_download,
    _setup_seed,
    resolve_workdir_path,
)
from llmsql.utils.utils import (
    build_all_requests,
    choose_prompt_builder,
    load_jsonl,
    overwrite_jsonl,
    save_jsonl_lines,
)

load_dotenv()


[docs] def inference_vllm( model_name: str, *, # === Model Loading Parameters === trust_remote_code: bool = True, tensor_parallel_size: int = 1, hf_token: str | None = None, llm_kwargs: dict[str, Any] | None = None, use_chat_template: bool = True, # === LoRA Parameters === lora_config: dict[str, Any] | None = None, # new optional dict # === Generation Parameters === max_new_tokens: int = 256, temperature: float = 1.0, do_sample: bool = True, sampling_kwargs: dict[str, Any] | None = None, # === Benchmark Parameters === version: Literal["1.0", "2.0"] = DEFAULT_LLMSQL_VERSION, output_file: str = "llm_sql_predictions.jsonl", workdir_path: str | None = None, limit: int | float | None = None, num_fewshots: int = 5, batch_size: int = 8, seed: int = 42, ) -> list[dict[str, str]]: """ Run SQL generation using vLLM. Args: model_name: Hugging Face model name or path. # Model Loading: trust_remote_code: Whether to trust remote code (default: True). tensor_parallel_size: Number of GPUs for tensor parallelism (default: 1). hf_token: Hugging Face authentication token. llm_kwargs: Additional arguments for vllm.LLM(). Note: 'model', 'tokenizer', 'tensor_parallel_size', 'trust_remote_code' are handled separately and will override values here. lora_config: Optional dict with LoRA parameters: - lora_path: Path to the pretrained LoRA adapter (required if enable_lora) - lora_name: Logical name for the LoRA adapter - lora_scale: Scaling factor for LoRA weights - max_lora_rank: Maximum LoRA rank supported by vLLM LoRA usage rules: - If `lora_config` is provided, `enable_lora` must be True in `llm_kwargs`. - If `enable_lora` is True, a valid `lora_config` must be provided. - Otherwise, an exception is raised to prevent inconsistent configuration. # Generation: max_new_tokens: Maximum tokens to generate per sequence. temperature: Sampling temperature (0.0 = greedy). do_sample: Whether to use sampling vs greedy decoding. sampling_kwargs: Additional arguments for vllm.SamplingParams(). Note: 'temperature', 'max_tokens' are handled separately and will override values here. # Benchmark: version: LLMSQL version output_file: Path to write outputs (will be overwritten). workdir_path: Directory to store downloaded benchmark files. If omitted, a temporary directory is created automatically. num_fewshots: Number of few-shot examples (0, 1, or 5). batch_size: Number of questions per generation batch. seed: Random seed for reproducibility. limit: Limit the number of questions to evaluate. If an integer, evaluates the first N samples. If a float between 0.0 and 1.0, evaluates the first X*100% of samples. If None, evaluates all samples (default). Returns: List of dicts containing `question_id` and generated `completion`. """ # --- setup --- llm_kwargs = llm_kwargs or {} sampling_kwargs = sampling_kwargs or {} _setup_seed(seed=seed) hf_token = hf_token or os.environ.get("HF_TOKEN") # --- load input data --- log.info("Preparing questions and tables...") workdir = resolve_workdir_path(workdir_path) repo_id = get_repo_id(version) questions_path = _maybe_download(repo_id, "questions.jsonl", workdir) tables_path = _maybe_download(repo_id, "tables.jsonl", workdir) questions = load_jsonl(questions_path) tables_list = load_jsonl(tables_path) tables = {t["table_id"]: t for t in tables_list} # --- Apply limit --- if limit is not None: if isinstance(limit, float): if not (0.0 < limit <= 1.0): raise ValueError( f"When a float, `limit` must be between 0.0 and 1.0, got {limit}." ) limit = max(1, int(len(questions) * limit)) if not isinstance(limit, int) or limit < 1: raise ValueError( f"`limit` must be a positive integer or a float in (0.0, 1.0], got {limit!r}." ) log.info( f"Limiting evaluation to first {limit} questions out of {len(questions)}" ) questions = questions[:limit] # --- Validate LoRA usage --- enable_lora = llm_kwargs.get("enable_lora", False) if lora_config is not None and not enable_lora: raise ValueError( "LoRA config provided but `enable_lora` is not True in llm_kwargs." ) if enable_lora and lora_config is None: raise ValueError("`enable_lora` is True but no `lora_config` was provided.") if lora_config is not None and not enable_lora: raise ValueError( "`lora_config` provided but `enable_lora` is not True in llm_kwargs." ) # --- init model --- llm_init_args = { "model": model_name, "tokenizer": model_name, "tensor_parallel_size": tensor_parallel_size, "trust_remote_code": trust_remote_code, **llm_kwargs, # user overrides } log.info(f"Loading vLLM model '{model_name}' (tp={tensor_parallel_size})...") llm = LLM(**llm_init_args) # --- LoRA request --- lora_request = None if enable_lora and lora_config is not None: log.info(f"Loading LoRA adapter from {lora_config['lora_path']}") lora_request = LoRARequest( lora_name=lora_config["lora_name"], lora_path=lora_config["lora_path"], scaling=lora_config["lora_scale"], ) tokenizer = llm.get_tokenizer() if use_chat_template: use_chat_template = getattr(tokenizer, "chat_template", None) # type: ignore # --- prepare output file --- overwrite_jsonl(output_file) log.info(f"Output will be written to {output_file}") # --- prompt builder and sampling params --- prompt_builder = choose_prompt_builder(num_fewshots) effective_temperature = 0.0 if not do_sample else temperature sampling_params_args = { "temperature": effective_temperature, "max_tokens": max_new_tokens, **sampling_kwargs, } sampling_params = SamplingParams(**sampling_params_args) # --- build all requests --- prompts = build_all_requests( questions, tables, prompt_builder, tokenizer=tokenizer if use_chat_template else None, use_chat_template=bool(use_chat_template), ) # --- main inference loop --- all_results: list[dict[str, str]] = [] total = len(questions) for batch_start in tqdm(range(0, total, batch_size), desc="Generating"): batch_prompts = prompts[batch_start : batch_start + batch_size] batch_questions = questions[batch_start : batch_start + batch_size] outputs = llm.generate( batch_prompts, sampling_params, lora_request=lora_request, ) batch_results: list[dict[str, str]] = [] for q, out in zip(batch_questions, outputs, strict=False): text = out.outputs[0].text batch_results.append( { "question_id": q.get("question_id", q.get("id", "")), "completion": text, } ) save_jsonl_lines(output_file, batch_results) all_results.extend(batch_results) log.info( f"Saved batch {batch_start // batch_size + 1}: {len(all_results)}/{total}" ) log.info(f"Generation completed. {len(all_results)} results saved to {output_file}") return all_results