Source code for panda_guard.llms.oai

# encoding: utf-8
# Author    : Floyed<Floyed_Shen@outlook.com>
# Datetime  : 2024/8/31 22:01
# User      : yu
# Product   : PyCharm
# Project   : panda-guard
# File      : oai.py
# explain   :

import os
import time
import warnings
from typing import Dict, List, Union, Any, Tuple, Generator
from dataclasses import dataclass, field

import openai

from panda_guard.llms import BaseLLM, BaseLLMConfig, LLMGenerateConfig
from panda_guard.utils import process_end_eos


[docs]@dataclass class OpenAiLLMConfig(BaseLLMConfig): """ OpenAI LLM Configuration. :param llm_type: Type of LLM, default is "OpenAiLLM". :param model_name: Name of the model. :param base_url: Base URL for the OpenAI API. :param api_key: API key for accessing OpenAI. """ llm_type: str = field(default="OpenAiLLM") model_name: str = field(default=None) base_url: str = field(default=None) api_key: str = field(default="KEY HERE")
[docs]@dataclass class OpenAiChatLLMConfig(BaseLLMConfig): """ OpenAI Chat LLM Configuration. :param llm_type: Type of LLM, default is "OpenAiChatLLM". :param model_name: Name of the model. :param base_url: Base URL for the OpenAI API. :param api_key: API key for accessing OpenAI. """ llm_type: str = field(default="OpenAiChatLLM") model_name: str = field(default=None) base_url: str = field(default=None) api_key: str = field(default="KEY HERE")
[docs]class OpenAiChatLLM(BaseLLM): """ OpenAI Chat LLM Implementation. :param config: Configuration for OpenAI Chat LLM. """ def __init__(self, config: OpenAiLLMConfig): super().__init__(config) self.client = openai.OpenAI( base_url=config.base_url, api_key=config.api_key, )
[docs] def generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ) -> list[dict[str, str]] | tuple[list[dict[str, str]], list[Any]] | Generator[str, None, None] | None: """ Generate a response for a given input using OpenAI Chat API. :param messages: List of input messages. :param config: Configuration for LLM generation. :return: Generated response or response with logprobs or stream generator. """ if ('4k' in self._NAME or 'gemma-2-2b-it' in self._NAME) and config.max_n_tokens > 2048: config.max_n_tokens = min(config.max_n_tokens, 2048) warnings.warn(f"Model {self._NAME} only supports max_n_tokens up to 4096, setting response tokens to 2048.") if "gemma" in self._NAME.lower() and messages[0]["role"] == "system": system_prompt = messages[0]["content"] messages = messages[1:] messages[0]["content"] = system_prompt + "\n\n" + messages[0]["content"] retry_count = 0 max_retries = 50 while retry_count < max_retries: model_name = 'DeepSeek-R1' if self._NAME == 'deepseek-reasoner' else self._NAME model_name = 'DeepSeek-V3' if self._NAME == 'deepseek-ai/DeepSeek-V3' else model_name try: if config.stream: full_content = "" prompt_tokens = 0 completion_tokens = 0 # Create a streaming request stream = self.client.chat.completions.create( model=model_name, messages=messages, max_tokens=config.max_n_tokens if ('o1' not in self._NAME and 'o3' not in self._NAME) else None, temperature=config.temperature, seed=config.seed, stream=True, ) def stream_response(): nonlocal full_content, prompt_tokens, completion_tokens for chunk in stream: if chunk.choices and chunk.choices[0].delta.content: content_piece = chunk.choices[0].delta.content full_content += content_piece yield content_piece if hasattr(chunk, 'usage') and chunk.usage: if hasattr(chunk.usage, 'prompt_tokens'): prompt_tokens = chunk.usage.prompt_tokens if hasattr(chunk.usage, 'completion_tokens'): completion_tokens = chunk.usage.completion_tokens response_generator = stream_response() def wrapped_generator(): yield from response_generator messages.append({"role": "assistant", "content": full_content}) # 更新使用统计 self.update( prompt_tokens or len(str(messages[:-1])) // 4, completion_tokens or len(full_content) // 4, 1, ) return wrapped_generator() # Non-streaming mode (original code) else: if 'o1' not in self._NAME and 'o3' not in self._NAME: response = self.client.chat.completions.create( model=model_name, messages=messages, max_tokens=config.max_n_tokens, temperature=config.temperature, logprobs=config.logprobs, seed=config.seed, ) else: response = self.client.chat.completions.create( model=model_name, messages=messages, temperature=config.temperature, logprobs=config.logprobs, seed=config.seed, ) if response.choices is None: print("Cannot find choices in Response:", response) messages.append({"role": "assistant", "content": "I'm sorry, but I can't fulfill this request."}) else: content = response.choices[0].message.content messages.append({"role": "assistant", "content": content}) self.update( response.usage.prompt_tokens, response.usage.completion_tokens, 1, ) if config.logprobs: logs = [c.logprob for c in response.choices[0].logprobs.content] return messages, logs return messages except Exception as e: if "安全" in str(e) or "敏感" in str(e): messages.append({"role": "assistant", "content": "I'm sorry, I can't help with that."}) print(f"API request Safety Issue, {self._NAME}, Error: {e}, returning safety message.") return messages retry_count += 1 if retry_count >= max_retries: messages.append({"role": "assistant", "content": "I'm sorry, but I can't fulfill this request."}) print( f"API request failed when testing model {self._NAME}, tried: {max_retries}, Error: {e}") exit(0) else: print( f"API request failed when testing model {self._NAME},retrying {retry_count}/{max_retries}... Error: {e}") time.sleep(retry_count) return messages
[docs] def continual_generate(self, messages: List[Dict[str, str]], config: LLMGenerateConfig): """ Remove EOS token in formatted prompt. Manually add generation prompt. :param messages: List of messages for input. :param config: Configuration for LLM generation. :raises NotImplementedError: OpenAiChatLLM does not support continual generation. """ raise NotImplementedError( "OpenAiChatLLM does not support continual generation, please use OpenAiLLM instead." )
[docs] def evaluate_log_likelihood(self, messages: List[Dict[str, str]], config: LLMGenerateConfig) -> List[float]: """ Evaluate the log likelihood of the given messages. :param messages: List of messages for evaluation. :param config: Configuration for LLM generation. :raises NotImplementedError: OpenAI Chat does not support log likelihood evaluation. """ raise NotImplementedError( "OpenAI Chat does not support log likelihood evaluation." )
[docs]class OpenAiLLM(BaseLLM): """ OpenAI LLM Implementation. :param config: Configuration for OpenAI LLM. """ def __init__(self, config: OpenAiLLMConfig): super().__init__(config) from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( config.model_name, token=os.getenv("HF_TOKEN"), trust_remote_code=True, ) self.client = openai.OpenAI( base_url=config.base_url, api_key=config.api_key, )
[docs] def generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ) -> list[dict[str, str]] | tuple[list[dict[str, str]], list[float] | None] | Generator[str, Any, None] | None: """ Generate a response for a given input using OpenAI API. :param messages: List of input messages. :param config: Configuration for LLM generation. :return: Generated response or response with logprobs or stream generator. """ if ('4k' in self._NAME or 'gemma-2-2b-it' in self._NAME) and config.max_n_tokens > 2048: config.max_n_tokens = min(config.max_n_tokens, 2048) warnings.warn(f"Model {self._NAME} only supports max_n_tokens up to 4096, setting response tokens to 2048.") if "gemma" in self._NAME.lower() and messages[0]["role"] == "system": system_prompt = messages[0]["content"] messages = messages[1:] messages[0]["content"] = system_prompt + "\n\n" + messages[0]["content"] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) tokens = self.tokenizer(prompt, return_tensors="pt") if tokens["input_ids"].shape[1] > 3840: truncated_tokens = tokens["input_ids"][:, :3840] prompt = self.tokenizer.decode( truncated_tokens[0], skip_special_tokens=True ) # 为重试添加循环 retry_count = 0 max_retries = 10 while retry_count < max_retries: try: # 处理流式输出 if config.stream: full_content = "" prompt_tokens = len(tokens["input_ids"][0]) completion_tokens = 0 # 创建流式请求 stream = self.client.completions.create( model=self._NAME, prompt=prompt, max_tokens=config.max_n_tokens, temperature=config.temperature, stream=True, ) def stream_response(): nonlocal full_content, completion_tokens for chunk in stream: if chunk.choices and chunk.choices[0].text: content_piece = chunk.choices[0].text full_content += content_piece completion_tokens += 1 # 简单估计token数量 yield content_piece if hasattr(chunk, 'usage') and chunk.usage: if hasattr(chunk.usage, 'completion_tokens'): completion_tokens = chunk.usage.completion_tokens response_generator = stream_response() def wrapped_generator(): yield from response_generator messages.append({"role": "assistant", "content": full_content}) # 更新使用统计 self.update( prompt_tokens, completion_tokens, 1, ) return wrapped_generator() # 非流式模式(原始代码) else: response = self.client.completions.create( model=self._NAME, prompt=prompt, max_tokens=config.max_n_tokens, temperature=config.temperature, logprobs=config.logprobs, ) content = response.choices[0].text messages.append({"role": "assistant", "content": content}) self.update( response.usage.prompt_tokens, response.usage.completion_tokens, 1, ) if config.logprobs: logs = response.choices[0].logprobs.token_logprobs return messages, logs return messages except Exception as e: if "安全" in str(e) or "敏感" in str(e): messages.append({"role": "assistant", "content": "I'm sorry, I can't help with that."}) print(f"API request Safety Issue, {self._NAME}, Error: {e}, returning safety message.") return messages retry_count += 1 if retry_count >= max_retries: raise RuntimeError( f"API request failed when testing model {self._NAME}, tried: {max_retries}, Error: {e}") else: print( f"API request failed when testing model {self._NAME},retrying {retry_count}/{max_retries}... Error: {e}") time.sleep(retry_count) return None
[docs] def continual_generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ): """ Remove EOS token in formatted prompt. Manually add generation prompt. :param messages: List of messages for input. :param config: Configuration for generation. :return: Generated response or responses with log probabilities. """ if "gemma" in self._NAME.lower() and messages[0]["role"] == "system": system_prompt = messages[0]["content"] messages = messages[1:] messages[0]["content"] = system_prompt + "\n\n" + messages[0]["content"] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, continual_final_message=True ) eos_token = self.tokenizer.eos_token # remove final eos prompt = process_end_eos(msg=prompt, eos_token=eos_token) tokens = self.tokenizer(prompt, return_tensors="pt") if tokens["input_ids"].shape[1] > 3840: truncated_tokens = tokens["input_ids"][:, :3840] prompt = self.tokenizer.decode( truncated_tokens[0], skip_special_tokens=True ) response = self.client.completions.create( model=self._NAME, prompt=prompt, max_tokens=config.max_n_tokens, temperature=config.temperature, logprobs=config.logprobs, ) content = response.choices[0].text # messages.append({"role": "assistant", "content": content}) messages[-1]["content"] += content self.update( response.usage.prompt_tokens, response.usage.completion_tokens, 1, ) if config.logprobs: logs = response.choices[0].logprobs.token_logprobs return messages, logs return messages
[docs] def evaluate_log_likelihood( self, messages: List[Dict[str, str]], config: LLMGenerateConfig, require_grad=False, ) -> List[float]: """ Evaluate the log likelihood of the given messages. :param messages: List of messages for evaluation. :param config: Configuration for LLM generation. :return: List of log likelihood values. """ if require_grad: raise NotImplementedError # Make sure the content exists and is not None if not messages or "content" not in messages[-1] or messages[-1]["content"] is None: print(messages) raise ValueError("Last message must have valid content") content = messages[-1]["content"] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) response = self.client.completions.create( model=self._NAME, prompt=prompt, max_tokens=0, temperature=config.temperature, logprobs=1, echo=True, ) # Extract token_logprobs from the response if hasattr(response.choices[0], 'logprobs') and hasattr(response.choices[0].logprobs, 'token_logprobs'): all_logprobs = response.choices[0].logprobs.token_logprobs # Get token IDs for the content token_ids = self.tokenizer(text=content).input_ids num_tokens = len(token_ids) # Take the last N logprobs where N is the number of tokens in the content content_logprobs = all_logprobs[-num_tokens:] if num_tokens <= len(all_logprobs) else all_logprobs self.update(response.usage.prompt_tokens, 0, 1) return content_logprobs else: # Handle the case where logprobs structure is different raise ValueError("Response does not contain expected logprobs structure")
if __name__ == "__main__": from panda_guard.llms import LLMS print(LLMS) llm_gen_config = LLMGenerateConfig( max_n_tokens=128, temperature=1.0, logprobs=True, seed=42 ) config = OpenAiLLMConfig( model_name="meta-llama/Meta-Llama-3.1-70B-Instruct", base_url="http://172.18.129.80:8000/v1", ) llm = OpenAiLLM(config) messages = [ {"role": "user", "content": "Hello, how are you?"}, {"role": "user", "content": "Hello, how are you?"}, {"role": "user", "content": "Hello, how are you?"}, ] results = llm.evaluate_log_likelihood(messages, llm_gen_config) print(results, len(results))