evo_prot_grad.common.embeddings
OneHotEmbedding
evo_prot_grad.common.embeddings.OneHotEmbedding
Bases: nn.Module
Compute the embeddings for a sequence of amino acids. Converts a sequence of amino acids to a sequence of one-hot vectors first. Caches the one-hot tensors for computing gradients with respect to the one-hot tensors.
forward(input_ids: torch.LongTensor) -> torch.Tensor
Compute the embeddings for a sequence of amino acids, caching the one-hot tensors for computing gradients with respect to the one-hot tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ids |
torch.LongTensor
|
Amino acid sequences of shape [batch_size, max_sequence_len]. |
required |
Returns:
Name | Type | Description |
---|---|---|
embeddings |
torch.FloatTensor
|
Amino acid embeddings of shape [batch_size, max_sequence_len, embedding_dim]. |
IdentityEmbedding
evo_prot_grad.common.embeddings.IdentityEmbedding
Bases: nn.Module
A module that does nothing except store the most recent one_hots tensor.
forward(one_hots: torch.Tensor) -> torch.Tensor
Cache the one_hots tensor and return it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
one_hots |
torch.Tensor
|
A torch.FloatTensor of shape [batch_size, max_sequence_len, vocab_size]. |
required |
Returns:
Name | Type | Description |
---|---|---|
one_hots |
torch.Tensor
|
The same one_hots tensor that was passed in. |