Source code for panda_guard.pipelines.inference

# encoding: utf-8
# Author    : Floyed<Floyed_Shen@outlook.com>
# Datetime  : 2024/9/4 14:37
# User      : yu
# Product   : PyCharm
# Project   : panda-guard
# File      : inference.py
# explain   : Inference pipeline for executing attack and defense phases.

import abc
from typing import Dict, List, Union, Any, Iterator
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor, as_completed

from panda_guard.llms import BaseLLM
from panda_guard.role.attacks import create_attacker, BaseAttacker, BaseAttackerConfig
from panda_guard.role.attacks.transfer import TransferAttacker
from panda_guard.role.defenses import create_defender, BaseDefender, BaseDefenderConfig
from panda_guard.role.judges import create_judge, BaseJudge, BaseJudgeConfig


[docs]@dataclass class InferPipelineConfig: """ Configuration class for the inference pipeline. :param attacker_config: Configuration for the attacker, detailing its parameters and behavior. :param defender_config: Configuration for the defender, detailing its parameters and behavior. :param judge_configs: List of configurations for any judges used in evaluating the effectiveness of the defense. """ attacker_config: BaseAttackerConfig = field(default=None) defender_config: BaseDefenderConfig = field(default=None) judge_configs: List[BaseJudgeConfig] = field(default_factory=list)
[docs]def llm_register(attr: Union[BaseAttacker, BaseDefender, BaseJudge]) -> List[BaseLLM]: """ Register any LLM instances associated with the given attribute (attacker, defender, or judge). :param attr: An instance of an attacker, defender, or judge. :return: List of registered LLM instances. """ results = [] for name, value in vars(attr).items(): if isinstance(value, BaseLLM): results.append(value) return results
[docs]class InferPipeline: """ Inference pipeline for managing and executing attacks and defenses. :param config: Configuration for the inference pipeline, including attacker, defender, and judges. :param verbose: Whether to enable verbose logging for debugging and transparency. """ def __init__(self, config: InferPipelineConfig, verbose: bool = False): # Create instances of the attacker, defender, and any judges based on the provided configurations. # 根据提供的配置创建攻击者、防御者和评审者的实例。 self.attacker = create_attacker(config.attacker_config) self.defender = create_defender(config.defender_config) self.judges = [ create_judge(judge_config) for judge_config in config.judge_configs ] # Register LLMs used by the attacker, defender, and judges. # 注册攻击者、防御者和评审者所使用的LLM。 self.attack_llms = llm_register(self.attacker) self.defense_llms = llm_register(self.defender) self.judges_llms = [llm_register(judge) for judge in self.judges] self.verbose = verbose def __call__(self, messages: List[Dict[str, str]], **kwargs) -> Union[Dict[str, Any], Iterator[Dict[str, Any]]]: """ Execute the full inference process, including attack and defense phases. :param messages: Input messages to be processed through the pipeline. :param kwargs: Additional keyword arguments, such as reformulated requests for transfer attacks. :return: A dictionary containing the final defended messages and the token usage details, or an iterator yielding streaming chunks. """ # Check if streaming is enabled in the defender's config is_streaming = getattr(self.defender.target_llm_gen_config, "stream", False) # Check for TransferAttacker and validate the required argument if isinstance(self.attacker, TransferAttacker): assert ( "request_reformulated" in kwargs ), "TransferAttacker requires a 'request_reformulated' parameter." self.log(messages, "[INPUT]") # Perform the attack phase using the attacker instance attack = self.attacker.attack( messages, request_reformulated=kwargs.get("request_reformulated", None) ) self.log(attack, "[ATTACK]") # Perform the defense phase defense_result = self.defender.defense(attack) # Check if the result is an iterator (streaming) or regular messages if is_streaming and hasattr(defense_result, "__iter__") and not isinstance(defense_result, list): # It's a streaming iterator, yield chunks as they come def stream_generator(): final_messages = None for chunk, current_messages in defense_result: # Yield each chunk with the current state of messages yield { "chunk": chunk, "messages": current_messages, "complete": False } final_messages = current_messages # After all chunks are processed, yield a final complete message with usage if final_messages: self.log(final_messages, "[DEFENSE]") yield { "messages": final_messages, "complete": True, "usage": self.calc_tokens() } return stream_generator() else: # It's a regular result (non-streaming) self.log(defense_result, "[DEFENSE]") usage = self.calc_tokens() return {"messages": defense_result, "usage": usage}
[docs] def parallel_judging( self, defense: List[Dict[str, str]], request: str ) -> Dict[str, Any]: """ Perform parallel evaluation of the defense using multiple judges. :param defense: Messages that have been defended and need evaluation. :param request: The original request for context in evaluation. :return: A dictionary containing the results from each judge. """ judge_results = {} # Use ThreadPoolExecutor for parallel judge evaluations. # 使用ThreadPoolExecutor进行评审者的并行评估。 with ThreadPoolExecutor() as executor: future_to_judge = { executor.submit(judge.judge, defense, request): judge for judge in self.judges } for future in as_completed(future_to_judge): judge = future_to_judge[future] try: result = future.result() judge_results[judge._NAME] = result except Exception as exc: print(f"Judge {judge._NAME} generated an exception: {exc}") return judge_results
[docs] def calc_tokens(self) -> Dict[str, Dict[str, int]]: """ Calculate the token usage for both the attacker and defender, summarizing prompt and completion tokens. :return: Dictionary containing detailed token usage for both roles. """ # Calculate token usage for the attacker. # 计算攻击者的Token使用量。 attacker_prompt_tokens = sum(llm.prompt_tokens for llm in self.attack_llms) attacker_completion_tokens = sum( llm.completion_tokens for llm in self.attack_llms ) # Calculate token usage for the defender. # 计算防御者的Token使用量。 defender_prompt_tokens = sum(llm.prompt_tokens for llm in self.defense_llms) defender_completion_tokens = sum( llm.completion_tokens for llm in self.defense_llms ) return { "attacker": { "prompt_tokens": attacker_prompt_tokens, "completion_tokens": attacker_completion_tokens, }, "defender": { "prompt_tokens": defender_prompt_tokens, "completion_tokens": defender_completion_tokens, }, }
[docs] def reset(self): """ Reset the state of all LLMs in the pipeline, preparing them for a new inference run. """ for llm in self.attack_llms + self.defense_llms: llm.reset()
[docs] def log(self, messages: List[Dict[str, str]], stage: str = None): """ Log messages for each stage of the pipeline if verbose logging is enabled. :param messages: The messages to be logged at the current stage. :param stage: A label for the current stage of the pipeline. """ if self.verbose: pass
# print(stage) # print(messages) # print()