Skip to content

evo_prot_grad.common.utils

evo_prot_grad.common.utils.safe_logits_to_probs(logits: torch.Tensor) -> torch.Tensor

safe convert logits to probs.

Parameters:

Name Type Description Default
logits torch.Tensor

[parallel_chains, seq_len, vocab_size]

required

Returns:

Name Type Description
probs torch.Tensor

[parallel_chains, seq_len, vocab_size]

evo_prot_grad.common.utils.mut_distance(x, wt)

Computes edit distance from x to wt

Parameters:

Name Type Description Default
x torch.Tensor

shape [parallel_chains, seq_len, vocab_size].

required
wt torch.Tensor

shape [1, seq_len, vocab_size].

required

Returns:

Name Type Description
edits torch.Tensor

shape [parallel_chains].

evo_prot_grad.common.utils.mutation_mask(x, wt, mutations_value = False)

Create a boolean tensor with locations corresponding to mutations set to mutations_value.

For every pos where x and wt differ, and wt is not a gap (0), the mask is set to mutations_value. Everywhere else set to ~mutations_value.

Parameters:

Name Type Description Default
x torch.Tensor

shape [parallel_chains, seq_len, vocab_size]. X is one-hot encoded.

required
wt torch.Tensor

shape [*, seq_len, vocab_size]. wt is one-hot encoded.

required
mutations_value bool

If True, set the mask to True where mutations are present. Default: False.

False

Returns:

Name Type Description
mask torch.BoolTensor

shape [parallel_chains, seq_len, vocab_size].

evo_prot_grad.common.utils.expert_alphabet_to_canonical(expert_alphabet: List[str], device: str) -> torch.Tensor

Create a binary matrix that shuffles the vocab dimension of a tensor in an expert's AA alphabet order to the canonical AA alphabet order.

Parameters:

Name Type Description Default
expert_alphabet List[str]

The amino acid vocab used by the expert.

required

Returns:

Name Type Description
alignment_matrix torch.Tensor

tensor of shape [len(expert_alphabet), len(CANONICAL_ALPHABET)].

evo_prot_grad.common.utils.set_seed(seed: int) -> None

Set random seed for reproducibility.

Parameters:

Name Type Description Default
seed int

The seed to set.

required

evo_prot_grad.common.utils.print_variant_in_color(seq: str, wt: str, ignore_gaps: bool = True) -> None

Print a variant sequence with highlighted mutations.

Parameters:

Name Type Description Default
seq str

The variant sequence.

required
wt str

The wildtype sequence.

required
ignore_gaps bool

If True, ignore gaps (- or X) in the comparison. Default: True.

True

evo_prot_grad.common.utils.read_fasta(fasta_file: str) -> str

Read a fasta, return string.