Source code for panda_guard.cli.serve

import sys
import os
import yaml
import typer
import logging
import time
from typing import Optional, Iterator, Dict, Any, List, Union
from pathlib import Path
from rich.console import Console
from fastapi import FastAPI, Request, BackgroundTasks, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
import json
import asyncio
from pydantic import BaseModel, Field

from panda_guard.pipelines.inference import InferPipeline, InferPipelineConfig
from panda_guard.utils import parse_configs_from_dict

app = typer.Typer(help="Language model server and chat interface", invoke_without_command=True)
console = Console()

# Global variables to store the pipeline instance
global_pipeline = None
verbose_mode = False


# Define API models based on OpenAI API format
[docs]class ChatCompletionRequest(BaseModel): model: str messages: List[Dict[str, str]] temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None
[docs]class Choice(BaseModel): index: int message: Dict[str, str] finish_reason: str
[docs]class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[Choice] usage: Dict[str, int]
[docs]class ModelData(BaseModel): id: str object: str = "model" created: int owned_by: str = "panda-guard"
[docs]class ModelsResponse(BaseModel): object: str = "list" data: List[ModelData]
[docs]def load_yaml(yaml_file): """Load YAML configuration file""" with open(yaml_file, 'r') as file: return yaml.safe_load(file)
[docs]def get_package_config_path(model_type: str) -> Path: """Get the path to a default config file within the package.""" try: package_path = Path(__file__).parent.parent config_path = package_path / "config" / f"{model_type}.yaml" if not config_path.exists(): raise FileNotFoundError(f"Default config file for {model_type} not found at {config_path}") return config_path except Exception as e: console.print(f"[bold red]Error finding default config: {str(e)}[/bold red]") raise typer.Exit(1)
[docs]def apply_env_vars_to_config(config_dict: Dict[str, Any], model_type: str) -> Dict[str, Any]: """Apply environment variables to the config if needed.""" if model_type == "openai": if "base_url" in config_dict["defender"]["target_llm_config"] and os.environ.get("OPENAI_BASE_URL"): config_dict["defender"]["target_llm_config"]["base_url"] = os.environ["OPENAI_BASE_URL"] if os.environ.get("OPENAI_API_KEY"): if "api_key" not in config_dict["defender"]["target_llm_config"]: config_dict["defender"]["target_llm_config"]["api_key"] = os.environ["OPENAI_API_KEY"] elif model_type == "gemini": if os.environ.get("GEMINI_API_KEY"): if "api_key" not in config_dict["defender"]["target_llm_config"]: config_dict["defender"]["target_llm_config"]["api_key"] = os.environ["GEMINI_API_KEY"] elif model_type == "claude": if os.environ.get("ANTHROPIC_API_KEY"): if "api_key" not in config_dict["defender"]["target_llm_config"]: config_dict["defender"]["target_llm_config"]["api_key"] = os.environ["ANTHROPIC_API_KEY"] if "base_url" in config_dict["defender"]["target_llm_config"] and os.environ.get("ANTHROPIC_BASE_URL"): config_dict["defender"]["target_llm_config"]["base_url"] = os.environ["ANTHROPIC_BASE_URL"] return config_dict
[docs]def create_fastapi_app(pipeline_instance): """Create a FastAPI application with OpenAI-compatible endpoints.""" api_app = FastAPI(title="PandaGuard API Server", description="API server compatible with OpenAI API format") # Add CORS middleware api_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define the chat completions endpoint @api_app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): nonlocal pipeline_instance try: # Convert request to the format expected by the pipeline messages = request.messages # Update config based on request parameters if request.temperature is not None: pipeline_instance.defender.target_llm_gen_config.temperature = request.temperature # Set streaming mode pipeline_instance.defender.target_llm_gen_config.stream = request.stream # Process through pipeline start_time = time.time() if request.stream: # For streaming, return a StreamingResponse async def stream_generator(): try: # Call the pipeline and get the result result = pipeline_instance(messages) # Debug logging logging.debug(f"Stream result type: {type(result)}") if isinstance(result, dict): logging.debug(f"Stream result keys: {result.keys()}") id_value = f"chatcmpl-{int(time.time())}" created_time = int(time.time()) model_name = pipeline_instance.defender.target_llm._NAME # Create and yield the initial response header = { "id": id_value, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "choices": [ { "index": 0, "delta": {"role": "assistant"}, "finish_reason": None } ] } yield f"data: {json.dumps(header)}\n\n" # Process the streaming content content_so_far = "" # Handle different types of streaming results if isinstance(result, dict) and "messages" in result: stream_content = result["messages"] # Check if it's an iterator if hasattr(stream_content, "__iter__") and hasattr(stream_content, "__next__"): for chunk in stream_content: content_so_far += chunk chunk_data = { "id": id_value, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "choices": [ { "index": 0, "delta": {"content": chunk}, "finish_reason": None } ] } yield f"data: {json.dumps(chunk_data)}\n\n" await asyncio.sleep(0.01) else: # If it's not an iterator, treat it as a complete message if isinstance(stream_content, list) and len(stream_content) > 0: final_content = stream_content[-1]["content"] content_so_far = final_content chunk_data = { "id": id_value, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "choices": [ { "index": 0, "delta": {"content": final_content}, "finish_reason": None } ] } yield f"data: {json.dumps(chunk_data)}\n\n" elif isinstance(result, list): # If the result is a list of messages if len(result) > 0: final_content = result[-1]["content"] content_so_far = final_content chunk_data = { "id": id_value, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "choices": [ { "index": 0, "delta": {"content": final_content}, "finish_reason": None } ] } yield f"data: {json.dumps(chunk_data)}\n\n" # Run judge evaluation if judges are available if hasattr(pipeline_instance, 'judges') and pipeline_instance.judges: try: user_msg = messages[-1]["content"] # Create messages for judge evaluation judge_messages = [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": content_so_far} ] # Run judges judge_results = pipeline_instance.parallel_judging(judge_messages, user_msg) # Send judge results as a special chunk if they exist if judge_results: judge_chunk = { "id": id_value, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "judge_results": judge_results } yield f"data: {json.dumps(judge_chunk)}\n\n" except Exception as judge_error: logging.exception(f"Error during judge evaluation: {str(judge_error)}") # Send completion signal final_chunk = { "id": id_value, "object": "chat.completion.chunk", "created": created_time, "model": model_name, "choices": [ { "index": 0, "delta": {}, "finish_reason": "stop" } ] } yield f"data: {json.dumps(final_chunk)}\n\n" yield "data: [DONE]\n\n" except Exception as e: logging.exception(f"Error during streaming: {str(e)}") error_data = { "error": { "message": str(e), "type": "server_error" } } yield f"data: {json.dumps(error_data)}\n\n" return StreamingResponse( stream_generator(), media_type="text/event-stream" ) else: # For non-streaming, return a standard JSON response result = pipeline_instance(messages) if isinstance(result, dict) and "messages" in result: response_messages = result["messages"] else: response_messages = result # Get the assistant's response if isinstance(response_messages, list) and len(response_messages) > 0: assistant_message = response_messages[-1]["content"] else: assistant_message = "No response generated" # Get usage data if available usage = result.get("usage", {}) token_usage = { "prompt_tokens": sum(role.get("prompt_tokens", 0) for role in usage.values()), "completion_tokens": sum(role.get("completion_tokens", 0) for role in usage.values()), "total_tokens": sum( role.get("prompt_tokens", 0) + role.get("completion_tokens", 0) for role in usage.values()) } # Run judge evaluation if judges are available judge_results = None if hasattr(pipeline_instance, 'judges') and pipeline_instance.judges: try: user_msg = messages[-1]["content"] # Create messages for judge evaluation judge_messages = [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_message} ] # Run judges judge_results = pipeline_instance.parallel_judging(judge_messages, user_msg) except Exception as judge_error: logging.exception(f"Error during judge evaluation: {str(judge_error)}") # Create OpenAI-compatible response response = { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": pipeline_instance.defender.target_llm._NAME, "choices": [ { "index": 0, "message": { "role": "assistant", "content": assistant_message }, "finish_reason": "stop" } ], "usage": token_usage } # Add judge results if available if judge_results: response["judge_results"] = judge_results return response except Exception as e: logging.exception("Error in chat completions endpoint") raise HTTPException(status_code=500, detail=str(e)) # Define the models endpoint @api_app.get("/v1/models") async def list_models(): nonlocal pipeline_instance try: model_name = pipeline_instance.defender.target_llm._NAME # Create a response in OpenAI format models_data = [ { "id": model_name, "object": "model", "created": int(time.time()), "owned_by": "panda-guard" } ] return { "object": "list", "data": models_data } except Exception as e: logging.exception("Error in models endpoint") raise HTTPException(status_code=500, detail=str(e)) # Add health check endpoint @api_app.get("/health") async def health_check(): return {"status": "healthy", "version": "1.0.0"} return api_app
[docs]@app.callback(invoke_without_command=True) def start( config: Optional[str] = typer.Argument(None, help="Path to YAML configuration file"), defense: Optional[Path] = typer.Option(None, "--defense", "-d", help="Path to defense configuration file or defense type (goal_priority/icl/none/rpo/self_reminder/smoothllm)"), judge: Optional[str] = typer.Option(None, "--judge", "-j", help="Path to judge configuration file or defense type (llm_based/rule_based). Multiple judges can be specified using comma separation."), endpoint: Optional[Path] = typer.Option(None, "--endpoint", "-e", help="Path to endpoint configuration file or endpoint type (openai/gemini/claude)"), model: Optional[Path] = typer.Option(None, "--model", "-m", help="model name"), temperature: Optional[float] = typer.Option(None, "--temperature", "-t", help="Override temperature setting"), device: Optional[str] = typer.Option(None, "--device", help="Device to run the model on (e.g., 'cuda:0')"), log_level: str = typer.Option("WARNING", "--log-level", help="Logging level (DEBUG, INFO, WARNING, ERROR)"), port: int = typer.Option(8000, "--port", "-p", help="Port to run the server on"), host: str = typer.Option("0.0.0.0", "--host", help="Host to bind the server to"), verbose: bool = typer.Option(False, "--verbose/--no-verbose", help="Enable/disable verbose mode with token usage info"), ): """ Start an API server compatible with OpenAI API format. Accepts the same configuration options as the chat interface, plus host and port settings. """ global global_pipeline, verbose_mode verbose_mode = verbose # Set up logging logging.basicConfig(level=getattr(logging, log_level), format='%(asctime)s - %(levelname)s - %(message)s') # Load configuration try: config_dict = {} if config is None: config_path = get_package_config_path('tasks/chat') config_dict = load_yaml(config_path) # Load and merge defense config if provided if defense: if not defense.exists(): defense_path = get_package_config_path(f'defenses/{defense}') if not defense_path.exists(): typer.echo(f"Error: Defense config file {defense} not found", err=True) raise typer.Exit(1) defense = defense_path config_dict["defender"] = load_yaml(defense) # Parse multiple judges if provided judge_configs = [] if judge: judge_names = [j.strip() for j in judge.split(',')] for judge_name in judge_names: judge_path = Path(judge_name) if not judge_path.exists(): judge_path = get_package_config_path(f'judges/{judge_name}') if not judge_path.exists(): typer.echo(f"Error: Judge config file {judge_name} not found", err=True) raise typer.Exit(1) judge_config = load_yaml(judge_path) if 'judge_llm_config' in judge_config: if judge_config["judge_llm_config"].get("base_url", None) is None: if os.environ.get("OPENAI_BASE_URL"): judge_config["judge_llm_config"]["base_url"] = os.environ["OPENAI_BASE_URL"] else: judge_config["judge_llm_config"]["base_url"] = "https://api.openai.com/v1" if judge_config["judge_llm_config"].get("api_key", None) is None: if os.environ.get("OPENAI_API_KEY"): judge_config["judge_llm_config"]["api_key"] = os.environ["OPENAI_API_KEY"] else: raise ValueError("API key not found for judge LLM config") judge_configs.append(judge_config) config_dict["judges"] = judge_configs endpoint_type = None if endpoint is None: endpoint_type = "openai" endpoint_path = get_package_config_path(f"endpoints/{endpoint_type}") config_dict["defender"]["target_llm_config"] = load_yaml(endpoint_path) elif config and config.lower() in ["openai", "gemini", "claude"]: endpoint_type = config.lower() endpoint_path = get_package_config_path(f"endpoints/{endpoint_type}") config_dict["defender"]["target_llm_config"] = load_yaml(endpoint_path) elif config and config.endswith(".yaml"): endpoint_path = Path(config) if not endpoint_path.exists(): typer.echo(f"Error: Config file {endpoint_path} not found", err=True) raise typer.Exit(1) config_dict["defender"]["target_llm_config"] = load_yaml(endpoint_path) else: typer.echo( f"Error: Invalid config option '{config}'. Must be a .yaml file or one of 'openai', 'gemini', 'claude'", err=True) raise typer.Exit(1) if endpoint_path: config_dict = apply_env_vars_to_config(config_dict, endpoint_type) # Load and merge model config if provided if model: config_dict["defender"]["target_llm_config"]['model_name'] = str(model) # Override device if specified if device: config_dict["defender"]["target_llm_config"]["device_map"] = device # Override temperature if specified if temperature is not None: config_dict["defender"]["target_llm_gen_config"]["temperature"] = temperature # Initialize the pipeline attacker_config, defender_config, judge_config = parse_configs_from_dict(config_dict) pipeline = InferPipeline( InferPipelineConfig( attacker_config=attacker_config, defender_config=defender_config, judge_configs=judge_config, ), verbose=verbose # Pass verbose flag to pipeline ) # Store in global variable for access from the API global_pipeline = pipeline # Create the FastAPI app api_app = create_fastapi_app(pipeline) # Print server information model_name = config_dict["defender"]["target_llm_config"]["model_name"] console.print(f"[bold green]Starting API server with model: {model_name}[/bold green]") console.print(f"[bold green]Server running at http://{host}:{port}[/bold green]") console.print("[bold green]API is compatible with OpenAI API format[/bold green]") # Start the server uvicorn.run(api_app, host=host, port=port) except Exception as e: typer.echo(f"Error starting server: {str(e)}", err=True) logging.exception("Exception during server startup") raise typer.Exit(1)
if __name__ == "__main__": app()