evo_prot_grad.common.variant_scoring
evo_prot_grad.common.variant_scoring.VariantScoring
Every Expert has a VariantScoring object to use to score variant sequences containing multiple mutations.
Supported scoring strategies
1) attribute_value
- Uses a model's predicted attribute value for a given variant,
normalized by subtracting the wildtype's predicted value.
2) pseudolikelihood_ratio
3) mutant_marginal
See: https://www.biorxiv.org/content/10.1101/2021.07.09.450648v2 for (2-3).
pseudolikelihood(x_oh: torch.Tensor, logits: torch.Tensor) -> torch.Tensor
A pseudo-log-likelihood (pll) score for a protein sequence and model logits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_oh |
torch.Tensor
|
one-hot encoded variant sequences, shape [parallel_chains, seq_len, vocab_size] |
required |
logits |
torch.Tensor
|
predicted logits, of shape [parallel_chains, seq_len, vocab_size] |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
of shape [parallel_chains, seq_len, vocab_size] |
pseudolikelihood_ratio(x_oh: torch.Tensor, logits: torch.Tensor) -> torch.Tensor
Pll ratio with respect to wild type, for scoring variants.
The difference of two terms: pll for a) the variant and b) the wildtype. The input to the model for computing a) is the variant, and the wildtype for computing (b).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_oh |
torch.Tensor
|
one-hot encoded variant sequences, shape [parallel_chains, seq_len, vocab_size] |
required |
logits |
torch.Tensor
|
predicted logits, of shape [parallel_chains, seq_len, vocab_size] |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
of shape [parallel_chains] |
mutant_marginal(x_oh: torch.Tensor, logits: torch.Tensor, wt_oh: torch.Tensor) -> torch.Tensor
Mutant marginal variant scoring mechanism.
The difference of two terms: log-likelihood of a) variant and b) wildtype, summing over the mutation locations. The input to the model to compute logits is the variant, for computing both a) and b). This differs from the pseudo-likelihood ratio since here, the variant is used to compute the likelihood of the wild type (b).
See https://www.biorxiv.org/content/10.1101/2021.07.09.450648v2.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_oh |
torch.Tensor
|
one-hot encoded variant sequences, shape [parallel_chains, seq_len, vocab_size] |
required |
logits |
torch.Tensor
|
predicted logits, of shape [parallel_chains, seq_len, vocab_size] |
required |
wt_oh |
torch.Tensor
|
one-hot encoded wild type sequence, shape [parallel_chains, seq_len, vocab_size] |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
of shape [parallel_chains] |
__call__(x_oh: torch.Tensor, x_pred: torch.Tensor, wt_oh: torch.Tensor) -> torch.Tensor
Returns the mutation score.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_oh |
torch.Tensor
|
one-hot encoded variant sequence, shape [parallel_chains, seq_len, vocab_size] |
required |
x_pred |
torch.Tensor
|
model prediction for the variant,
for example, logits. First dimension should be |
required |
wt_oh |
torch.Tensor
|
one-hot encoded wildtype sequence, shape [parallel_chains, seq_len, vocab_size] |
required |
Returns:
Name | Type | Description |
---|---|---|
variant_score |
torch.Tensor
|
of shape [parallel_chains] |
cache_wt_score(wt_oh: torch.Tensor, wt_pred: torch.Tensor) -> None
Caches the score value for wildtype protein if needed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
wt_oh |
torch.Tensor
|
of shape [1, seq_len, vocab_size]. The one-hot encoded wt protein. |
required |
wt_pred |
torch.Tensor
|
of shape [1, *]. The models prediction for the wt protein. |
required |