Source code for panda_guard.llms.base

# encoding: utf-8
# Author    : Floyed<Floyed_Shen@outlook.com>
# Datetime  : 2024/8/31 22:00
# User      : yu
# Product   : PyCharm
# Project   : panda-guard
# File      : base.py
# explain   :

import abc

from typing import Dict, List, Union, Any, Tuple
from dataclasses import dataclass, field
import concurrent.futures


[docs]@dataclass class LLMGenerateConfig: """ Configuration for LLM generation. :param max_n_tokens: Maximum number of tokens to generate. :param temperature: Temperature for sampling randomness. :param logprobs: Whether to return log probabilities. :param seed: Seed for reproducibility. :param stream: Whether to use streaming generation. """ max_n_tokens: int = field(default=None) temperature: float = field(default=None) logprobs: bool = field(default=False) seed: int = field(default=None) stream: bool = field(default=False) # Default to non-streaming behavior
[docs]@dataclass class BaseLLMConfig(abc.ABC): """ Base configuration for LLM. :param llm_type: Type of the LLM. :param model_name: Name of the model. """ llm_type: str = field(default=None) model_name: str = field(default=None)
[docs]class BaseLLM(abc.ABC): """ Abstract base class for LLM. :param config: Configuration object for LLM. """ def __init__(self, config: BaseLLMConfig): self._CLS = config.llm_type self._NAME = config.model_name self.prompt_tokens = 0 self.completion_tokens = 0 self.num_requires = 0
[docs] @abc.abstractmethod def generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ) -> Union[List[Dict[str, str]], Tuple[List[Dict[str, str]], List[float]]]: """ Abstract method for generating response from LLM. :param messages: List of messages for input. :param config: Configuration for generation. :return: Generated response or responses with log probabilities. """ pass
[docs] @abc.abstractmethod def evaluate_log_likelihood( self, messages: List[Dict[str, str]], config: LLMGenerateConfig, require_grad=False, ) -> List[float]: """ Abstract method for evaluating log likelihood of messages. :param messages: List of messages to evaluate. :param config: Configuration for generation. :param require_grad: Whether grad information is needed. :return: List of log likelihoods. """ pass
[docs] @abc.abstractmethod def continual_generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ) -> Union[List[Dict[str, str]], Tuple[List[Dict[str, str]], List[float]]]: """ Remove EOS token in formatted prompt. Manually add generation prompt. :param messages: List of messages for input. :param config: Configuration for generation. :return: Generated response or responses with log probabilities. """ pass
[docs] def batch_generate( self, batch_messages: List[List[Dict[str, str]]], config: LLMGenerateConfig, ) -> List[Union[List[Dict[str, str]], Tuple[List[Dict[str, str]], List[float]]]]: """ Generate responses for a batch of messages concurrently. :param batch_messages: List of batches of messages. :param config: Configuration for generation. :return: List of generated responses. """ configs = [config] * len(batch_messages) if config.seed is not None: seeds = list(range(config.seed, config.seed + len(batch_messages))) for i, c in enumerate(configs): c.seed = seeds[i] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ executor.submit(self.generate, messages, config) for config, messages in zip(configs, batch_messages) ] results = [ future.result() for future in concurrent.futures.as_completed(futures) ] return results
[docs] def reset(self): """ Reset the token counts and the number of requests. """ self.prompt_tokens = 0 self.completion_tokens = 0 self.num_requires = 0
[docs] def update( self, prompt_tokens: int, completion_tokens: int, num_requires: int, ): """ Update the token counts and number of requests. :param prompt_tokens: Number of tokens in prompt. :param completion_tokens: Number of tokens in completion. :param num_requires: Number of requests made. """ self.prompt_tokens += prompt_tokens self.completion_tokens += completion_tokens self.num_requires += num_requires
@property def total_tokens(self): """ Get the total number of tokens used. :return: Total number of tokens. """ return self.prompt_tokens + self.completion_tokens @property def avg_tokens(self): """ Get the average number of tokens per request. :return: Average number of tokens. """ return self.total_tokens / self.num_requires