Source code for panda_guard.llms.gemini

# encoding: utf-8
# Author    : Floyed<Floyed_Shen@outlook.com>
# Datetime  : 2024/9/2 20:30
# User      : yu
# Product   : PyCharm
# Project   : panda-guard
# File      : gemini.py
# explain   : Google Gemini API integration

import os
import time
import warnings
from typing import Dict, List, Union, Any, Tuple, Generator
from dataclasses import dataclass, field


import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold


from panda_guard.llms import BaseLLM, BaseLLMConfig, LLMGenerateConfig


[docs]@dataclass class GeminiLLMConfig(BaseLLMConfig): """ Gemini LLM Configuration. :param llm_type: Type of LLM, default is "GeminiLLM". :param model_name: Name of the model. :param api_key: API key for accessing Google AI. :param safety_settings: Custom safety settings for the model. """ llm_type: str = field(default="GeminiLLM") model_name: str = field(default="gemini-1.5-pro") api_key: str = field(default=None) safety_settings: Dict[str, str] = field(default=None)
[docs]class GeminiLLM(BaseLLM): """ Gemini LLM Implementation. :param config: Configuration for Gemini LLM. """ def __init__(self, config: GeminiLLMConfig): super().__init__(config) # Use provided API key or try to get from environment variable api_key = config.api_key or os.getenv("GOOGLE_API_KEY") if not api_key: raise ValueError( "API key must be provided or set as GOOGLE_API_KEY environment variable" ) # Configure the Gemini API genai.configure(api_key=api_key) # Create model instance self.model = genai.GenerativeModel(model_name=self._NAME) # Apply safety settings if provided if config.safety_settings: safety_settings = {} for category, threshold in config.safety_settings.items(): harm_category = getattr(HarmCategory, category, None) harm_threshold = getattr(HarmBlockThreshold, threshold, None) if harm_category and harm_threshold: safety_settings[harm_category] = harm_threshold if safety_settings: self.model = genai.GenerativeModel( model_name=self._NAME, safety_settings=safety_settings )
[docs] def generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ) -> Union[ List[Dict[str, str]], Tuple[List[Dict[str, str]], List[float]], Generator[str, None, None], ]: """ Generate a response for a given input using Google Gemini API. :param messages: List of input messages. :param config: Configuration for LLM generation. :return: Generated response, stream generator, or response with logprobs. """ retry_count = 0 max_retries = 10 while retry_count < max_retries: try: # Convert messages to Gemini format gemini_messages = [] system_prompt = None # Extract system prompt if present if messages and messages[0]["role"] == "system": system_prompt = messages[0]["content"] messages_to_process = messages[1:] else: messages_to_process = messages # Convert remaining messages for msg in messages_to_process: role = msg["role"] if role == "user": gemini_messages.append( {"role": "user", "parts": [msg["content"]]} ) elif role == "assistant": gemini_messages.append( {"role": "model", "parts": [msg["content"]]} ) # Create chat session with system prompt if available if system_prompt: chat = self.model.start_chat(system_instruction=system_prompt) else: chat = self.model.start_chat( history=gemini_messages[:-1] if gemini_messages else [] ) # Get the latest user message or use an empty one if none exists latest_user_msg = {"parts": ["Hello"]} # Default fallback if gemini_messages and gemini_messages[-1]["role"] == "user": latest_user_msg = gemini_messages[-1] # Generate response config generation_config = { "temperature": ( config.temperature if config.temperature is not None else 0.7 ), "max_output_tokens": config.max_n_tokens, } # if config.seed is not None: # generation_config["seed"] = config.seed # Handle streaming mode if config.stream: full_content = "" # For Gemini, approximate token counts prompt_tokens = ( sum(len(msg.get("content", "")) for msg in messages) // 4 ) completion_tokens = 0 # Create streaming request stream_response = chat.send_message_streaming( latest_user_msg["parts"][0], generation_config=generation_config ) def stream_generator(): nonlocal full_content, completion_tokens for chunk in stream_response: if chunk.text: content_piece = chunk.text full_content += content_piece # Rough approximation of tokens completion_tokens += len(content_piece) // 4 yield content_piece response_generator = stream_generator() def wrapped_generator(): yield from response_generator # Add final response to messages messages.append({"role": "assistant", "content": full_content}) # Update usage statistics self.update( prompt_tokens, completion_tokens, 1, ) return wrapped_generator() # Non-streaming mode (original code) else: response = chat.send_message( latest_user_msg["parts"][0], generation_config=generation_config ) content = response.text messages.append({"role": "assistant", "content": content}) # Approximate token counts (Gemini API doesn't return token counts directly) # Rough approximation: 1 token ≈ 4 characters for English text prompt_tokens = ( sum(len(msg.get("content", "")) for msg in messages[:-1]) // 4 ) completion_tokens = len(content) // 4 self.update( prompt_tokens, completion_tokens, 1, ) # Gemini API doesn't support logprobs if config.logprobs: warnings.warn( "Gemini API does not support logprobs, returning response without them." ) return messages return messages except Exception as e: # Handle safety/content policy issues if ( "safety" in str(e).lower() or "harm" in str(e).lower() or "blocked" in str(e).lower() ): messages.append( { "role": "assistant", "content": "I'm sorry, I can't help with that.", } ) print( f"API request Safety Issue, {self._NAME}, Error: {e}, returning safety message." ) return messages retry_count += 1 if retry_count >= max_retries: raise RuntimeError( f"API request failed when testing model {self._NAME}, tried: {max_retries}, Error: {e}" ) else: print( f"API request failed when testing model {self._NAME}, retrying {retry_count}/{max_retries}... Error: {e}" ) time.sleep(retry_count) # Exponential backoff return None
[docs] def continual_generate( self, messages: List[Dict[str, str]], config: LLMGenerateConfig ): """ Generate continuation for the last message. :param messages: List of messages for input. :param config: Configuration for generation. :return: Generated response. """ # Similar to Claude, implement continuation by appending to last message convo_messages = messages.copy() last_assistant_content = "" if convo_messages[-1]["role"] == "assistant": last_assistant_content = convo_messages[-1]["content"] convo_messages.pop() # Add a user message requesting continuation convo_messages.append( {"role": "user", "content": "Please continue your previous response."} ) # Generate continuation result = self.generate(convo_messages, config) # Merge with original content if isinstance(result, tuple): cont_messages, logprobs = result cont_content = cont_messages[-1]["content"] messages[-1]["content"] = last_assistant_content + " " + cont_content return messages, logprobs else: cont_content = result[-1]["content"] messages[-1]["content"] = last_assistant_content + " " + cont_content return messages else: warnings.warn( "The last message must be from the assistant to use continual_generate." ) return self.generate(messages, config)
[docs] def evaluate_log_likelihood( self, messages: List[Dict[str, str]], config: LLMGenerateConfig, require_grad=False, ) -> List[float]: """ Evaluate the log likelihood of the given messages. :param messages: List of messages for evaluation. :param config: Configuration for LLM generation. :param require_grad: Whether to compute gradients (not supported for API models). :raises NotImplementedError: Gemini API does not support log likelihood evaluation. """ raise NotImplementedError( "Gemini API does not support log likelihood evaluation." )
if __name__ == "__main__": from panda_guard.llms import LLMS print(LLMS) llm_gen_config = LLMGenerateConfig( max_n_tokens=128, temperature=1.0, logprobs=False, seed=42 ) config = GeminiLLMConfig( model_name="gemini-1.5-pro", api_key=os.getenv("GOOGLE_API_KEY") ) llm = GeminiLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"}, ] result = llm.generate(messages, llm_gen_config) print(result[-1]["content"])