Skip to content

evo_prot_grad.common.variant_scoring

evo_prot_grad.common.variant_scoring.VariantScoring

Every Expert has a VariantScoring object to use to score variant sequences containing multiple mutations.

Supported scoring strategies

1) attribute_value - Uses a model's predicted attribute value for a given variant, normalized by subtracting the wildtype's predicted value. 2) pseudolikelihood_ratio 3) mutant_marginal

See: https://www.biorxiv.org/content/10.1101/2021.07.09.450648v2 for (2-3).

pseudolikelihood(x_oh: torch.Tensor, logits: torch.Tensor) -> torch.Tensor

A pseudo-log-likelihood (pll) score for a protein sequence and model logits.

Parameters:

Name Type Description Default
x_oh torch.Tensor

one-hot encoded variant sequences, shape [parallel_chains, seq_len, vocab_size]

required
logits torch.Tensor

predicted logits, of shape [parallel_chains, seq_len, vocab_size]

required

Returns:

Type Description
torch.Tensor

of shape [parallel_chains, seq_len, vocab_size]

pseudolikelihood_ratio(x_oh: torch.Tensor, logits: torch.Tensor) -> torch.Tensor

Pll ratio with respect to wild type, for scoring variants.

The difference of two terms: pll for a) the variant and b) the wildtype. The input to the model for computing a) is the variant, and the wildtype for computing (b).

Parameters:

Name Type Description Default
x_oh torch.Tensor

one-hot encoded variant sequences, shape [parallel_chains, seq_len, vocab_size]

required
logits torch.Tensor

predicted logits, of shape [parallel_chains, seq_len, vocab_size]

required

Returns:

Type Description
torch.Tensor

of shape [parallel_chains]

mutant_marginal(x_oh: torch.Tensor, logits: torch.Tensor, wt_oh: torch.Tensor) -> torch.Tensor

Mutant marginal variant scoring mechanism.

The difference of two terms: log-likelihood of a) variant and b) wildtype, summing over the mutation locations. The input to the model to compute logits is the variant, for computing both a) and b). This differs from the pseudo-likelihood ratio since here, the variant is used to compute the likelihood of the wild type (b).

See https://www.biorxiv.org/content/10.1101/2021.07.09.450648v2.

Parameters:

Name Type Description Default
x_oh torch.Tensor

one-hot encoded variant sequences, shape [parallel_chains, seq_len, vocab_size]

required
logits torch.Tensor

predicted logits, of shape [parallel_chains, seq_len, vocab_size]

required
wt_oh torch.Tensor

one-hot encoded wild type sequence, shape [parallel_chains, seq_len, vocab_size]

required

Returns:

Type Description
torch.Tensor

of shape [parallel_chains]

__call__(x_oh: torch.Tensor, x_pred: torch.Tensor, wt_oh: torch.Tensor) -> torch.Tensor

Returns the mutation score.

Parameters:

Name Type Description Default
x_oh torch.Tensor

one-hot encoded variant sequence, shape [parallel_chains, seq_len, vocab_size]

required
x_pred torch.Tensor

model prediction for the variant, for example, logits. First dimension should be parallel_chains

required
wt_oh torch.Tensor

one-hot encoded wildtype sequence, shape [parallel_chains, seq_len, vocab_size]

required

Returns:

Name Type Description
variant_score torch.Tensor

of shape [parallel_chains]

cache_wt_score(wt_oh: torch.Tensor, wt_pred: torch.Tensor) -> None

Caches the score value for wildtype protein if needed.

Parameters:

Name Type Description Default
wt_oh torch.Tensor

of shape [1, seq_len, vocab_size]. The one-hot encoded wt protein.

required
wt_pred torch.Tensor

of shape [1, *]. The models prediction for the wt protein.

required