Skip to content

evo_prot_grad.common.embeddings

OneHotEmbedding

evo_prot_grad.common.embeddings.OneHotEmbedding

Bases: nn.Module

Compute the embeddings for a sequence of amino acids. Converts a sequence of amino acids to a sequence of one-hot vectors first. Caches the one-hot tensors for computing gradients with respect to the one-hot tensors.

forward(input_ids: torch.LongTensor) -> torch.Tensor

Compute the embeddings for a sequence of amino acids, caching the one-hot tensors for computing gradients with respect to the one-hot tensors.

Parameters:

Name Type Description Default
input_ids torch.LongTensor

Amino acid sequences of shape [batch_size, max_sequence_len].

required

Returns:

Name Type Description
embeddings torch.FloatTensor

Amino acid embeddings of shape [batch_size, max_sequence_len, embedding_dim].

IdentityEmbedding

evo_prot_grad.common.embeddings.IdentityEmbedding

Bases: nn.Module

A module that does nothing except store the most recent one_hots tensor.

forward(one_hots: torch.Tensor) -> torch.Tensor

Cache the one_hots tensor and return it.

Parameters:

Name Type Description Default
one_hots torch.Tensor

A torch.FloatTensor of shape [batch_size, max_sequence_len, vocab_size].

required

Returns:

Name Type Description
one_hots torch.Tensor

The same one_hots tensor that was passed in.