# encoding: utf-8
# Author : Floyed<Floyed_Shen@outlook.com>
# Datetime : 2024/9/4 14:37
# User :
# Product : PyCharm
# Project : panda-guard
# File : chat.py
# explain :
import sys
import os
import yaml
import typer
import logging
import time
from typing import Optional, Iterator, Dict, Any, List
from pathlib import Path
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
from rich import box
from panda_guard.pipelines.inference import InferPipeline, InferPipelineConfig
from panda_guard.utils import parse_configs_from_dict
app = typer.Typer(help="Interactive chat with language models", invoke_without_command=True)
console = Console()
[docs]def is_iterator(obj):
"""Check if an object is an iterator but not a list or other sequence."""
return (
hasattr(obj, "__iter__") and
not isinstance(obj, (list, dict, str, bytes, tuple)) and
hasattr(obj, "__next__")
)
[docs]def get_package_config_path(model_type: str) -> Path:
"""Get the path to a default config file within the package."""
try:
package_path = Path(__file__).parent.parent
config_path = package_path / "config" / f"{model_type}.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Default config file for {model_type} not found at {config_path}")
return config_path
except Exception as e:
console.print(f"[bold red]Error finding default config: {str(e)}[/bold red]")
raise typer.Exit(1)
[docs]def apply_env_vars_to_config(config_dict: Dict[str, Any], model_type: str) -> Dict[str, Any]:
"""Apply environment variables to the config if needed."""
if model_type == "openai":
if "base_url" in config_dict["defender"]["target_llm_config"] and os.environ.get("OPENAI_BASE_URL"):
config_dict["defender"]["target_llm_config"]["base_url"] = os.environ["OPENAI_BASE_URL"]
if os.environ.get("OPENAI_API_KEY"):
if "api_key" not in config_dict["defender"]["target_llm_config"]:
config_dict["defender"]["target_llm_config"]["api_key"] = os.environ["OPENAI_API_KEY"]
elif model_type == "gemini":
if os.environ.get("GEMINI_API_KEY"):
if "api_key" not in config_dict["defender"]["target_llm_config"]:
config_dict["defender"]["target_llm_config"]["api_key"] = os.environ["GEMINI_API_KEY"]
elif model_type == "claude":
if os.environ.get("ANTHROPIC_API_KEY"):
if "api_key" not in config_dict["defender"]["target_llm_config"]:
config_dict["defender"]["target_llm_config"]["api_key"] = os.environ["ANTHROPIC_API_KEY"]
if "base_url" in config_dict["defender"]["target_llm_config"] and os.environ.get("ANTHROPIC_BASE_URL"):
config_dict["defender"]["target_llm_config"]["base_url"] = os.environ["ANTHROPIC_BASE_URL"]
return config_dict
[docs]def display_token_info(usage, response_time=None):
"""Display token usage information in a less intrusive format."""
# Calculate totals
total_prompt = sum(role["prompt_tokens"] for role in usage.values())
total_completion = sum(role["completion_tokens"] for role in usage.values())
total_tokens = total_prompt + total_completion
# Display in a more subtle format
console.print("[dim]Token usage:[/dim] " +
f"Prompt: {total_prompt} | " +
f"Completion: {total_completion} | " +
f"Total: {total_tokens}")
# Add speed info if response_time is provided
if response_time is not None and response_time > 0:
tokens_per_second = total_completion / response_time
console.print(f"[dim]Response time: {response_time:.2f}s ({tokens_per_second:.2f} tokens/sec)[/dim]")
[docs]def display_judge_results(results):
"""Display judge evaluation results in a less intrusive format."""
if not results:
return
console.print("[dim]Judge evaluations:[/dim]")
for judge_name, result in results.items():
# Handle different result formats
if isinstance(result, dict):
# score = result.get("score", "N/A")
verdict = result.get("verdict", str(result))
else:
# score = "N/A"
verdict = str(result)
# Use more subtle formatting
console.print(f"[dim]{judge_name}:[/dim] {'⚠️ ' if str(verdict) == '10' else ''}{verdict} ")
[docs]def display_help():
"""Display available commands."""
help_table = Table(title="Available Commands", box=box.ROUNDED)
help_table.add_column("Command", style="cyan")
help_table.add_column("Description", style="green")
help_table.add_row("/exit", "Exit the chat")
help_table.add_row("/reset", "Reset the conversation")
help_table.add_row("/help", "Display this help message")
help_table.add_row("/verbose", "Toggle verbose mode (show token usage)")
help_table.add_row("/save [filename]", "Save conversation to file")
help_table.add_row("/stats", "Display current conversation statistics")
console.print(help_table)
[docs]@app.callback(invoke_without_command=True)
def start(
config: Optional[str] = typer.Argument(None, help="Path to YAML configuration file"),
defense: Optional[Path] = typer.Option(None, "--defense", "-d",
help="Path to defense configuration file or defense type (goal_priority/icl/none/rpo/self_reminder/smoothllm)"),
judge: Optional[str] = typer.Option(None, "--judge", "-j",
help="Path to judge configuration file or defense type (llm_based/rule_based). Multiple judges can be specified using comma separation."),
endpoint: Optional[Path] = typer.Option(None, "--endpoint", "-e",
help="Path to endpoint configuration file or endpoint type (openai/gemini/claude)"),
model: Optional[Path] = typer.Option(None, "--model", "-m", help="model name"),
temperature: Optional[float] = typer.Option(None, "--temperature", "-t", help="Override temperature setting"),
device: Optional[str] = typer.Option(None, "--device", help="Device to run the model on (e.g., 'cuda:0')"),
log_level: str = typer.Option("WARNING", "--log-level", help="Logging level (DEBUG, INFO, WARNING, ERROR)"),
output: Optional[Path] = typer.Option(None, "--output", "-o", help="Save chat history to file"),
stream: bool = typer.Option(True, "--stream/--no-stream", help="Enable/disable streaming output"),
verbose: bool = typer.Option(False, "--verbose/--no-verbose",
help="Enable/disable verbose mode with token usage info"),
):
"""
Start an interactive chat session using configuration from a YAML file or a predefined model type.
If config is a file path ending with .yaml, it will load configuration from that file.
If config is one of 'openai', 'gemini', or 'claude', it will load a default configuration and
apply relevant environment variables.
"""
# Set up logging
logging.basicConfig(level=getattr(logging, log_level), format='%(asctime)s - %(levelname)s - %(message)s')
# Load configuration
try:
config_dict = {}
if config is None:
config_path = get_package_config_path('tasks/chat')
config_dict = load_yaml(config_path)
else:
config_path = Path(config)
if not config_path.exists():
typer.echo(f"Error: Config file {config} not found", err=True)
raise typer.Exit(1)
config_dict = load_yaml(config_path)
# Load and merge defense config if provided
if defense:
if not defense.exists():
defense_path = get_package_config_path(f'defenses/{defense}')
if not defense_path.exists():
typer.echo(f"Error: Defense config file {defense} not found", err=True)
raise typer.Exit(1)
defense = defense_path
config_dict["defender"] = load_yaml(defense)
# Parse multiple judges if provided
judge_configs = []
if judge:
judge_names = [j.strip() for j in judge.split(',')]
for judge_name in judge_names:
judge_path = Path(judge_name)
if not judge_path.exists():
judge_path = get_package_config_path(f'judges/{judge_name}')
if not judge_path.exists():
typer.echo(f"Error: Judge config file {judge_name} not found", err=True)
raise typer.Exit(1)
judge_config = load_yaml(judge_path)
judge_configs.append(judge_config)
config_dict["judges"] = judge_configs
if len(config_dict.get("judges") or []) > 0:
for judge_config_ids in range(len(config_dict.get("judges", []))):
if 'judge_llm_config' in config_dict["judges"][judge_config_ids]:
print("judge_llm_config:", config_dict["judges"][judge_config_ids])
if config_dict["judges"][judge_config_ids]["judge_llm_config"].get("base_url", None) is None:
if os.environ.get("OPENAI_BASE_URL"):
config_dict["judges"][judge_config_ids]["judge_llm_config"]["base_url"] = os.environ[
"OPENAI_BASE_URL"]
else:
config_dict["judges"][judge_config_ids]["judge_llm_config"][
"base_url"] = "https://api.openai.com/v1"
if config_dict["judges"][judge_config_ids]["judge_llm_config"].get("api_key", None) is None:
if os.environ.get("OPENAI_API_KEY"):
config_dict["judges"][judge_config_ids]["judge_llm_config"]["api_key"] = os.environ[
"OPENAI_API_KEY"]
else:
raise ValueError("API key not found for judge LLM config")
endpoint_type = None
print(config_dict)
if endpoint is None:
endpoint_type = "openai"
endpoint_path = get_package_config_path(f"endpoints/{endpoint_type}")
config_dict["defender"]["target_llm_config"] = load_yaml(endpoint_path)
elif str(endpoint).lower() in ["openai", "gemini", "claude", "hf"]:
endpoint_type = str(endpoint).lower()
endpoint_path = get_package_config_path(f"endpoints/{endpoint_type}")
config_dict["defender"]["target_llm_config"] = load_yaml(endpoint_path)
elif config and isinstance(config, str) and config.endswith(".yaml"):
endpoint_path = Path(config)
if not endpoint_path.exists():
typer.echo(f"Error: Config file {endpoint_path} not found", err=True)
raise typer.Exit(1)
config_dict["defender"]["target_llm_config"] = load_yaml(endpoint_path)
else:
typer.echo(
f"Error: Invalid config option '{config}'. Must be a .yaml file or one of 'openai', 'gemini', 'claude'",
err=True)
raise typer.Exit(1)
if endpoint_path:
config_dict = apply_env_vars_to_config(config_dict, endpoint_type)
# Load and merge model config if provided
if model:
config_dict["defender"]["target_llm_config"]['model_name'] = str(model)
# Override device if specified
if device:
config_dict["defender"]["target_llm_config"]["device_map"] = device
# Override temperature if specified
if temperature is not None:
config_dict["defender"]["target_llm_gen_config"]["temperature"] = temperature
# Enable or disable streaming in the configuration
config_dict["defender"]["target_llm_gen_config"]["stream"] = stream
print("Configuration loaded successfully:", config_dict)
# Initialize the pipeline
attacker_config, defender_config, judge_config = parse_configs_from_dict(config_dict)
pipe = InferPipeline(
InferPipelineConfig(
attacker_config=attacker_config,
defender_config=defender_config,
judge_configs=judge_config,
),
verbose=verbose # Pass verbose flag to pipeline
)
console.print(
f"[bold green]Chat initialized with {config_dict['defender']['target_llm_config']['model_name']}[/bold green]")
console.print("[bold]Type your message (or '/help' for available commands)[/bold]")
# Initialize chat history and variables for tracking statistics
messages = []
total_usage = {"attacker": {"prompt_tokens": 0, "completion_tokens": 0},
"defender": {"prompt_tokens": 0, "completion_tokens": 0}}
total_response_time = 0
total_responses = 0
# Chat loop
while True:
user_input = console.input("[bold blue]User:[/bold blue] ")
if user_input.lower().startswith("/"):
# Handle commands
if user_input.lower() == "/exit":
break
elif user_input.lower() == "/reset":
messages = []
pipe.reset()
total_usage = {"attacker": {"prompt_tokens": 0, "completion_tokens": 0},
"defender": {"prompt_tokens": 0, "completion_tokens": 0}}
total_response_time = 0
total_responses = 0
console.print("[bold yellow]Conversation reset[/bold yellow]")
continue
elif user_input.lower() == "/help":
display_help()
continue
elif user_input.lower() == "/verbose":
verbose = not verbose
console.print(f"[bold yellow]Verbose mode {'enabled' if verbose else 'disabled'}[/bold yellow]")
continue
elif user_input.lower().startswith("/save"):
parts = user_input.split(maxsplit=1)
filename = parts[1] if len(parts) > 1 else "conversation.json"
import json
with open(filename, 'w') as f:
json.dump(messages, f, indent=2)
console.print(f"[bold]Chat history saved to {filename}[/bold]")
continue
elif user_input.lower() == "/stats":
# Display cumulative stats in the new, less intrusive style
console.print("[bold]Cumulative Conversation Statistics:[/bold]")
# Get the most up-to-date token usage
if messages:
current_usage = pipe.calc_tokens()
# Calculate totals from current usage
total_prompt = sum(role["prompt_tokens"] for role in current_usage.values())
total_completion = sum(role["completion_tokens"] for role in current_usage.values())
total_all = total_prompt + total_completion
else:
total_prompt = 0
total_completion = 0
total_all = 0
console.print(
f"Token usage: Prompt: {total_prompt} | Completion: {total_completion} | Total: {total_all}")
if total_responses > 0:
avg_response_time = total_response_time / total_responses
console.print(f"Average response time: {avg_response_time:.2f}s")
continue
else:
display_help()
continue
# Create or update messages
if not messages:
messages = [{"role": "user", "content": user_input}]
else:
messages.append({"role": "user", "content": user_input})
try:
# Record start time for response speed calculation
start_time = time.time()
# Process through pipeline
result = pipe(messages)
# Flag to track if we need to run judges
run_judges = len(pipe.judges) > 0
# Check if the result is streaming or regular
if stream and isinstance(result, dict) and "messages" in result and is_iterator(result["messages"]):
# Handle streaming response
console.print("[bold green]Assistant:[/bold green]")
full_response = ""
try:
generator = result["messages"]
for text_chunk in generator:
full_response += text_chunk
console.print(text_chunk, end="")
sys.stdout.flush()
console.print()
except Exception as e:
console.print(f"\n[bold red]Error during streaming: {str(e)}[/bold red]")
logging.exception("Exception during streaming")
else:
# Handle regular non-streaming response
if isinstance(result, dict) and "messages" in result:
messages = result["messages"]
else:
# If result is not a dict with messages, it might be the messages directly
messages = result
# Display assistant response
assistant_response = messages[-1]["content"]
console.print("[bold green]Assistant:[/bold green]")
console.print(Markdown(assistant_response))
# Within the chat loop where response time is calculated and usage is displayed:
# Calculate response time
response_time = time.time() - start_time
total_response_time += response_time
total_responses += 1
# Get the most up-to-date token usage directly from the pipeline
current_usage = pipe.calc_tokens()
# Update total usage with the current values (not incremental)
for role in current_usage:
for token_type in current_usage[role]:
if role in total_usage:
if token_type not in total_usage[role]:
total_usage[role][token_type] = 0
# Replace with current value instead of adding
total_usage[role][token_type] = current_usage[role][token_type]
# Display token usage information if verbose mode is enabled
if verbose:
console.print("") # Add an empty line for better readability
display_token_info(current_usage, response_time) # Pass the actual response_time value
# Run judge evaluation if judges are available
if run_judges and isinstance(messages, list) and len(messages) >= 2:
user_msg = messages[-2]["content"] if len(messages) >= 2 else ""
assistant_msg = messages[-1]["content"]
# Create messages for judge evaluation
judge_messages = [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
]
# Run judges in parallel
judge_results = pipe.parallel_judging(judge_messages, user_msg)
# Display judge results if verbose or if judges were specified
if judge_results:
console.print("") # Add an empty line for better readability
display_judge_results(judge_results)
except Exception as e:
console.print(f"[bold red]Error: {str(e)}[/bold red]")
logging.exception("Exception during chat")
continue
# Save chat history if requested
if output:
import json
with open(output, 'w') as f:
json.dump(messages, f, indent=2)
console.print(f"[bold]Chat history saved to {output}[/bold]")
except Exception as e:
typer.echo(f"Error during chat: {str(e)}", err=True)
logging.exception("Exception during chat")
raise typer.Exit(1)
[docs]def load_yaml(yaml_file):
"""Load YAML configuration file"""
with open(yaml_file, 'r') as file:
return yaml.safe_load(file)
if __name__ == "__main__":
app()