Source code for panda_guard.role.attacks.gcg

from panda_guard.role.attacks import BaseAttacker, BaseAttackerConfig

from panda_guard.llms import create_llm, BaseLLMConfig, LLMGenerateConfig
import functools
import gc
import inspect
import transformers
import torch
import copy
from dataclasses import fields
from torch import Tensor
from tqdm import tqdm
from panda_guard.role.judges.rule_based import *

[docs]@dataclass class GCGAttackerConfig(BaseAttackerConfig): """ Configuration for the GCG Attacker. :param attacker_cls: Class of the attacker, default is "GCGAttacker". :param attacker_name: Name of the attacker. :param topk: Each time the ID of the k largest one-hot vector is selected before the gradient. :param adv_string_init: Initialization versus suffix :param num_steps: Number of iteration steps :param use_prefix_cache: Whether to cache content before the suffix :param early_stop: Whether to stop after judging that the attack is successful """ attacker_cls: str = field(default="GCGAttacker") attacker_name: str = field(default=None) search_width: int = field(default=512) batch_size: int = field(default=None) topk: int = field(default=256) adv_string_init: str = field(default="x x x x x x x x x x x x x x x x x x x x x x x x x x x x x x") num_steps: int = field(default=250) llm_config: BaseLLMConfig = field(default_factory=BaseLLMConfig) llm_gen_config: LLMGenerateConfig = field(default=None) n_replace: int = field(default=1) buffer_size: int = field(default=0) use_mellowmax: bool = field(default=False) mellowmax_alpha: float = field(default=1.0) early_stop: bool = field(default=False) use_prefix_cache: bool = field(default=True) allow_non_ascii: bool = field(default=False) filter_ids: bool = field(default=True) add_space_before_target: bool = field(default=False) seed: int = field(default=None)
[docs]class AttackBuffer: """ A buffer to store and manage the best (lowest loss) optimization IDs. :param size: Maximum number of entries the buffer can hold. """ def __init__(self, size: int): """ Initializes the buffer. :param size: The maximum number of entries in the buffer. """ self.buffer = [] # (loss: float, optim_ids: Tensor) self.size = size
[docs] def add(self, loss: float, optim_ids: Tensor) -> None: """ Adds a new entry to the buffer. :param loss: The loss value of the current optimization. :param optim_ids: The optimized token IDs corresponding to the loss. """ if self.size == 0: self.buffer = [(loss, optim_ids)] return if len(self.buffer) < self.size: self.buffer.append((loss, optim_ids)) else: self.buffer[-1] = (loss, optim_ids) self.buffer.sort(key=lambda x: x[0])
[docs] def get_best_ids(self) -> Tensor: """ Retrieves the optimized token IDs with the lowest loss. :return: The token IDs with the lowest loss. """ return self.buffer[0][1]
[docs] def get_lowest_loss(self) -> float: """ Retrieves the lowest loss value. :return: The lowest loss value. """ return self.buffer[0][0]
[docs] def get_highest_loss(self) -> float: """ Retrieves the highest loss value. :return: The highest loss value. """ return self.buffer[-1][0]
[docs]class GCGAttacker(BaseAttacker): """ GCG Attacker Implementation Reference: https://arxiv.org/abs/2307.15043 :param config: Configuration for the GCG Attacker. """ def __init__( self, config: GCGAttackerConfig ): """ Initializes the GCGAttacker with the provided configuration. :param config: Configuration object for the attacker. """ super().__init__(config) for field in fields(config): field_name = field.name field_value = getattr(config, field_name) if "llm_gen_config" in field_name: self.llm_gen_config = config.llm_gen_config elif "llm_config" in field_name: self.llm = create_llm(config.llm_config) else: setattr(self, field_name, field_value) self.embedding_layer = self.llm.model.get_input_embeddings() self.not_allowed_ids = None if config.allow_non_ascii else self.get_nonascii_toks() self.prefix_cache = None self.stop_flag = False self.INIT_CHARS = [ ".", ",", "!", "?", ";", ":", "(", ")", "[", "]", "{", "}", "@", "#", "$", "%", "&", "*", "w", "x", "y", "z", ]
[docs] def get_nonascii_toks(self): """ Returns a list of non-ASCII tokens in the tokenizer's vocabulary. :return: Tensor containing the IDs of non-ASCII tokens. """ device = self.llm.model.device def is_ascii(s): return s.isascii() and s.isprintable() nonascii_toks = [] for i in range(self.llm.tokenizer.vocab_size): if not is_ascii(self.llm.tokenizer.decode([i])): nonascii_toks.append(i) if self.llm.tokenizer.bos_token_id is not None: nonascii_toks.append(self.llm.tokenizer.bos_token_id) if self.llm.tokenizer.eos_token_id is not None: nonascii_toks.append(self.llm.tokenizer.eos_token_id) if self.llm.tokenizer.pad_token_id is not None: nonascii_toks.append(self.llm.tokenizer.pad_token_id) if self.llm.tokenizer.unk_token_id is not None: nonascii_toks.append(self.llm.tokenizer.unk_token_id) return torch.tensor(nonascii_toks, device=device)
[docs] def attack( self, messages: List[Dict[str, str]], **kwargs ) -> List[Dict[str, str]]: """ Execute GCG attack. :param messages: List of messages in the conversation. :param kwargs: Additional parameters for the attack, must include "request_reformulated". :return: Modified list of messages with the reformulated request. """ instruction = [copy.deepcopy(messages[0])] target = 'Sure, here is' instruction[-1]["content"] = instruction[-1]["content"] + " " + self.adv_string_init prompt_formatted = self.llm.tokenizer.apply_chat_template(instruction, tokenize=False, add_generation_prompt=True) before_str, after_str = prompt_formatted.split(f"{self.adv_string_init}") target = " " + target if self.add_space_before_target else target before_ids = self.llm.tokenizer([before_str], padding=False, return_tensors="pt")["input_ids"].to(self.llm.model.device, torch.int64) after_ids = self.llm.tokenizer([after_str], add_special_tokens=False, return_tensors="pt")["input_ids"].to( self.llm.model.device, torch.int64) target_ids = self.llm.tokenizer([target], add_special_tokens=False, return_tensors="pt")["input_ids"].to( self.llm.model.device, torch.int64) before_embeds, after_embeds, target_embeds = [self.embedding_layer(ids) for ids in (before_ids, after_ids, target_ids)] if self.use_prefix_cache: with torch.no_grad(): output = self.llm.model(inputs_embeds=before_embeds, use_cache=True) self.prefix_cache = output.past_key_values self.target_ids = target_ids self.before_embeds = before_embeds self.after_embeds = after_embeds self.target_embeds = target_embeds buffer = self.init_buffer() optim_ids = buffer.get_best_ids() losses = [] optim_strings = [] for _ in tqdm(range(self.num_steps)): # Compute the token gradient optim_ids_onehot_grad = self.compute_token_gradient(optim_ids) # [1, len(optim_ids), len(embeds)] with torch.no_grad(): # Sample candidate token sequences based on the token gradient sampled_ids = self.sample_ids_from_grad( optim_ids.squeeze(0), optim_ids_onehot_grad.squeeze(0), self.search_width, self.topk, self.n_replace, not_allowed_ids=self.not_allowed_ids, ) if self.filter_ids: sampled_ids = self.filter_ids_op(sampled_ids, self.llm.tokenizer) new_search_width = sampled_ids.shape[0] # Compute loss on all candidate sequences batch_size = new_search_width if self.batch_size is None else self.batch_size if self.prefix_cache: input_embeds = torch.cat([ self.embedding_layer(sampled_ids), after_embeds.repeat(new_search_width, 1, 1), target_embeds.repeat(new_search_width, 1, 1), ], dim=1) else: input_embeds = torch.cat([ before_embeds.repeat(new_search_width, 1, 1), self.embedding_layer(sampled_ids), after_embeds.repeat(new_search_width, 1, 1), target_embeds.repeat(new_search_width, 1, 1), ], dim=1) loss = self.find_executable_batch_size(self.compute_candidates_loss, batch_size)(input_embeds) current_loss = loss.min().item() optim_ids = sampled_ids[loss.argmin()].unsqueeze(0) losses.append(current_loss) if buffer.size == 0 or current_loss < buffer.get_highest_loss(): buffer.add(current_loss, optim_ids) optim_ids = buffer.get_best_ids() optim_str = self.llm.tokenizer.batch_decode(optim_ids)[0] optim_strings.append(optim_str) if self.stop_flag: print("Early stopping due to finding a perfect match.") break min_loss_index = losses.index(min(losses)) messages[0]["content"] = messages[0]["content"] + " " + optim_strings[min_loss_index] return [messages[0]]
[docs] def init_buffer(self) -> AttackBuffer: model = self.llm.model tokenizer = self.llm.tokenizer # Create the attack buffer and initialize the buffer ids buffer = AttackBuffer(self.buffer_size) if isinstance(self.adv_string_init, str): init_optim_ids = tokenizer(self.adv_string_init, add_special_tokens=False, return_tensors="pt")[ "input_ids"].to(model.device) if self.buffer_size > 1: init_buffer_ids = tokenizer(self.INIT_CHARS, add_special_tokens=False, return_tensors="pt")[ "input_ids"].squeeze().to(model.device) init_indices = torch.randint(0, init_buffer_ids.shape[0], (self.buffer_size - 1, init_optim_ids.shape[1])) init_buffer_ids = torch.cat([init_optim_ids, init_buffer_ids[init_indices]], dim=0) else: init_buffer_ids = init_optim_ids true_buffer_size = max(1, self.buffer_size) # Compute the loss on the initial buffer entries if self.prefix_cache: init_buffer_embeds = torch.cat([ self.embedding_layer(init_buffer_ids), self.after_embeds.repeat(true_buffer_size, 1, 1), self.target_embeds.repeat(true_buffer_size, 1, 1), ], dim=1) else: init_buffer_embeds = torch.cat([ self.before_embeds.repeat(true_buffer_size, 1, 1), self.embedding_layer(init_buffer_ids), self.after_embeds.repeat(true_buffer_size, 1, 1), self.target_embeds.repeat(true_buffer_size, 1, 1), ], dim=1) init_buffer_losses = self.find_executable_batch_size(self.compute_candidates_loss, true_buffer_size)( init_buffer_embeds) # Populate the buffer for i in range(true_buffer_size): buffer.add(init_buffer_losses[i], init_buffer_ids[[i]]) print("Initialized attack buffer.") return buffer
[docs] def compute_token_gradient( self, optim_ids: Tensor, ) -> Tensor: """ Computes the gradient of the GCG loss with respect to the one-hot token matrix. :param optim_ids: Tensor, shape = (1, n_optim_ids), token IDs being optimized. :return: Tensor, shape = (1, n_optim_ids, vocab_size), gradient of the loss with respect to the one-hot token matrix. """ model = self.llm.model embedding_layer = self.embedding_layer # Create the one-hot encoding matrix of our optimized token ids optim_ids_onehot = torch.nn.functional.one_hot(optim_ids, num_classes=embedding_layer.num_embeddings) optim_ids_onehot = optim_ids_onehot.to(model.device, model.dtype) optim_ids_onehot.requires_grad_() # (1, num_optim_tokens, vocab_size) @ (vocab_size, embed_dim) -> (1, num_optim_tokens, embed_dim) optim_embeds = optim_ids_onehot @ embedding_layer.weight if self.prefix_cache: input_embeds = torch.cat([optim_embeds, self.after_embeds, self.target_embeds], dim=1) output = model(inputs_embeds=input_embeds, past_key_values=self.prefix_cache) else: input_embeds = torch.cat([self.before_embeds, optim_embeds, self.after_embeds, self.target_embeds], dim=1) output = model(inputs_embeds=input_embeds) logits = output.logits # Shift logits so token n-1 predicts token n shift = input_embeds.shape[1] - self.target_ids.shape[1] shift_logits = logits[..., shift - 1:-1, :].contiguous() # (1, num_target_ids, vocab_size) shift_labels = self.target_ids if self.use_mellowmax: label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1) loss = self.mellowmax(-label_logits, alpha=self.mellowmax_alpha, dim=-1) else: loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) optim_ids_onehot_grad = torch.autograd.grad(outputs=[loss], inputs=[optim_ids_onehot])[0] return optim_ids_onehot_grad
[docs] def compute_candidates_loss( self, search_batch_size: int, input_embeds: Tensor, ) -> Tensor: """ Computes the GCG loss on all candidate token ID sequences. :param search_batch_size: int, number of candidate sequences to evaluate in a given batch. :param input_embeds: Tensor, shape = (search_width, seq_len, embd_dim), embeddings of the candidate sequences to evaluate. """ all_loss = [] prefix_cache_batch = [] for i in range(0, input_embeds.shape[0], search_batch_size): with torch.no_grad(): input_embeds_batch = input_embeds[i:i + search_batch_size] current_batch_size = input_embeds_batch.shape[0] if self.prefix_cache: if not prefix_cache_batch or current_batch_size != search_batch_size: prefix_cache_batch = [[x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]] for i in range(len(self.prefix_cache))] outputs = self.llm.model(inputs_embeds=input_embeds_batch, past_key_values=prefix_cache_batch) else: outputs = self.llm.model(inputs_embeds=input_embeds_batch) logits = outputs.logits tmp = input_embeds.shape[1] - self.target_ids.shape[1] shift_logits = logits[..., tmp - 1:-1, :].contiguous() shift_labels = self.target_ids.repeat(current_batch_size, 1) if self.use_mellowmax: label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1) loss = self.mellowmax(-label_logits, alpha=self.mellowmax_alpha, dim=-1) else: loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none") loss = loss.view(current_batch_size, -1).mean(dim=-1) all_loss.append(loss) if self.early_stop: if torch.any(torch.all(torch.argmax(shift_logits, dim=-1) == shift_labels, dim=-1)).item(): self.stop_flag = True del outputs gc.collect() torch.cuda.empty_cache() return torch.cat(all_loss, dim=0)
[docs] def filter_ids_op(self, ids: Tensor, tokenizer: transformers.PreTrainedTokenizer): """ Filters out sequences of token IDs that change after retokenization. :param ids: Tensor, shape = (search_width, n_optim_ids), token IDs to evaluate. :param tokenizer: transformers.PreTrainedTokenizer, the model's tokenizer. :return: Tensor, shape = (new_search_width, n_optim_ids), token IDs that remain the same after retokenization. """ ids_decoded = tokenizer.batch_decode(ids) filtered_ids = [] for i in range(len(ids_decoded)): # Retokenize the decoded token ids ids_encoded = \ tokenizer(ids_decoded[i], return_tensors="pt", add_special_tokens=False).to(ids.device)["input_ids"][0] if torch.equal(ids[i], ids_encoded): filtered_ids.append(ids[i]) if not filtered_ids: # This occurs in some cases, e.g. using the Llama-3 tokenizer with a bad initialization raise RuntimeError( "No token sequences are the same after decoding and re-encoding. " "Consider setting `filter_ids=False` or trying a different `optim_str_init`" ) return torch.stack(filtered_ids)
[docs] def sample_ids_from_grad( self, ids: Tensor, grad: Tensor, search_width: int, topk: int = 256, n_replace: int = 1, not_allowed_ids: Tensor = False, ): """ Returns `search_width` combinations of token IDs based on the token gradient. :param ids: Tensor, shape = (n_optim_ids), the sequence of token IDs being optimized. :param grad: Tensor, shape = (n_optim_ids, vocab_size), the gradient of the GCG loss with respect to the one-hot token embeddings. :param search_width: int, the number of candidate sequences to return. :param topk: int, the number of top-k tokens to sample from the gradient. :param n_replace: int, the number of token positions to update per sequence. :param not_allowed_ids: Tensor, shape = (n_ids), token IDs that should not be used in optimization. :return: Tensor, shape = (search_width, n_optim_ids), sampled token IDs. """ n_optim_tokens = len(ids) original_ids = ids.repeat(search_width, 1) if not_allowed_ids is not None: grad[:, not_allowed_ids.to(grad.device)] = float("inf") topk_ids = (-grad).topk(topk, dim=1).indices sampled_ids_pos = torch.argsort(torch.rand((search_width, n_optim_tokens), device=grad.device))[..., :n_replace] sampled_ids_val = torch.gather( topk_ids[sampled_ids_pos], 2, torch.randint(0, topk, (search_width, n_replace, 1), device=grad.device) ).squeeze(2) new_ids = original_ids.scatter_(1, sampled_ids_pos, sampled_ids_val) return new_ids
[docs] def should_reduce_batch_size(self, exception: Exception) -> bool: """ Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory. :param exception: `Exception`, the exception to check. """ _statements = [ "CUDA out of memory.", # CUDA OOM "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU "DefaultCPUAllocator: can't allocate memory", # CPU OOM ] if isinstance(exception, RuntimeError) and len(exception.args) == 1: return any(err in exception.args[0] for err in _statements) return False
# modified from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L87
[docs] def find_executable_batch_size(self, function: callable = None, starting_batch_size: int = 128): """ A basic decorator that executes `function`. If it fails due to out-of-memory or CUDNN errors, the batch size is halved and retried. `function` must accept a `batch_size` parameter as its first argument. :param function: callable, the function to wrap. :param starting_batch_size: int, the initial batch size to try. Example: # Example usage of the decorator @find_executable_batch_size def train_model(batch_size: int): # Train model logic here train_model(256) """ if function is None: return functools.partial(self.find_executable_batch_size, starting_batch_size=starting_batch_size) batch_size = starting_batch_size def decorator(*args, **kwargs): nonlocal batch_size gc.collect() torch.cuda.empty_cache() params = list(inspect.signature(function).parameters.keys()) # Guard against user error if len(params) < (len(args) + 1): arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) raise TypeError( f"Batch size was passed into `{function.__name__}` as the first argument when called." f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" ) while True: if batch_size == 0: raise RuntimeError("No executable batch size found, reached zero.") try: return function(batch_size, *args, **kwargs) except Exception as e: if self.should_reduce_batch_size(e): gc.collect() torch.cuda.empty_cache() batch_size //= 2 else: raise return decorator
[docs] def mellowmax(self, t: Tensor, alpha=1.0, dim=-1): return 1.0 / alpha * (torch.logsumexp(alpha * t, dim=dim) - torch.log( torch.tensor(t.shape[-1], dtype=t.dtype, device=t.device)))