evo_prot_grad.common.utils
evo_prot_grad.common.utils.safe_logits_to_probs(logits: torch.Tensor) -> torch.Tensor
safe convert logits to probs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits |
torch.Tensor
|
[parallel_chains, seq_len, vocab_size] |
required |
Returns:
Name | Type | Description |
---|---|---|
probs |
torch.Tensor
|
[parallel_chains, seq_len, vocab_size] |
evo_prot_grad.common.utils.mut_distance(x, wt)
Computes edit distance from x to wt
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
shape [parallel_chains, seq_len, vocab_size]. |
required |
wt |
torch.Tensor
|
shape [1, seq_len, vocab_size]. |
required |
Returns:
Name | Type | Description |
---|---|---|
edits |
torch.Tensor
|
shape [parallel_chains]. |
evo_prot_grad.common.utils.mutation_mask(x, wt, mutations_value = False)
Create a boolean tensor with locations corresponding to mutations set to mutations_value
.
For every pos where x and wt differ, and wt is not a gap (0), the mask is set to mutations_value
.
Everywhere else set to ~mutations_value
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
shape [parallel_chains, seq_len, vocab_size]. X is one-hot encoded. |
required |
wt |
torch.Tensor
|
shape [*, seq_len, vocab_size]. wt is one-hot encoded. |
required |
mutations_value |
bool
|
If True, set the mask to True where mutations are present. Default: False. |
False
|
Returns:
Name | Type | Description |
---|---|---|
mask |
torch.BoolTensor
|
shape [parallel_chains, seq_len, vocab_size]. |
evo_prot_grad.common.utils.expert_alphabet_to_canonical(expert_alphabet: List[str], device: str) -> torch.Tensor
Create a binary matrix that shuffles the vocab dimension of a tensor in an expert's AA alphabet order to the canonical AA alphabet order.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expert_alphabet |
List[str]
|
The amino acid vocab used by the expert. |
required |
Returns:
Name | Type | Description |
---|---|---|
alignment_matrix |
torch.Tensor
|
tensor of shape [len(expert_alphabet), len(CANONICAL_ALPHABET)]. |
evo_prot_grad.common.utils.set_seed(seed: int) -> None
Set random seed for reproducibility.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seed |
int
|
The seed to set. |
required |
evo_prot_grad.common.utils.print_variant_in_color(seq: str, wt: str, ignore_gaps: bool = True) -> None
Print a variant sequence with highlighted mutations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seq |
str
|
The variant sequence. |
required |
wt |
str
|
The wildtype sequence. |
required |
ignore_gaps |
bool
|
If True, ignore gaps ( |
True
|
evo_prot_grad.common.utils.read_fasta(fasta_file: str) -> str
Read a fasta, return string.