# encoding: utf-8
# Author : Floyed<Floyed_Shen@outlook.com>
# Datetime : 2024/9/2 20:30
# User : yu
# Product : PyCharm
# Project : panda-guard
# File : claude.py
# explain : Anthropic Claude API integration
import os
import time
import warnings
from typing import Dict, List, Union, Any, Tuple, Generator
from dataclasses import dataclass, field
import anthropic
from panda_guard.llms import BaseLLM, BaseLLMConfig, LLMGenerateConfig
[docs]@dataclass
class ClaudeLLMConfig(BaseLLMConfig):
"""
Claude LLM Configuration.
:param llm_type: Type of LLM, default is "ClaudeLLM".
:param model_name: Name of the model.
:param api_key: API key for accessing Anthropic.
:param max_tokens_to_sample: Maximum tokens to sample, overrides max_n_tokens if provided.
"""
llm_type: str = field(default="ClaudeLLM")
model_name: str = field(default="claude-3-opus-20240229")
api_key: str = field(default=None)
max_tokens_to_sample: int = field(default=None)
[docs]class ClaudeLLM(BaseLLM):
"""
Claude LLM Implementation.
:param config: Configuration for Claude LLM.
"""
def __init__(self, config: ClaudeLLMConfig):
super().__init__(config)
# Use provided API key or try to get from environment variable
api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("API key must be provided or set as ANTHROPIC_API_KEY environment variable")
self.client = anthropic.Anthropic(api_key=api_key)
self.max_tokens_to_sample = config.max_tokens_to_sample
[docs] def generate(
self, messages: List[Dict[str, str]], config: LLMGenerateConfig
) -> list[dict[str, str]] | Generator[str, None, None] | None:
"""
Generate a response for a given input using Anthropic Claude API.
:param messages: List of input messages.
:param config: Configuration for LLM generation.
:return: Generated response, stream generator, or response with logprobs.
"""
max_tokens = self.max_tokens_to_sample or config.max_n_tokens
retry_count = 0
max_retries = 10
while retry_count < max_retries:
try:
# Convert messages to Anthropic format if needed
anthropic_messages = []
system_content = None
for msg in messages:
role = msg["role"]
# Anthropic uses "user" and "assistant", convert "system" to content for first user message
if role == "system":
system_content = msg["content"]
continue
elif role == "user":
anthropic_messages.append({"role": "user", "content": msg["content"]})
elif role == "assistant":
anthropic_messages.append({"role": "assistant", "content": msg["content"]})
# If we had a system message and there are other messages, prepend to first user message
if system_content and anthropic_messages and anthropic_messages[0]["role"] == "user":
anthropic_messages[0]["content"] = f"{system_content}\n\n{anthropic_messages[0]['content']}"
# Handle streaming mode
if config.stream:
full_content = ""
input_tokens = 0
output_tokens = 0
# Create streaming request
stream = self.client.messages.create(
model=self._NAME,
messages=anthropic_messages,
max_tokens=max_tokens,
temperature=config.temperature or 0.7,
stream=True,
)
def stream_response():
nonlocal full_content, input_tokens, output_tokens
for chunk in stream:
# Check if the chunk has content to stream
if chunk.type == "content_block_delta" and chunk.delta.text:
content_piece = chunk.delta.text
full_content += content_piece
yield content_piece
# Update token usage if available
if hasattr(chunk, 'usage') and chunk.usage:
if hasattr(chunk.usage, 'input_tokens'):
input_tokens = chunk.usage.input_tokens
if hasattr(chunk.usage, 'output_tokens'):
output_tokens = chunk.usage.output_tokens
# Handle end of stream with usage info
if chunk.type == "message_delta" and chunk.usage:
input_tokens = chunk.usage.input_tokens
output_tokens = chunk.usage.output_tokens
response_generator = stream_response()
def wrapped_generator():
yield from response_generator
# Add final response to messages
messages.append({"role": "assistant", "content": full_content})
# Update usage statistics
self.update(
input_tokens,
output_tokens,
1,
)
return wrapped_generator()
# Non-streaming mode (original code)
else:
# Claude API call
response = self.client.messages.create(
model=self._NAME,
messages=anthropic_messages,
max_tokens=max_tokens,
temperature=config.temperature or 0.7,
)
# Add generated response to messages
content = response.content[0].text
messages.append({"role": "assistant", "content": content})
# Update token usage statistics
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
self.update(
input_tokens,
output_tokens,
1,
)
# Claude API doesn't support logprobs directly, so we handle the case
if config.logprobs:
warnings.warn("Claude API does not support logprobs, returning response without them.")
return messages
return messages
except Exception as e:
# Handle safety/content policy issues
if "content policy" in str(e).lower() or "content_policy" in str(e).lower():
messages.append({"role": "assistant", "content": "I'm sorry, I can't help with that."})
print(f"API request Content Policy Issue, {self._NAME}, Error: {e}, returning safety message.")
return messages
retry_count += 1
if retry_count >= max_retries:
raise RuntimeError(
f"API request failed when testing model {self._NAME}, tried: {max_retries}, Error: {e}")
else:
print(
f"API request failed when testing model {self._NAME}, retrying {retry_count}/{max_retries}... Error: {e}")
time.sleep(retry_count) # Exponential backoff
return None
[docs] def continual_generate(
self, messages: List[Dict[str, str]], config: LLMGenerateConfig
):
"""
Generate continuation for the last message.
:param messages: List of messages for input.
:param config: Configuration for generation.
:return: Generated response.
"""
# Claude doesn't support true "continue generating" functionality like some models
# Instead, we extract current conversation and generate a continuation
# Clone messages and extract the last assistant message if it exists
convo_messages = messages.copy()
last_assistant_content = ""
# If the last message is from the assistant, we'll use it as context for continuation
if convo_messages[-1]["role"] == "assistant":
last_assistant_content = convo_messages[-1]["content"]
convo_messages.pop() # Remove last message as we'll append to it
# Add a user message asking to continue
convo_messages.append({"role": "user", "content": "Please continue from where you left off."})
# Generate the continuation
result = self.generate(convo_messages, config)
# Merge the continuation with the original message
if isinstance(result, tuple): # If logprobs were requested
cont_messages, logprobs = result
cont_content = cont_messages[-1]["content"]
messages[-1]["content"] = last_assistant_content + " " + cont_content
return messages, logprobs
else:
cont_content = result[-1]["content"]
messages[-1]["content"] = last_assistant_content + " " + cont_content
return messages
else:
warnings.warn("The last message must be from the assistant to use continual_generate.")
return self.generate(messages, config)
[docs] def evaluate_log_likelihood(
self,
messages: List[Dict[str, str]],
config: LLMGenerateConfig,
require_grad=False
) -> List[float]:
"""
Evaluate the log likelihood of the given messages.
:param messages: List of messages for evaluation.
:param config: Configuration for LLM generation.
:param require_grad: Whether to compute gradients (not supported for API models).
:raises NotImplementedError: Claude API does not support log likelihood evaluation.
"""
raise NotImplementedError(
"Claude API does not support log likelihood evaluation."
)
if __name__ == "__main__":
from panda_guard.llms import LLMS
print(LLMS)
llm_gen_config = LLMGenerateConfig(
max_n_tokens=128, temperature=1.0, logprobs=True, seed=42
)
config = ClaudeLLMConfig(
model_name="claude-3-opus-20240229",
api_key=os.getenv("ANTHROPIC_API_KEY")
)
llm = ClaudeLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
result = llm.generate(messages, llm_gen_config)
print(result[-1]["content"])