Instructions to use InstaDeepAI/ChatNT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use InstaDeepAI/ChatNT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="InstaDeepAI/ChatNT", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("InstaDeepAI/ChatNT", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use InstaDeepAI/ChatNT with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "InstaDeepAI/ChatNT" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "InstaDeepAI/ChatNT", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/InstaDeepAI/ChatNT
- SGLang
How to use InstaDeepAI/ChatNT with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "InstaDeepAI/ChatNT" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "InstaDeepAI/ChatNT", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "InstaDeepAI/ChatNT" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "InstaDeepAI/ChatNT", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use InstaDeepAI/ChatNT with Docker Model Runner:
docker model run hf.co/InstaDeepAI/ChatNT
| # This file stores ChatNT and all associated layers and configs | |
| from dataclasses import asdict, dataclass, field | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F # noqa: N812 | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| class RotaryEmbeddingConfig: | |
| """ | |
| Rotary Positional Embedding configuration | |
| max_seq_len: The number of positions to encode and cache. | |
| dim: Dimension of RoPE. | |
| theta: Rotation angle. | |
| """ | |
| max_seq_len: int | |
| dim: int | |
| theta: float | |
| class PerceiverResamplerConfig: | |
| """ | |
| Parameters to initialize an PerceiverResampler model. | |
| Args: | |
| emb_layer_norm_before: Whether to use layer norm before the first attention | |
| layer. | |
| attention_heads: Number of attention heads. | |
| key_size: The dimension of the query, key, and values within each attention | |
| head, if not specified, it is set to attention_heads//embed_dim. | |
| It can be useful to set a custom key size if we want to impose the size of | |
| the query, key and value tensor ( for example, tensors shaped with | |
| power of 2 are more efficiently handled on TPUs ). | |
| Note: Parametrizing the model with a custom key size has been done in : | |
| Brown, Tom, et al. "Language models are few-shot learners." | |
| Advances in neural information processing systems 33 (2020): 1877-1901. | |
| embed_dim: Embedding dimension. | |
| ffn_embed_dim: Feed forward embedding dimension. | |
| num_layers: Number of attention blocks. | |
| ffn_activation_name: Activation function to be used in FFN block. Supported | |
| names are "gelu", "relu", "swish". | |
| use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed | |
| Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg | |
| to True and use swish as ffn_activation_name. | |
| Same principle for a gated-relu. To keep the same number of parameters in | |
| the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. | |
| See https://arxiv.org/pdf/2002.05202.pdf for more details. | |
| resampled_length: length of the resampled output of the module | |
| use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint | |
| gradients in the forward pass to reduce the computation in the backward). | |
| """ | |
| # architecture | |
| emb_layer_norm_before: bool = False | |
| attention_heads: int = 20 | |
| key_size: Optional[int] = None | |
| embed_dim: int = 1280 | |
| ffn_embed_dim: int = 5120 | |
| num_layers: int = 24 | |
| add_bias_kv: bool = False | |
| add_bias_ffn: bool = True | |
| ffn_activation_name: str = "gelu-no-approx" | |
| use_glu_in_ffn: bool = False | |
| resampled_length: int = 64 | |
| # performance | |
| use_gradient_checkpointing: bool = False | |
| def __post_init__(self) -> None: | |
| """ | |
| Checks that the given values are compatible. | |
| """ | |
| if self.key_size is None: | |
| if not self.embed_dim % self.attention_heads == 0: | |
| raise ValueError( | |
| f"When no key size is provided, the embedding dimension should be " | |
| f"divisible by the number of heads, however provided embedding " | |
| f"dimension is {self.embed_dim} and the number of heads is " | |
| f"{self.attention_heads}." | |
| ) | |
| self.key_size = self.embed_dim // self.attention_heads | |
| class GptConfig: | |
| """ | |
| Parameters to initialize a Gpt model. | |
| NOTE: the pad token is not defined | |
| Args: | |
| vocab_size: Token vocabulary. | |
| eos_token_id: used to stop sentence generation | |
| embed_dim: Embedding dimension. | |
| ffn_embed_dim: Feed forward embedding dimension. | |
| num_heads: Number of attention heads. | |
| num_kv_heads: Number of key and value heads to support Grouped-Query and | |
| Multi-Query Attention. If None, the number of key and value heads is | |
| equal to the number of attention heads. | |
| num_layers: Number of Decoder layer_stack | |
| rope_config: The configuration for the rotary positional embeddings | |
| add_bias_ffn: Add bias in feed forward network block. | |
| ffn_activation_name: Activation function to be used in FFN block. Supported | |
| names are "gelu", "gelu-no-approx", "relu", "swish". | |
| use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed | |
| Forward Network (FFN) block. | |
| example: To do a swiGLU (gated-swish) put this arg | |
| to True and use swish as ffn_activation_name. | |
| Same principle for a gated-relu. | |
| add_bias_lm_head: whether to use bias in the final LM layer | |
| norm_type: The type of norm used ( pre normalization scheme ) used. can be | |
| one of ["layer_norm", "RMS_norm"] | |
| parallel_attention_ff: Whether to do the attention and the MLP in parallel, | |
| and then sum up the results as it is done in Gpt-NeoX : | |
| Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive | |
| language model." arXiv preprint arXiv:2204.06745 (2022). | |
| It is said to improve the training time of 15% when compiling with JAX | |
| use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint | |
| gradients in the forward pass to reduce the computation in the backward). | |
| add_bias_attn: Add bias to the attention mechanism (key, query, value, and | |
| output projections). | |
| """ | |
| # vocabulary | |
| vocab_size: int | |
| eos_token_id: int | |
| # architecture | |
| embed_dim: int = 16 | |
| ffn_embed_dim: int = 64 | |
| num_heads: int = 2 | |
| num_kv_heads: Optional[int] = None | |
| num_layers: int = 2 | |
| rope_config: RotaryEmbeddingConfig = field( | |
| default_factory=lambda: RotaryEmbeddingConfig( | |
| max_seq_len=512, dim=8, theta=10000.0 | |
| ) | |
| ) | |
| add_bias_ffn: bool = False | |
| ffn_activation_name: str = "swish" | |
| use_glu_in_ffn: bool = True | |
| add_bias_lm_head: bool = False | |
| norm_type: str = "RMS_norm" | |
| rms_norm_eps: float = 1e-6 | |
| parallel_attention_ff: bool = True | |
| # inference / backward behavior | |
| use_gradient_checkpointing: bool = False | |
| # architecture params with default values | |
| add_bias_attn: bool = False | |
| def __post_init__(self) -> None: | |
| """ | |
| Checks that the given values are compatible. | |
| """ | |
| if not self.embed_dim % self.num_heads == 0: | |
| raise ValueError( | |
| f"The embedding dimension should be " | |
| f"divisible by the number of heads, however provided embedding " | |
| f"dimension is {self.embed_dim} and the number of heads is " | |
| f"{self.num_heads}." | |
| ) | |
| if not self.embed_dim // self.num_heads > 1: | |
| raise ValueError( | |
| "embed_dim / num_heads must be higher than 2 to apply rotary embeddings" | |
| ) | |
| if not self.embed_dim // self.num_heads >= self.rope_config.dim: | |
| raise ValueError( | |
| "embed_dim // num_heads must be higher than rope_config.dim " | |
| "to apply rotary embeddings" | |
| ) | |
| def to_dict(self): # type: ignore | |
| output = asdict(self) | |
| output["rope_config"] = asdict(self.rope_config) | |
| return output | |
| class NucleotideTransformerConfig: | |
| """ | |
| Parameters to initialize an NT model. | |
| Args: | |
| alphabet_size: Token vocabulary. | |
| pad_token_id: ID of pad token. | |
| mask_token_id: ID of mask token. | |
| max_positions: Maximum sequence length. | |
| embed_scale: Correction ratio applied to the embeddings to make up for the | |
| norm difference between the input during training and inference. | |
| emb_layer_norm_before: Whether to use layer norm before the first attention | |
| layer. | |
| attention_heads: Number of attention heads. | |
| key_size: The dimension of the query, key, and values within each attention | |
| head, if not specified, it is set to attention_heads//embed_dim. | |
| It can be useful to set a custom key size if we want to impose the size of | |
| the query, key and value tensor ( for example, tensors shaped with | |
| power of 2 are more efficiently handled on TPUs ). | |
| Note: Parametrizing the model with a custom key size has been done in : | |
| Brown, Tom, et al. "Language models are few-shot learners." | |
| Advances in neural information processing systems 33 (2020): 1877-1901. | |
| embed_dim: Embedding dimension. | |
| ffn_embed_dim: Feed forward embedding dimension. | |
| num_layers: Number of attention blocks. | |
| positional_embedding: Type of positional embedding to use before the first | |
| attention layer. Options: "learned", "learned_standard" "sinusoidal" or | |
| None. | |
| NOTE: "learned" is the positional embedding of ESM, and "learned_standard" | |
| is a more standard one, used for example in DNAbert. | |
| lm_head: type of language model head. Options: "simple", "roberta" or None. | |
| add_bias_kv: Add bias in attention layer. | |
| add_bias_ffn: Add bias in feed forward network block. | |
| use_rotary_embedding: Whether to use rotary embeddings. Requires: | |
| positional_embeddings = None. | |
| rescaling_factor: Scaling factor to use for rotary embeddings. | |
| ffn_activation_name: Activation function to be used in FFN block. Supported | |
| names are "gelu", "relu", "swish". | |
| use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed | |
| Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg | |
| to True and use swish as ffn_activation_name. | |
| Same principle for a gated-relu. To keep the same number of parameters in | |
| the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU. | |
| See https://arxiv.org/pdf/2002.05202.pdf for more details. | |
| mask_before_attention: Use mask before attention layers. | |
| layer_norm_eps: the eps factor in the different layer norms of the model (refer | |
| to layer norm implementation) | |
| token_dropout: Token dropout. | |
| masking_ratio: Masking ratio (used if token dropout is enabled). | |
| masking_prob: Masking probability (used if token dropout is enabled). | |
| use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint | |
| gradients in the forward pass to reduce the computation in the backward). | |
| """ | |
| alphabet_size: int | |
| pad_token_id: int | |
| mask_token_id: int | |
| max_positions: int = 1024 | |
| embed_scale: float = 1.0 | |
| # architecture | |
| emb_layer_norm_before: bool = False | |
| attention_heads: int = 20 | |
| key_size: Optional[int] = None | |
| embed_dim: int = 1280 | |
| ffn_embed_dim: int = 5120 | |
| num_layers: int = 24 | |
| positional_embedding: Optional[str] = "learned" | |
| lm_head: Optional[str] = "simple" | |
| add_bias_kv: bool = False | |
| add_bias_ffn: bool = True | |
| use_rotary_embedding: bool = False | |
| rescaling_factor: Optional[float] = None | |
| ffn_activation_name: str = "gelu-no-approx" | |
| use_glu_in_ffn: bool = False | |
| mask_before_attention: bool = False | |
| layer_norm_eps: float = 1e-5 | |
| pre_layer_norm: bool = True | |
| bias_word_embedding: bool = False | |
| # dropout | |
| token_dropout: bool = False | |
| masking_ratio: float = 0.1 | |
| masking_prob: float = 0.8 | |
| # logging | |
| use_gradient_checkpointing: bool = False | |
| # return | |
| embeddings_layers_to_save: List[int] = field(default_factory=list) | |
| attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list) | |
| def __post_init__(self) -> None: | |
| """ | |
| Checks that the given values are compatible. | |
| """ | |
| if self.key_size is None: | |
| if not self.embed_dim % self.attention_heads == 0: | |
| raise ValueError( | |
| f"When no key size is provided, the embedding dimension should be " | |
| f"divisible by the number of heads, however provided embedding " | |
| f"dimension is {self.embed_dim} and the number of heads is " | |
| f"{self.attention_heads}." | |
| ) | |
| self.key_size = self.embed_dim // self.attention_heads | |
| if self.positional_embedding is not None: | |
| if type(self.positional_embedding) != str: | |
| raise TypeError | |
| if self.positional_embedding not in [ | |
| "learned", | |
| "sinusoidal", | |
| "learned_standard", | |
| "alibi_dnabert_2", | |
| ]: | |
| raise ValueError( | |
| "The positional_embedding argument should either be None," | |
| "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'." | |
| ) | |
| if self.lm_head is not None: | |
| if type(self.lm_head) != str: | |
| raise TypeError | |
| if self.lm_head not in ["simple", "roberta"]: | |
| raise ValueError( | |
| "The lm_head argument should either be None," | |
| "`simple` or `roberta`." | |
| ) | |
| if self.use_rotary_embedding and self.positional_embedding is not None: | |
| raise ValueError( | |
| "When using rotary embedding, positional_embedding must be set to none" | |
| ) | |
| if self.add_bias_kv and self.use_rotary_embedding: | |
| raise ValueError( | |
| "Biases on key and values are not compatible with Rotary embeddings." | |
| ) | |
| if self.positional_embedding == "alibi_dnabert_2": | |
| assert not self.add_bias_kv | |
| class ChatNTConfig(PretrainedConfig): | |
| model_type = "ChatNT" | |
| def __init__(self, **kwargs): # type: ignore | |
| self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3)) | |
| self.nt_config: NucleotideTransformerConfig = kwargs.get( | |
| "nt_config", NucleotideTransformerConfig(4000, 1, 4) | |
| ) | |
| self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get( | |
| "perceiver_resampler_config", PerceiverResamplerConfig() | |
| ) | |
| self.seq_token_id: int = kwargs.get("seq_token_id", 32000) | |
| self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1) | |
| self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2) | |
| super().__init__(**kwargs) | |
| def to_dict(self): # type: ignore | |
| output = super().to_dict() | |
| def serialize(obj): # type: ignore | |
| return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj) | |
| output["gpt_config"] = serialize(self.gpt_config) # type: ignore | |
| output["nt_config"] = serialize(self.nt_config) # type: ignore | |
| output["perceiver_resampler_config"] = serialize( # type: ignore | |
| self.perceiver_resampler_config | |
| ) | |
| return output | |
| class TorchBioBrainDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| gpt_config: GptConfig, | |
| seq_token_id: int, | |
| ): | |
| """ | |
| Initializes the BioBrain decoder, using a GPT model for text generation with | |
| bio embeddings. | |
| Args: | |
| gpt_config: Configuration for the GPT model | |
| seq_token_id: Index of the SEQ token | |
| """ | |
| super(TorchBioBrainDecoder, self).__init__() | |
| self.gpt_config = gpt_config | |
| self.seq_token_id = seq_token_id | |
| # Initialize the GPT model (assumed you have it already in PyTorch) | |
| self.gpt_model = TorchGptDecoder(self.gpt_config) | |
| def forward( | |
| self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass through the model. | |
| Args: | |
| english_token_ids: Tensor of English token IDs with shape | |
| (batch_size, num_english_tokens). | |
| projected_bio_embeddings: Optional tensor of bio embeddings with shape | |
| (batch_size, num_bio_sequences, ?, embed_dim). | |
| Returns: | |
| torch.Tensor: The logits from the GPT model, | |
| shaped (batch_size, num_english_tokens, vocab_size). | |
| """ | |
| # Compute English token embeddings | |
| tokens_embeddings = self.gpt_model.token_embed(english_token_ids) | |
| if projected_bio_embeddings is not None: | |
| ( | |
| batch_size, | |
| num_bio_sequences, | |
| _, | |
| bio_embed_dim, | |
| ) = projected_bio_embeddings.shape | |
| # Insert the bio embeddings at the SEQ token positions | |
| processed_tokens_ids = english_token_ids.clone() | |
| for bio_seq_num in range(num_bio_sequences): | |
| tokens_embeddings, processed_tokens_ids = self.insert_embeddings( | |
| processed_tokens_ids, | |
| tokens_embeddings, | |
| projected_bio_embeddings[:, bio_seq_num, :, :], | |
| bio_seq_num=bio_seq_num, | |
| ) | |
| # Regular GPT pass through | |
| embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings) | |
| embeddings = self.gpt_model.final_norm(embeddings) | |
| # Compute logits | |
| logits = self.gpt_model.lm_head(embeddings) | |
| if projected_bio_embeddings is not None: | |
| # Clean logits sequentially | |
| processed_tokens_ids = english_token_ids.clone() | |
| resampled_length = projected_bio_embeddings.shape[-2] | |
| for _ in range(num_bio_sequences): | |
| logits, processed_tokens_ids = self.cleanup_logits( | |
| tokens=processed_tokens_ids, | |
| logits=logits, | |
| resampled_length=resampled_length, | |
| ) | |
| return logits | |
| def insert_embeddings( | |
| self, | |
| tokens: torch.Tensor, | |
| input_embeddings: torch.Tensor, | |
| resampled_embeddings: torch.Tensor, | |
| bio_seq_num: int, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Inserts resampled embeddings in input_embeddings, starting at the SEQ token | |
| Args: | |
| tokens (torch.Tensor): Shape (batch_size, num_tokens) | |
| input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim) | |
| resampled_embeddings (torch.Tensor): | |
| Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim) | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: | |
| - input_embeddings with resampled_embeddings inserted at the SEQ token | |
| - tokens with the SEQ token set to -1 | |
| """ | |
| def _insert( | |
| tokens_1d: torch.Tensor, | |
| input_embeddings_1d: torch.Tensor, | |
| resampled_embeddings_1d: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| tokens (torch.Tensor): Shape (num_tokens,) | |
| input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,) | |
| resampled_embeddings (torch.Tensor): | |
| Shape (bio_sequence_length, embed_dim,) | |
| """ | |
| indices = torch.where(tokens_1d == self.seq_token_id)[0] | |
| if indices.numel() > 0: | |
| idx = indices[0].item() | |
| insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num | |
| x = torch.cat( | |
| [ | |
| input_embeddings_1d[:insertion_pos, :], | |
| resampled_embeddings_1d, | |
| input_embeddings_1d[insertion_pos:, :], | |
| ], | |
| dim=0, | |
| )[: tokens_1d.shape[0] + 1, :] | |
| x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[ | |
| :-1, : | |
| ] | |
| tokens_1d[idx] = -1 | |
| return x, tokens_1d | |
| else: | |
| return ( | |
| input_embeddings, | |
| tokens_1d, | |
| ) # Return unchanged if seq_token_id is not found | |
| tokens_acc = [] | |
| embeddings_acc = [] | |
| for i in range(tokens.shape[0]): | |
| embeddings_out, tokens_out = _insert( | |
| tokens[i].clone(), | |
| input_embeddings[i].clone(), | |
| resampled_embeddings[i].clone(), | |
| ) | |
| tokens_acc.append(tokens_out) | |
| embeddings_acc.append(embeddings_out) | |
| tokens_acc = torch.stack(tokens_acc) | |
| embeddings_acc = torch.stack(embeddings_acc) | |
| return embeddings_acc, tokens_acc | |
| def cleanup_logits( | |
| self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Removes the logits corresponding to the unused embeddings. | |
| Args: | |
| tokens: Input english tokens. | |
| logits: Input logits. | |
| Returns: | |
| Cleaned logits, last values will be equal to 0. | |
| """ | |
| def _clean( | |
| token: torch.Tensor, logit: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| indices = torch.where(token == self.seq_token_id)[0] | |
| if indices.numel() > 0: | |
| idx = indices[0].item() | |
| mask_idx = ( | |
| torch.arange(logit.shape[0] - resampled_length, device=logit.device) | |
| > idx | |
| ) | |
| mask_idx = mask_idx.unsqueeze(1) | |
| # Remove values corresponding to bio tokens | |
| logit = ( | |
| logit[:-resampled_length] * (~mask_idx) | |
| + logit[resampled_length:] * mask_idx | |
| ) | |
| # Append zeros at the end | |
| logit = torch.cat( | |
| ( | |
| logit, | |
| torch.zeros( | |
| (resampled_length, logit.shape[1]), | |
| dtype=logit.dtype, | |
| device=logit.device, | |
| ), | |
| ) | |
| ) | |
| # Update token | |
| token[idx] = -1 | |
| return logit, token | |
| else: | |
| return logit, token | |
| tokens_acc = [] | |
| logits_acc = [] | |
| for i in range(tokens.shape[0]): | |
| logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone()) | |
| tokens_acc.append(tokens_out) | |
| logits_acc.append(logits_out) | |
| tokens_acc = torch.stack(tokens_acc) | |
| logits_acc = torch.stack(logits_acc) | |
| return logits_acc, tokens_acc | |
| class TorchMultiOmicsModel(PreTrainedModel): | |
| config_class = ChatNTConfig | |
| def __init__(self, config: ChatNTConfig) -> None: | |
| if isinstance(config, dict): | |
| # If config is a dictionary instead of ChatNTConfig (which can happen | |
| # depending how the config was saved), we convert it to the config | |
| config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig( | |
| **config["gpt_config"]["rope_config"] | |
| ) | |
| config["gpt_config"] = GptConfig(**config["gpt_config"]) | |
| config["nt_config"] = NucleotideTransformerConfig(**config["nt_config"]) | |
| config["perceiver_resampler_config"] = PerceiverResamplerConfig( | |
| **config["perceiver_resampler_config"] | |
| ) | |
| config = ChatNTConfig(**config) # type: ignore | |
| else: | |
| if isinstance(config.gpt_config, dict): | |
| config.gpt_config["rope_config"] = RotaryEmbeddingConfig( | |
| **config.gpt_config["rope_config"] | |
| ) | |
| config.gpt_config = GptConfig(**config.gpt_config) | |
| if isinstance(config.nt_config, dict): | |
| config.nt_config = NucleotideTransformerConfig(**config.nt_config) | |
| if isinstance(config.perceiver_resampler_config, dict): | |
| config.perceiver_resampler_config = PerceiverResamplerConfig( | |
| **config.perceiver_resampler_config | |
| ) | |
| super().__init__(config=config) | |
| self.gpt_config = config.gpt_config | |
| self.nt_config = config.nt_config | |
| self.perceiver_resampler_config = config.perceiver_resampler_config | |
| self.seq_token_id = config.seq_token_id | |
| self.bio_pad_token_id = config.bio_pad_token_id | |
| self.english_pad_token_id = config.english_pad_token_id | |
| # Correct seq_token_id | |
| self.seq_token_id -= 1 | |
| self.biobrain_encoder = TorchBioBrainEncoder(nt_config=self.nt_config) | |
| self.biobrain_decoder = TorchBioBrainDecoder( | |
| gpt_config=self.gpt_config, seq_token_id=self.seq_token_id | |
| ) | |
| self.projection_model = TorchMultiModalPerceiverResamplerProjection( | |
| perceiver_resampler_config=self.perceiver_resampler_config, | |
| input_embed_dim=self.nt_config.embed_dim, | |
| embed_dim=self.gpt_config.embed_dim, | |
| english_vocab_size=self.gpt_config.vocab_size, | |
| bio_pad_token_id=self.bio_pad_token_id, | |
| english_pad_token_id=self.english_pad_token_id, | |
| ) | |
| def forward( | |
| self, | |
| multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor | None], | |
| projection_english_tokens_ids: torch.Tensor, | |
| projected_bio_embeddings: torch.Tensor = None, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Args: | |
| multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): | |
| english_tokens_ids: Represents the prompt tokens (english tokens) | |
| Shape (batch_size, num_english_tokens) | |
| bio_tokens_ids: Represents the bio sequences tokens | |
| Shape (batch_size, num_bio_sequences, num_bio_tokens) | |
| projection_english_tokens_ids (torch.Tensor): | |
| Shape (batch_size, num_english_tokens) | |
| projected_bio_embeddings (projected_bio_embeddings, optional): | |
| Shape (batch_size, num_bio_sequencse, ?, embed_dim). | |
| Defaults to None. | |
| Returns: | |
| dict[str, torch.Tensor] containing: | |
| - logits: | |
| Shape (batch_size, num_tokens, vocab_size) | |
| - projected_bio_embeddings: | |
| Shape (batch_size, num_bio_sequences, ?, embed_dim) | |
| """ | |
| english_token_ids, bio_token_ids = multi_omics_tokens_ids | |
| english_token_ids = english_token_ids.clone() | |
| projection_english_tokens_ids = projection_english_tokens_ids.clone() | |
| if bio_token_ids is not None: | |
| bio_token_ids = bio_token_ids.clone() | |
| if projected_bio_embeddings is not None: | |
| projected_bio_embeddings = projected_bio_embeddings.clone() | |
| # Replace config.vocab_size value in english tokens | |
| # We do this because the default vocab size (32000) doesn't match with the | |
| # number of tokens because of seq_token_id(=32000) that was added | |
| # Therefore, we will put seq_token_id to 31999 | |
| # (I will also put token n°31999 to 0, which is for unknown token) | |
| # This is a workaround to avoid having to change the vocab size in the config | |
| vocab_size = self.gpt_config.vocab_size | |
| # Replace vocab | |
| english_token_ids[english_token_ids == vocab_size - 1] = 0 | |
| projection_english_tokens_ids[ | |
| projection_english_tokens_ids == vocab_size - 1 | |
| ] = 0 | |
| english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 | |
| projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( | |
| vocab_size - 1 | |
| ) | |
| if bio_token_ids is None: | |
| projected_bio_embeddings = None | |
| else: | |
| num_bio_sequences = bio_token_ids.shape[1] | |
| if projected_bio_embeddings is None: | |
| # Compute bio sequences embeddings | |
| bio_embeddings_list = [ | |
| self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) | |
| for bio_seq_num in range(num_bio_sequences) | |
| ] | |
| # Project these embeddings | |
| projected_bio_embeddings = [ | |
| self.projection_model( | |
| bio_token_ids=bio_token_ids[:, bio_seq_num], | |
| bio_embeddings=bio_embeddings, | |
| english_token_ids=projection_english_tokens_ids, | |
| ) | |
| for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) | |
| ] | |
| projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) | |
| # decode | |
| logits = self.biobrain_decoder( | |
| english_token_ids=english_token_ids, | |
| projected_bio_embeddings=projected_bio_embeddings, | |
| ) | |
| logits = logits.to(torch.float32) | |
| outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} | |
| return outs | |
| class TorchRotaryEmbedding(torch.nn.Module): | |
| def __init__(self, config: RotaryEmbeddingConfig): | |
| super().__init__() | |
| self.max_seq_len = config.max_seq_len | |
| self.dim = config.dim | |
| self.theta = config.theta | |
| self.sincos_cache = None | |
| def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor: | |
| """ | |
| Create the sines and cosines for the RoPE. | |
| Returns: | |
| Sinusoidal positions of shape (self.max_seq_len, self.dim). | |
| """ | |
| # Create the inverse frequency based on theta and dim | |
| inv_freq = 1.0 / ( | |
| self.theta | |
| ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim) | |
| ) | |
| # Compute sinusoidal input using the broadcasting | |
| sinusoid_inp = torch.einsum( | |
| "i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq | |
| ) | |
| # Apply sin and cos to the sinusoidal input | |
| sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos() | |
| # Allocate a tensor for the final sin-cos values | |
| sincos = torch.zeros( | |
| (self.max_seq_len, self.dim), dtype=torch.float32, device=device | |
| ) | |
| # Fill the sincos tensor with sin and cos values | |
| sentinel = self.dim // 2 + self.dim % 2 | |
| sincos[:, :sentinel] = sin | |
| sincos[:, sentinel:] = cos | |
| return sincos | |
| def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Prepare a tensor to apply the RoPE mechanism. | |
| Args: | |
| x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), | |
| typically this is the key or query tensor. | |
| Returns: | |
| The even indices in the last dimension have their sign flipped. | |
| Tensor of shape (batch_size, seq_len, num_heads, head_dim). | |
| """ | |
| # Split the tensor into two halves (odd and even indexed dimensions) | |
| rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1) | |
| # Reshape the tensor to the original shape | |
| rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,)) | |
| return rotate_half | |
| def _apply_rotary_pos_emb( | |
| self, x: torch.Tensor, sincos: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Applies rotary embeddings to x. | |
| Args: | |
| x: Tensor of shape (batch_size, seq_len, num_heads, head_dim), | |
| typically this is the key or query tensor. | |
| sincos: Tuple of sine and cosine tensors for position encoding. | |
| Returns: | |
| RoPE embeddings tensor. | |
| """ | |
| sin_pos, cos_pos = sincos | |
| # Reshape the sin and cos tensors for broadcasting | |
| sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1) | |
| cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1) | |
| # Apply the rotary embedding mechanism | |
| return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos) | |
| def __call__( | |
| self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Applies rotary embeddings to k and q. | |
| Args: | |
| k: key tensor of shape (batch_size, seq_len, num_heads, head_dim), | |
| q: value tensor of shape (batch_size, seq_len, num_heads, head_dim), | |
| positions: optional positions offset useful when caching, | |
| Returns: | |
| RoPE embeddings for the keys and values. | |
| """ | |
| if self.sincos_cache is None: | |
| device = k.device | |
| self.sincos_cache = self._create_sinusoidal_positions(device=device) | |
| batch_size, seq_len, num_heads, head_dim = k.shape | |
| # Generate position ids | |
| position_ids = ( | |
| torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1) | |
| ) | |
| if positions is not None: | |
| position_ids += positions | |
| # Retrieve sincos values using the position_ids | |
| sincos = self.sincos_cache[position_ids] # type: ignore | |
| # Split sincos into sin_pos and cos_pos | |
| sincos = torch.chunk(sincos, 2, dim=-1) | |
| # Apply rotary position embedding to key (k) and query (q) | |
| k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos) | |
| k_pass = k[..., self.dim :] | |
| q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos) | |
| q_pass = q[..., self.dim :] | |
| # Concatenate the rotated and non-rotated parts | |
| keys = torch.cat([k_rot, k_pass], dim=-1) | |
| values = torch.cat([q_rot, q_pass], dim=-1) | |
| return keys, values | |
| class TorchGptGroupedQueryAttention(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| rope_config: RotaryEmbeddingConfig, | |
| num_kv_heads: int = None, # type: ignore | |
| head_dim: int = None, # type: ignore | |
| add_bias_attn: bool = False, # type: ignore | |
| ) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.num_kv_heads = num_kv_heads or num_heads | |
| self.embed_dim = embed_dim | |
| self.head_dim = head_dim or (embed_dim // num_heads) | |
| self.add_bias_attn = add_bias_attn | |
| self.rope = TorchRotaryEmbedding(rope_config) | |
| self.query_linear = nn.Linear( | |
| embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn | |
| ) | |
| self.key_linear = nn.Linear( | |
| embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn | |
| ) | |
| self.value_linear = nn.Linear( | |
| embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn | |
| ) | |
| self.out_linear = nn.Linear( | |
| self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn | |
| ) | |
| def forward( | |
| self, | |
| query_inputs: torch.Tensor, | |
| key_inputs: torch.Tensor, | |
| value_inputs: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| batch_size, seq_len, _ = query_inputs.shape | |
| queries = self.query_linear(query_inputs).view( # noqa | |
| batch_size, seq_len, self.num_heads, self.head_dim | |
| ) | |
| keys = self.key_linear(key_inputs).view( # noqa | |
| batch_size, seq_len, self.num_kv_heads, self.head_dim | |
| ) | |
| values = self.value_linear(value_inputs).view( # noqa | |
| batch_size, seq_len, self.num_kv_heads, self.head_dim | |
| ) | |
| keys, queries = self.rope(keys, queries) | |
| n_rep = self.num_heads // self.num_kv_heads | |
| keys = keys.repeat_interleave(n_rep, dim=2) | |
| values = values.repeat_interleave(n_rep, dim=2) | |
| attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / ( | |
| self.head_dim**0.5 | |
| ) | |
| if attention_mask is not None: | |
| attention_logits = attention_logits.masked_fill( | |
| attention_mask == 0, float("-inf") | |
| ) | |
| attention_weights = nn.functional.softmax(attention_logits, dim=-1) | |
| attention_weights = attention_weights.to(values.dtype) | |
| values = torch.einsum("bhtT,bThd->bthd", attention_weights, values) | |
| values = values.contiguous().view(batch_size, seq_len, -1) | |
| return self.out_linear(values) | |
| class TorchGptDecoder(nn.Module): | |
| def __init__(self, config: GptConfig, name: Optional[str] = None): | |
| super().__init__() | |
| self.config = config | |
| self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim) | |
| if config.norm_type == "layer_norm": | |
| self.final_norm = nn.LayerNorm(config.embed_dim) | |
| elif config.norm_type == "RMS_norm": | |
| self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps) | |
| else: | |
| raise ValueError(f"unrecognized norm_type in config {config.norm_type}") | |
| self.layers = nn.ModuleList( | |
| [ | |
| TorchGptDecoderLayer( | |
| embed_dim=config.embed_dim, | |
| ffn_embed_dim=config.ffn_embed_dim, | |
| num_heads=config.num_heads, | |
| rope_config=config.rope_config, | |
| norm_type=config.norm_type, | |
| parallel_attention_ff=config.parallel_attention_ff, | |
| add_bias_ffn=config.add_bias_ffn, | |
| ffn_activation_name=config.ffn_activation_name, | |
| use_glu_in_ffn=config.use_glu_in_ffn, | |
| num_kv_heads=config.num_kv_heads, # type: ignore | |
| add_bias_attn=config.add_bias_attn, | |
| rms_norm_eps=config.rms_norm_eps, | |
| ) | |
| for _ in range(config.num_layers) | |
| ] | |
| ) | |
| self.lm_head = TorchSimpleLMHead( | |
| embed_dim=config.embed_dim, | |
| alphabet_size=config.vocab_size, | |
| add_bias_lm_head=config.add_bias_lm_head, | |
| ) | |
| def apply_transformer_layers( | |
| self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| if attention_mask is None: | |
| attention_mask = build_causal_attention_mask( | |
| 1, embeddings.shape[1], device=embeddings.device | |
| ) | |
| for layer in self.layers: | |
| embeddings = layer(embeddings, attention_mask) | |
| return embeddings | |
| def forward( | |
| self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None | |
| ) -> dict[str, torch.Tensor]: | |
| if attention_mask is None: | |
| attention_mask = build_causal_attention_mask( | |
| 1, token_ids.shape[1], device=token_ids.device | |
| ) | |
| tokens_embeddings = self.token_embed(token_ids) | |
| after_transformer_embeddings = self.apply_transformer_layers( | |
| tokens_embeddings, attention_mask=attention_mask | |
| ) | |
| embeddings = self.final_norm(after_transformer_embeddings) | |
| logits = self.lm_head(embeddings) | |
| return {"embeddings": embeddings, "logits": logits} | |
| class TorchSimpleLMHead(nn.Module): | |
| def __init__( | |
| self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True | |
| ) -> None: | |
| super().__init__() | |
| self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fc(x) | |
| class TorchGptDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| ffn_embed_dim: int, | |
| num_heads: int, | |
| rope_config: RotaryEmbeddingConfig, | |
| norm_type: str, | |
| parallel_attention_ff: bool, | |
| add_bias_ffn: bool, | |
| ffn_activation_name: str, | |
| use_glu_in_ffn: bool, | |
| num_kv_heads: int, | |
| add_bias_attn: bool, | |
| rms_norm_eps: float = 1e-6, | |
| ) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.parallel_attention_ff = parallel_attention_ff | |
| self.use_glu_in_ffn = use_glu_in_ffn | |
| # Self-Attention layer | |
| self.self_attn = TorchGptGroupedQueryAttention( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| num_kv_heads=num_kv_heads, | |
| rope_config=rope_config, | |
| add_bias_attn=add_bias_attn, | |
| ) | |
| # Normalization layers | |
| if norm_type == "layer_norm": | |
| self.attn_norm = nn.LayerNorm(embed_dim) | |
| if not self.parallel_attention_ff: | |
| self.ffn_norm = nn.LayerNorm(embed_dim) | |
| elif norm_type == "RMS_norm": | |
| self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) | |
| if not self.parallel_attention_ff: | |
| self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps) | |
| else: | |
| raise ValueError(f"unrecognized norm_type: {norm_type}") | |
| # Feedforward network | |
| self.activation = get_activation_fn(ffn_activation_name) | |
| ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1) | |
| self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn) | |
| self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn) | |
| def forward( | |
| self, embeddings: torch.Tensor, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| residuals = embeddings | |
| if self.parallel_attention_ff: | |
| # Parallel Attention + MLP | |
| embeddings_normed = self.attn_norm(embeddings) | |
| attn_output, _ = self.self_attn( | |
| embeddings_normed, | |
| embeddings_normed, | |
| embeddings_normed, | |
| attn_mask=attention_mask, | |
| ) | |
| ffn_output = self.mlp(embeddings_normed) # type: ignore | |
| return residuals + attn_output + ffn_output | |
| else: | |
| # Sequential Attention + MLP | |
| normed_embeddings = self.attn_norm(embeddings) | |
| attn_output = embeddings + self.self_attn( | |
| normed_embeddings, | |
| normed_embeddings, | |
| normed_embeddings, | |
| attention_mask=attention_mask, | |
| ) | |
| normed_embeddings2 = self.ffn_norm(attn_output) | |
| ffn_output = self.mlp(normed_embeddings2) # type: ignore | |
| return attn_output + ffn_output # Residual connection | |
| def mlp(self, x: torch.Tensor) -> torch.Tensor: | |
| """Applies the feedforward network (MLP) with optional GLU.""" | |
| ffn_output = self.fc1(x) | |
| if self.use_glu_in_ffn: | |
| ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1) | |
| ffn_output = self.activation(ffn_output1) * ffn_output2 | |
| else: | |
| ffn_output = self.activation(ffn_output) | |
| return self.fc2(ffn_output) | |
| class TorchRMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.eps = eps | |
| self.scale = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return ( | |
| x | |
| * self.scale | |
| / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) | |
| ) | |
| def get_activation_fn(activation_name: str): # type: ignore | |
| activations = { | |
| "gelu": nn.functional.gelu, | |
| "relu": nn.functional.relu, | |
| "swish": nn.functional.silu, | |
| "silu": nn.functional.silu, | |
| } | |
| return activations.get(activation_name, nn.functional.relu) | |
| def build_causal_attention_mask( | |
| batch_size: int, seq_len: int, device: torch.device | |
| ) -> torch.Tensor: | |
| """ | |
| Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed | |
| to an attention layer. | |
| Args: | |
| batch_size: Batch size. | |
| seq_len: Length of the sequences. | |
| Returns: | |
| Batch of causal masks. | |
| """ | |
| mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device) | |
| causal_mask = torch.tril(mask) | |
| return causal_mask | |
| class RotaryEmbeddingConfigBis: | |
| """ | |
| Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows | |
| to adapt the rotary embeddings to larger lengths than what was used for training. | |
| One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa | |
| Args: | |
| """ | |
| rescaling_factor: Optional[float] | |
| class RotaryEmbeddingBis(torch.nn.Module): | |
| """ | |
| Rotary position embeddings based on those in | |
| [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). | |
| Query and keys are transformed by rotation | |
| matrices which depend on their relative positions. | |
| """ | |
| def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis): | |
| super().__init__() | |
| # Extract argument from the config | |
| self.rescaling_factor = rotary_embedding_config.rescaling_factor | |
| self.upper_freq = 10000 | |
| self.dim = dim | |
| self._seq_len_cached = None | |
| self._cos_cached = None | |
| self._sin_cached = None | |
| def _apply_rotary_pos_emb( | |
| self, | |
| heads: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ """ | |
| x_first, x_second = ( | |
| heads[..., : heads.shape[-1] // 2], | |
| heads[..., heads.shape[-1] // 2 :], | |
| ) | |
| first_part = x_first * cos - x_second * sin | |
| second_part = x_second * cos + x_first * sin | |
| return torch.cat((first_part, second_part), dim=-1) | |
| def _compute_cos_sin_tables( | |
| self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| seq_len = x.shape[seq_dimension] | |
| # Reset the tables if the sequence length has changed, | |
| # or if we're on a new device (possibly due to tracing for instance) | |
| self._seq_len_cached = seq_len | |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) | |
| # freqs = torch.outer(t, inv_freq) | |
| freqs = torch.einsum("i, j -> ij", t, inv_freq) | |
| self._cos_cached = torch.cos(freqs)[None, :, None, :] | |
| self._sin_cached = torch.sin(freqs)[None, :, None, :] | |
| # emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
| # self._cos_cached = emb.cos()[None, None, :, :] | |
| # self._sin_cached = emb.sin()[None, None, :, :] | |
| return self._cos_cached, self._sin_cached | |
| def forward( | |
| self, q: torch.Tensor, k: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.rescaling_factor is None: | |
| inv_freq = 1.0 / ( | |
| self.upper_freq | |
| ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) | |
| ) | |
| else: | |
| updated_base = self.upper_freq * ( | |
| self.rescaling_factor ** (self.dim / (self.dim - 2)) | |
| ) | |
| inv_freq = 1.0 / ( | |
| updated_base | |
| ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) | |
| ) | |
| self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( | |
| q, | |
| inv_freq, | |
| seq_dimension=-3, | |
| ) | |
| return ( | |
| self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), | |
| self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), | |
| ) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| key_size: int, | |
| rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, | |
| add_bias_kv: bool = False, | |
| value_size: Optional[int] = None, | |
| model_size: Optional[int] = None, | |
| name: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| if not model_size: | |
| model_size = key_size * num_heads | |
| if not value_size: | |
| value_size = key_size | |
| self.model_size = model_size | |
| self.key_size = key_size | |
| self.value_size = value_size | |
| self.add_bias_kv = add_bias_kv | |
| self.name = name | |
| self.num_heads = num_heads | |
| self._rotary_embedding_config = rotary_embedding_config | |
| self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) | |
| self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) | |
| self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) | |
| self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) | |
| if self._rotary_embedding_config: | |
| self._rotary_embedding = RotaryEmbeddingBis( | |
| self.key_size, self._rotary_embedding_config | |
| ) | |
| def apply_rotary_embeddings( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ """ | |
| query, key = self._rotary_embedding(query, key) | |
| return query, key | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| attention_weight_bias: Optional[torch.Tensor] = None, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Returns: | |
| dictionary containing attention weights | |
| and outputs. | |
| """ | |
| key_heads = self.w_k(key).reshape( | |
| (*key.shape[:-1], self.num_heads, self.key_size) | |
| ) | |
| query_heads = self.w_q(query).reshape( | |
| (*query.shape[:-1], self.num_heads, self.key_size) | |
| ) | |
| value_heads = self.w_v(value).reshape( | |
| (*value.shape[:-1], self.num_heads, self.value_size) | |
| ) | |
| if self._rotary_embedding_config: | |
| query_heads, key_heads = self.apply_rotary_embeddings( | |
| query_heads, key_heads | |
| ) | |
| attention_weights = torch.einsum( | |
| "...thd, ...Thd -> ...htT", query_heads, key_heads | |
| ) | |
| sqrt_key_size = np.sqrt(self.key_size) | |
| attention_weights = attention_weights / sqrt_key_size | |
| if attention_mask is not None: | |
| attention_weights = torch.where(attention_mask, attention_weights, -1e30) | |
| attention_weights = attention_weights.to(value_heads.dtype) | |
| if attention_weight_bias is not None: | |
| attention_weights = F.softmax( | |
| attention_weights + attention_weight_bias, dim=-1 | |
| ) | |
| else: | |
| attention_weights = F.softmax(attention_weights, dim=-1) | |
| value_out = torch.einsum( | |
| "...htT, ...Thd->...thd", attention_weights, value_heads | |
| ) | |
| value_out = value_out.reshape((*value_out.shape[:-2], -1)) | |
| embeddings = self.output(value_out) | |
| return {"attention_weights": attention_weights, "embeddings": embeddings} | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| embed_dim: int, | |
| ffn_embed_dim: int, | |
| key_size: Optional[int] = None, | |
| add_bias_kv: bool = False, | |
| add_bias_fnn: bool = True, | |
| ffn_activation_name: str = "gelu-no-approx", | |
| use_glu_in_ffn: bool = False, | |
| layer_norm_eps: float = 1e-5, # this is the default haiku value | |
| pre_layer_norm: bool = True, | |
| name: Optional[str] = None, | |
| rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None, | |
| ): | |
| super().__init__() | |
| if key_size is None: | |
| if embed_dim % num_heads != 0: | |
| raise ValueError( | |
| f"The embedding dimension should be divisible by the number of " | |
| f"heads, however provided embedding dimension is {embed_dim} and " | |
| f"the number of heads is {num_heads}." | |
| ) | |
| else: | |
| key_size = embed_dim // num_heads | |
| # Get ffn activation function | |
| self._pre_layer_norm = pre_layer_norm | |
| self._use_glu_in_fnn = use_glu_in_ffn | |
| # Define layers | |
| if use_glu_in_ffn: | |
| # user should multiply ffn_embed_dim by 2/3 when using GLU | |
| # to keep total number of parameters equal | |
| # see https://arxiv.org/pdf/2002.05202.pdf. for more details | |
| # we multiply by 2 here as the output will be split in 2 for GLU | |
| self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) | |
| else: | |
| self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) | |
| self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) | |
| self.layer_norm_self_attention = nn.LayerNorm( | |
| embed_dim, | |
| ) | |
| self.layer_norm_mlp = nn.LayerNorm(embed_dim) | |
| if ffn_activation_name == "swish": | |
| self._ffn_activation_fn = nn.SiLU() | |
| elif ffn_activation_name == "gelu-no-approx": | |
| self._ffn_activation_fn = nn.GELU(approximate="tanh") | |
| else: | |
| self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) | |
| self.mha = MultiHeadAttention( | |
| num_heads=num_heads, | |
| key_size=key_size, | |
| add_bias_kv=add_bias_kv, | |
| model_size=embed_dim, | |
| name="self_attention", | |
| rotary_embedding_config=rotary_embedding_config, | |
| ) | |
| def mlp(self, embed: torch.Tensor) -> torch.Tensor: | |
| if self._pre_layer_norm: | |
| x = self.layer_norm_mlp(embed) | |
| else: | |
| x = embed | |
| if self._use_glu_in_fnn: | |
| x = self.fc1(x) | |
| x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) | |
| x = self._ffn_activation_fn(x1) * x2 | |
| else: | |
| x = self._ffn_activation_fn(self.fc1(x)) | |
| x = self.fc2(x) | |
| if not self._pre_layer_norm: | |
| x = self.layer_norm_mlp(x + embed) | |
| return x | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| attention_weight_bias: Optional[torch.Tensor] = None, | |
| ) -> dict[str, torch.Tensor]: | |
| res = x | |
| if self._pre_layer_norm: | |
| x = self.layer_norm_self_attention(x) | |
| output: dict[str, torch.Tensor] = self.mha( | |
| x, | |
| x, | |
| x, | |
| attention_mask=attention_mask, | |
| attention_weight_bias=attention_weight_bias, | |
| ) | |
| if not self._pre_layer_norm: | |
| output["embeddings"] = self.layer_norm_self_attention( | |
| output["embeddings"] + res | |
| ) | |
| x = output["embeddings"] | |
| else: | |
| x = output["embeddings"] | |
| x = res + x | |
| # MLP | |
| if not self._pre_layer_norm: | |
| x = self.mlp(x) | |
| else: | |
| x = x + self.mlp(x) | |
| output["embeddings"] = x | |
| return output | |
| class RobertaLMHead(nn.Module): | |
| """ | |
| Roberta Language Model head. Transforms final attention layer output into a | |
| distribution over tokens at each position. | |
| """ | |
| def __init__(self, embed_dim: int, alphabet_size: int): | |
| """ | |
| Args: | |
| embed_dim: Embedding dimension. | |
| alphabet_size: Number of tokens in the alphabet. | |
| """ | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.alphabet_size = alphabet_size | |
| # Define layers | |
| self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) | |
| self._fc1 = nn.Linear(embed_dim, embed_dim) | |
| self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True) | |
| self._final_fc = nn.Linear(embed_dim, alphabet_size) | |
| def forward(self, x: torch.Tensor) -> dict: | |
| x = self._first_layer_norm(x) | |
| embeddings = x | |
| x = self._fc1(x) | |
| x = nn.functional.gelu(x) | |
| x = self._second_layer_norm(x) | |
| logits = self._final_fc(x) | |
| return {"embeddings": embeddings, "logits": logits} | |
| class TorchNucleotideTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| nt_config: NucleotideTransformerConfig, | |
| ): | |
| super(TorchNucleotideTransformer, self).__init__() | |
| self.nt_config = nt_config | |
| # Other cases are not implemented | |
| assert nt_config.positional_embedding is None | |
| assert nt_config.lm_head == "roberta" | |
| assert nt_config.use_rotary_embedding is True | |
| assert nt_config.token_dropout is False | |
| assert nt_config.emb_layer_norm_before is False | |
| assert nt_config.mask_before_attention is False | |
| assert nt_config.bias_word_embedding is False | |
| assert nt_config.use_gradient_checkpointing is False | |
| self.embed_layer = nn.Embedding(nt_config.alphabet_size, nt_config.embed_dim) | |
| self.lm_head = RobertaLMHead( | |
| embed_dim=nt_config.embed_dim, | |
| alphabet_size=nt_config.alphabet_size, | |
| ) | |
| self.rotary_embedding_config = RotaryEmbeddingConfigBis( | |
| rescaling_factor=nt_config.rescaling_factor | |
| ) | |
| self.attention_blocks = nn.ModuleList( | |
| [ | |
| SelfAttentionBlock( # type: ignore | |
| num_heads=nt_config.attention_heads, | |
| embed_dim=nt_config.embed_dim, | |
| key_size=nt_config.key_size, | |
| ffn_embed_dim=nt_config.ffn_embed_dim, | |
| add_bias_kv=nt_config.add_bias_kv, | |
| add_bias_fnn=nt_config.add_bias_ffn, | |
| ffn_activation_name=nt_config.ffn_activation_name, | |
| use_glu_in_ffn=nt_config.use_glu_in_ffn, | |
| rotary_embedding_config=self.rotary_embedding_config, | |
| layer_norm_eps=nt_config.layer_norm_eps, | |
| pre_layer_norm=nt_config.pre_layer_norm, | |
| ) | |
| for _ in range(nt_config.num_layers) | |
| ] | |
| ) | |
| def forward( | |
| self, tokens: torch.Tensor, attention_mask: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| """ | |
| Computes the embeddings based on the input tokens. | |
| Args: | |
| tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len). | |
| attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len). | |
| If no mask is provided, a mask by default which equals 1 over all non | |
| pad tokens and 0 over pad tokens is computed. | |
| Returns: | |
| Dictionary containing the final embeddings and logits. | |
| """ | |
| x = self.embed_layer(tokens) | |
| # RoBERTa's mask scaling factor | |
| x = self.nt_config.embed_scale * x | |
| if attention_mask is None: | |
| attention_mask = build_padding_attention_mask( | |
| tokens=tokens, pad_token_id=self.nt_config.pad_token_id | |
| ) | |
| for layer in self.attention_blocks: | |
| x = layer(x, attention_mask)["embeddings"] | |
| assert self.nt_config.lm_head == "roberta" | |
| x = self.lm_head(x)["embeddings"] | |
| return x | |
| def build_padding_attention_mask( | |
| tokens: torch.Tensor, pad_token_id: int | |
| ) -> torch.Tensor: | |
| """ | |
| Builds a padding mask from a sequence of tokens by masking <pad> in the attention. | |
| Args: | |
| tokens: Batch of sequences of shape (batch_size, seq_len). | |
| pad_token_id: Int corresponding to the <pad> token to mask. | |
| Returns: | |
| Batch of attention masks, masking out <pad> tokens. | |
| """ | |
| padding_mask = tokens != pad_token_id | |
| padding_mask = padding_mask.unsqueeze(1) | |
| padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask) | |
| return padding_mask | |
| class TorchBioBrainEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| nt_config: NucleotideTransformerConfig, | |
| ): | |
| super(TorchBioBrainEncoder, self).__init__() | |
| self.nt_config = nt_config | |
| self.nt_model = TorchNucleotideTransformer(self.nt_config) | |
| def forward( | |
| self, | |
| bio_token_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| bio_token_ids (torch.Tensor): | |
| Shape (batch_size, num_bio_tokens) | |
| Returns: | |
| torch.Tensor: | |
| Shape (batch_size, num_bio_tokens, embed_dim) | |
| """ | |
| bio_embeddings = self.nt_model(tokens=bio_token_ids) | |
| return bio_embeddings | |
| class TorchMultiModalPerceiverResamplerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| embed_dim: int, | |
| ffn_embed_dim: int, | |
| key_size: Optional[int] = None, | |
| add_bias_kv: bool = False, | |
| add_bias_ffn: bool = True, | |
| ffn_activation_name: str = "gelu", | |
| use_glu_in_ffn: bool = False, | |
| ): | |
| super().__init__() | |
| if key_size is None: | |
| if embed_dim % num_heads != 0: | |
| raise ValueError( | |
| f"Embedding dimension {embed_dim} should be divisible by " | |
| f"num_heads {num_heads}." | |
| ) | |
| key_size = embed_dim // num_heads | |
| self.num_heads = num_heads | |
| self.embed_dim = embed_dim | |
| self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim | |
| self.use_glu_in_ffn = use_glu_in_ffn | |
| self.cross_attention_1 = MultiHeadAttention( | |
| num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv | |
| ) | |
| self.cross_attention_2 = MultiHeadAttention( | |
| num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv | |
| ) | |
| self.norm_cross_attention_1 = nn.LayerNorm(embed_dim) | |
| self.norm_cross_attention_2 = nn.LayerNorm(embed_dim) | |
| self.norm_mlp = nn.LayerNorm(embed_dim) | |
| self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn) | |
| self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn) | |
| self.activation_fn = getattr( | |
| nn.functional, ffn_activation_name, nn.functional.gelu | |
| ) | |
| def mlp(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.norm_mlp(x) | |
| if self.use_glu_in_ffn: | |
| x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1) | |
| x = self.activation_fn(x1) * x2 | |
| else: | |
| x = self.activation_fn(self.fc1(x)) | |
| return self.fc2(x) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cross_attention_embeddings_1: torch.Tensor, | |
| cross_attention_embeddings_2: torch.Tensor, | |
| attention_mask_1: Optional[torch.Tensor] = None, | |
| attention_mask_2: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| res = x | |
| x = self.norm_cross_attention_1(x) | |
| attn_output = self.cross_attention_1( | |
| query=x, | |
| key=cross_attention_embeddings_1, | |
| value=cross_attention_embeddings_1, | |
| attention_mask=attention_mask_1, | |
| )["embeddings"] | |
| x = res + attn_output | |
| res = x | |
| x = self.norm_cross_attention_2(x) | |
| attn_output = self.cross_attention_2( | |
| query=x, | |
| key=cross_attention_embeddings_2, | |
| value=cross_attention_embeddings_2, | |
| attention_mask=attention_mask_2, | |
| )["embeddings"] | |
| x = res + attn_output | |
| x = x + self.mlp(x) | |
| return {"embeddings": x} | |
| class TorchMultiModalPerceiverResampler(nn.Module): | |
| """ | |
| Perceiver Resampler model, made of successive PerceiverResamplerBlocks. | |
| """ | |
| def __init__( | |
| self, | |
| config: PerceiverResamplerConfig, | |
| name: Optional[str] = None, | |
| ): | |
| """ | |
| Initialize a Perceiver Resampler model. | |
| Args: | |
| config: Dataclass containing model hyperparameters. | |
| name: Name for module (custom will break weight loading). | |
| """ | |
| super().__init__() | |
| self.config = config | |
| self.name = name | |
| self.layers = nn.ModuleList( | |
| [ | |
| TorchMultiModalPerceiverResamplerBlock( | |
| num_heads=self.config.attention_heads, | |
| embed_dim=self.config.embed_dim, | |
| key_size=self.config.key_size, | |
| ffn_embed_dim=self.config.ffn_embed_dim, | |
| add_bias_kv=self.config.add_bias_kv, | |
| add_bias_ffn=self.config.add_bias_ffn, | |
| ffn_activation_name=self.config.ffn_activation_name, | |
| use_glu_in_ffn=self.config.use_glu_in_ffn, | |
| ) | |
| for _ in range(self.config.num_layers) | |
| ] | |
| ) | |
| self.latent_queries = torch.nn.Parameter( | |
| torch.randn(self.config.resampled_length, self.config.embed_dim) | |
| * ( | |
| 1.0 | |
| / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32)) | |
| ) | |
| ) | |
| def apply_attention_blocks( | |
| self, | |
| x: torch.Tensor, | |
| xf_1: torch.Tensor, | |
| xf_2: torch.Tensor, | |
| outs: Dict[str, torch.Tensor], | |
| attention_mask_1: Optional[torch.Tensor] = None, | |
| attention_mask_2: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """ | |
| Create the blocks of attention layers and applies them. | |
| """ | |
| for layer in self.layers: | |
| concat_input_1 = torch.cat([xf_1, x], dim=1) | |
| concat_input_2 = torch.cat([xf_2, x], dim=1) | |
| output = layer( | |
| x=x, | |
| cross_attention_embeddings_1=concat_input_1, | |
| cross_attention_embeddings_2=concat_input_2, | |
| attention_mask_1=attention_mask_1, | |
| attention_mask_2=attention_mask_2, | |
| ) | |
| x = output["embeddings"] | |
| return x, outs | |
| def forward( | |
| self, | |
| input_embeddings_1: torch.Tensor, | |
| input_embeddings_2: torch.Tensor, | |
| attention_mask_1: Optional[torch.Tensor] = None, | |
| attention_mask_2: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Computes the embeddings based on the input tokens. | |
| """ | |
| assert ( | |
| input_embeddings_1.shape[-1] == self.config.embed_dim | |
| ), "The input embedding dim should match the model embed dim" | |
| assert ( | |
| input_embeddings_2.shape[-1] == self.config.embed_dim | |
| ), "The input embedding dim should match the model embed dim" | |
| batch_size = input_embeddings_1.shape[0] | |
| latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1) | |
| outs: Dict[str, torch.Tensor] = {} | |
| x = latent_queries | |
| x, outs = self.apply_attention_blocks( | |
| x=x, | |
| xf_1=input_embeddings_1, | |
| xf_2=input_embeddings_2, | |
| outs=outs, | |
| attention_mask_1=attention_mask_1, | |
| attention_mask_2=attention_mask_2, | |
| ) | |
| outs["embeddings"] = x | |
| return outs | |
| class TorchMultiModalPerceiverResamplerProjection(nn.Module): | |
| def __init__( | |
| self, | |
| perceiver_resampler_config: PerceiverResamplerConfig, | |
| input_embed_dim: int, | |
| embed_dim: int, | |
| bio_pad_token_id: int, | |
| english_pad_token_id: int, | |
| english_vocab_size: int, | |
| ): | |
| super().__init__() | |
| self.config = perceiver_resampler_config | |
| self.input_embed_dim = input_embed_dim | |
| self.embed_dim = embed_dim | |
| self.bio_pad_token_id = bio_pad_token_id | |
| self.english_pad_token_id = english_pad_token_id | |
| self.english_vocab_size = english_vocab_size | |
| self.bio_projection = nn.Linear(input_embed_dim, embed_dim) | |
| self.token_embedding = nn.Embedding(english_vocab_size, embed_dim) | |
| self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config) | |
| def forward( | |
| self, | |
| bio_token_ids: torch.Tensor, | |
| bio_embeddings: torch.Tensor, | |
| english_token_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| bio_token_ids (torch.Tensor): | |
| Shape (batch_size, num_bio_tokens) | |
| bio_embeddings (torch.Tensor): | |
| Shape (batch_size, num_bio_tokens, embed_dim) | |
| english_token_ids (torch.Tensor): | |
| Shape (batch_size, num_english_tokens) | |
| """ | |
| projected_bio_embeddings = self.bio_projection(bio_embeddings) | |
| english_embeddings = self.token_embedding(english_token_ids) | |
| bio_attention_mask = build_perceiver_padding_attention_mask( | |
| bio_token_ids, self.config.resampled_length, self.bio_pad_token_id | |
| ) | |
| english_attention_mask = build_perceiver_padding_attention_mask( | |
| english_token_ids, self.config.resampled_length, self.english_pad_token_id | |
| ) | |
| projected_embeddings = self.perceiver_resampler( | |
| input_embeddings_1=projected_bio_embeddings, | |
| attention_mask_1=bio_attention_mask, | |
| input_embeddings_2=english_embeddings, | |
| attention_mask_2=english_attention_mask, | |
| )["embeddings"] | |
| return projected_embeddings | |
| def build_perceiver_padding_attention_mask( | |
| tokens: torch.Tensor, resampled_length: int, pad_token_id: int | |
| ) -> torch.Tensor: | |
| batch_size, seq_len = tokens.shape | |
| padding_mask = tokens != pad_token_id # (batch_size, seq_len) | |
| padding_mask = torch.cat( | |
| [ | |
| padding_mask, | |
| torch.ones( | |
| (batch_size, resampled_length), dtype=torch.bool, device=tokens.device | |
| ), | |
| ], | |
| dim=1, | |
| ) # (batch_size, seq_len + resampled_length) | |
| padding_mask = padding_mask[:, None, None, :] | |
| padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa | |
| return padding_mask | |