Source code for panda_guard.role.defenses.repe_utils.rep_control_reading_vec

# encoding: utf-8
# Author    : Floyed<Floyed_Shen@outlook.com>
# Datetime  : 2024/9/9 16:37
# User      : yu
# Product   : PyCharm
# Project   : panda-guard
# File      : rep_control_reading_vec.py
# explain   : Adapted from https://github.com/andyzoujm/representation-engineering.git


# wrapping classes
import torch
import numpy as np


[docs]class WrappedBlock(torch.nn.Module): """ A wrapper around a transformer block or submodule that allows external activation control via injection of custom activation vectors. :param block: The transformer block or submodule to wrap. """ def __init__(self, block): super().__init__() self.block = block self.output = None self.controller = None self.mask = None self.token_pos = None self.normalize = False
[docs] def forward(self, *args, **kwargs): """ Forward pass with optional activation control. :return: Modified output from the block. """ output = self.block(*args, **kwargs) if isinstance(output, tuple): self.output = output[0] modified = output[0] else: self.output = output modified = output if self.controller is not None: norm_pre = torch.norm(modified, dim=-1, keepdim=True) if self.mask is not None: mask = self.mask # we should ignore the padding tokens when doing the activation addition # mask has ones for non padding tokens and zeros at padding tokens. # only tested this on left padding elif "position_ids" in kwargs: pos = kwargs["position_ids"] zero_indices = (pos == 0).cumsum(1).argmax(1, keepdim=True) col_indices = torch.arange(pos.size(1), device=pos.device).unsqueeze(0) target_shape = modified.shape mask = (col_indices >= zero_indices).float().reshape(target_shape[0], target_shape[1], 1) mask = mask.to(modified.dtype) else: # print(f"Warning: block {self.block_name} does not contain information 'position_ids' about token types. When using batches this can lead to unexpected results.") mask = 1.0 if len(self.controller.shape) == 1: self.controller = self.controller.reshape(1, 1, -1) assert len(self.controller.shape) == len(modified.shape), f"Shape of controller {self.controller.shape} does not match shape of modified {modified.shape}." self.controller = self.controller.to(modified.device) if isinstance(mask, torch.Tensor): mask = mask.to(modified.device) if isinstance(self.token_pos, int): modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller * mask) elif isinstance(self.token_pos, list) or isinstance(self.token_pos, tuple) or isinstance(self.token_pos, np.ndarray): modified[:, self.token_pos] = self.operator(modified[:, self.token_pos], self.controller * mask) elif isinstance(self.token_pos, str): if self.token_pos == "end": len_token = self.controller.shape[1] modified[:, -len_token:] = self.operator(modified[:, -len_token:], self.controller * mask) elif self.token_pos == "start": len_token = self.controller.shape[1] modified[:, :len_token] = self.operator(modified[:, :len_token], self.controller * mask) else: assert False, f"Unknown token position {self.token_pos}." else: modified = self.operator(modified, self.controller * mask) if self.normalize: norm_post = torch.norm(modified, dim=-1, keepdim=True) modified = modified / norm_post * norm_pre if isinstance(output, tuple): output = (modified,) + output[1:] else: output = modified return output
[docs] def set_controller(self, activations, token_pos=None, masks=None, normalize=False, operator='linear_comb'): """ Sets the external controller to inject activations. :param activations: Activation vector(s) to apply. :param token_pos: Target token positions to apply the activations. :param masks: Optional mask for activation injection. :param normalize: Whether to normalize pre/post activation. :param operator: Operator to apply (e.g. 'linear_comb', 'piecewise_linear'). """ self.normalize = normalize self.controller = activations.squeeze() self.mask = masks self.token_pos = token_pos if operator == 'linear_comb': def op(current, controller): return current + controller elif operator == 'piecewise_linear': def op(current, controller): sign = torch.sign((current * controller).sum(-1, keepdim=True)) return current + controller * sign elif operator == 'projection': def op(current, controller): raise NotImplementedError else: raise NotImplementedError(f"Operator {operator} not implemented.") self.operator = op
[docs] def reset(self): """ Resets all control variables to remove previously injected activations. """ self.output = None self.controller = None self.mask = None self.token_pos = None self.operator = None
[docs] def set_masks(self, masks): """ Sets an external mask to apply during activation injection. :param masks: The binary mask for injection. """ self.mask = masks
BLOCK_NAMES = [ "self_attn", "mlp", "input_layernorm", "post_attention_layernorm" ]
[docs]class WrappedReadingVecModel(torch.nn.Module): """ A wrapper for transformer models that enables dynamic injection and extraction of intermediate activations at specified layers and blocks. :param model: The pretrained transformer model. :param tokenizer: The tokenizer associated with the model. """ def __init__(self, model, tokenizer): super().__init__() self.model = model self.tokenizer = tokenizer
[docs] def forward(self, *args, **kwargs): """ Forwards inputs through the model. :return: Model outputs. """ return self.model(*args, **kwargs)
[docs] def generate(self, **kwargs): """ Calls the `generate` method on the underlying model. :return: Generated output. """ return self.model.generate(**kwargs)
[docs] def get_logits(self, tokens): """ Computes logits from the model given input tokens. :param tokens: Input token tensor. :return: Logits from the model. """ with torch.no_grad(): logits = self.model(tokens.to(self.model.device)).logits return logits
[docs] def run_prompt(self, prompt, **kwargs): """ Tokenizes and runs a prompt through the model without gradients. :param prompt: Input prompt string. :return: Model output. """ with torch.no_grad(): inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, max_length=512, truncation=True) input_ids = inputs.input_ids.to(self.model.device) attention_mask = inputs.attention_mask.to(self.model.device) output = self.model(input_ids, attention_mask=attention_mask) return output
[docs] def wrap(self, layer_id, block_name): """ Wraps a specific sub-block in a given layer. :param layer_id: Index of the transformer layer to wrap. :param block_name: Name of the block to wrap (e.g., 'mlp', 'self_attn'). """ assert block_name in BLOCK_NAMES if self.is_wrapped(self.model.model.layers[layer_id]): block = getattr(self.model.model.layers[layer_id].block, block_name) if not self.is_wrapped(block): setattr(self.model.model.layers[layer_id].block, block_name, WrappedBlock(block)) else: block = getattr(self.model.model.layers[layer_id], block_name) if not self.is_wrapped(block): setattr(self.model.model.layers[layer_id], block_name, WrappedBlock(block))
[docs] def wrap_decoder_block(self, layer_id): """ Wraps the entire decoder block for a given layer. :param layer_id: Index of the decoder layer to wrap. """ block = self.model.model.layers[layer_id] if not self.is_wrapped(block): self.model.model.layers[layer_id] = WrappedBlock(block)
[docs] def wrap_all(self): """ Wraps all known sub-blocks in every layer. """ for layer_id, layer in enumerate(self.model.model.layers): for block_name in BLOCK_NAMES: self.wrap(layer_id, block_name) self.wrap_decoder_block(layer_id)
[docs] def wrap_block(self, layer_ids, block_name): """ Wraps specified block(s) in the given layer(s). :param layer_ids: Single layer ID or list of layer IDs. :param block_name: Block name to wrap. """ def _wrap_block(layer_id, block_name): if block_name in BLOCK_NAMES: self.wrap(layer_id, block_name) elif block_name == 'decoder_block': self.wrap_decoder_block(layer_id) else: assert False, f"No block named {block_name}." if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray): for layer_id in layer_ids: _wrap_block(layer_id, block_name) else: _wrap_block(layer_ids, block_name)
[docs] def get_activations(self, layer_ids, block_name='decoder_block'): """ Extracts activations from wrapped blocks. :param layer_ids: Layer index or list of indices. :param block_name: Target block name. :return: Dictionary or tensor of activations. """ def _get_activations(layer_id, block_name): current_layer = self.model.model.layers[layer_id] if self.is_wrapped(current_layer): current_block = current_layer.block if block_name == 'decoder_block': return current_layer.output elif block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_block, block_name)): return getattr(current_block, block_name).output else: assert False, f"No wrapped block named {block_name}." else: if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_layer, block_name)): return getattr(current_layer, block_name).output else: assert False, f"No wrapped block named {block_name}." if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray): activations = {} for layer_id in layer_ids: activations[layer_id] = _get_activations(layer_id, block_name) return activations else: return _get_activations(layer_ids, block_name)
[docs] def set_controller(self, layer_ids, activations, block_name='decoder_block', token_pos=None, masks=None, normalize=False, operator='linear_comb'): """ Injects controller activations into specified layers/blocks. :param layer_ids: Layer index or list of indices. :param activations: Activation tensor(s) to inject. :param block_name: Block name to apply control to. :param token_pos: Target token positions. :param masks: Optional binary masks. :param normalize: Whether to normalize activations. :param operator: Operator name to apply. """ def _set_controller(layer_id, activations, block_name, masks, normalize, operator): current_layer = self.model.model.layers[layer_id] if block_name == 'decoder_block': current_layer.set_controller(activations, token_pos, masks, normalize, operator) elif self.is_wrapped(current_layer): current_block = current_layer.block if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_block, block_name)): getattr(current_block, block_name).set_controller(activations, token_pos, masks, normalize, operator) else: return f"No wrapped block named {block_name}." else: if block_name in BLOCK_NAMES and self.is_wrapped(getattr(current_layer, block_name)): getattr(current_layer, block_name).set_controller(activations, token_pos, masks, normalize, operator) else: return f"No wrapped block named {block_name}." if isinstance(layer_ids, list) or isinstance(layer_ids, tuple) or isinstance(layer_ids, np.ndarray): assert isinstance(activations, dict), "activations should be a dictionary" for layer_id in layer_ids: _set_controller(layer_id, activations[layer_id], block_name, masks, normalize, operator) else: _set_controller(layer_ids, activations, block_name, masks, normalize, operator)
[docs] def reset(self): """ Resets all internal controllers and outputs in wrapped blocks. """ for layer in self.model.model.layers: if self.is_wrapped(layer): layer.reset() for block_name in BLOCK_NAMES: if self.is_wrapped(getattr(layer.block, block_name)): getattr(layer.block, block_name).reset() else: for block_name in BLOCK_NAMES: if self.is_wrapped(getattr(layer, block_name)): getattr(layer, block_name).reset()
[docs] def set_masks(self, masks): """ Sets masks for all wrapped layers and blocks. :param masks: Binary tensor mask. """ for layer in self.model.model.layers: if self.is_wrapped(layer): layer.set_masks(masks) for block_name in BLOCK_NAMES: if self.is_wrapped(getattr(layer.block, block_name)): getattr(layer.block, block_name).set_masks(masks) else: for block_name in BLOCK_NAMES: if self.is_wrapped(getattr(layer, block_name)): getattr(layer, block_name).set_masks(masks)
[docs] def is_wrapped(self, block): """ Checks whether a given block has been wrapped. :param block: A model sub-block. :return: True if wrapped, else False. """ if hasattr(block, 'block'): return True return False
[docs] def unwrap(self): """ Restores the original model structure by removing all wrappers. """ for l, layer in enumerate(self.model.model.layers): if self.is_wrapped(layer): self.model.model.layers[l] = layer.block for block_name in BLOCK_NAMES: if self.is_wrapped(getattr(self.model.model.layers[l], block_name)): setattr(self.model.model.layers[l], block_name, getattr(self.model.model.layers[l], block_name).block)