evo_prot_grad.common.sampler
DirectedEvolution
evo_prot_grad.common.sampler.DirectedEvolution
Main class for plug and play directed evolution with gradient-based discrete MCMC.
__init__(experts: List[Expert], parallel_chains: int, n_steps: int, max_mutations: int, output: str = 'last', preserved_regions: Optional[List[Tuple[int, int]]] = None, wt_protein: Optional[str] = None, wt_fasta: Optional[str] = None, verbose: bool = False, random_seed: Optional[int] = None)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
experts |
List[Expert]
|
List of experts |
required |
parallel_chains |
int
|
number of parallel chains |
required |
n_steps |
int
|
number of steps to run directed evolution |
required |
max_mutations |
int
|
maximum mutation distance from WT, disable by setting to -1. |
required |
output |
str
|
output type, either 'best', 'last' or 'all'. Default is 'last'. |
'last'
|
preserved_regions |
List[Tuple[int, int]]
|
list of tuples of (start, end) of preserved regions. Default is None. |
None
|
wt_protein |
str
|
wt sequence as a string. Must provide one of wt_protein or wt_fasta. |
None
|
wt_fasta |
str
|
path to fasta file containing wt sequence. Must provide one of wt_protein or wt_fasta. |
None
|
verbose |
bool
|
whether to print verbose output. Default is False. |
False
|
random_seed |
int
|
random seed for reproducibility. Default is None. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
if |
ValueError
|
if neither |
ValueError
|
if a fasta file is passed to |
ValueError
|
if |
ValueError
|
if no experts are provided. |
ValueError
|
if any of the preserved regions are < 1 amino acid long. |
reset()
Initialize the parallel chains of protein sequences.
_prepare_results(variants, scores, n_seqs_to_keep = None)
Prepare the results by sorting and selecting the top sequences.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
variants |
List[str]
|
The list of sequence variants. Shape is (n_steps, parallel_chains). |
required |
scores |
np.ndarray
|
The scores for the sequence variants. Shape is (n_steps, parallel_chains). |
required |
n_seqs_to_keep |
int
|
Number of sequences to keep. Default is None (keep all). |
None
|
Returns:
Type | Description |
---|---|
pd.DataFrame
|
DataFrame of results. |
_get_variants_and_scores() -> Tuple[List[str], np.ndarray]
Get the variants and scores based on the output type.
_product_of_experts(inputs: List[str]) -> Tuple[List[torch.Tensor], torch.Tensor]
Compute the product of experts. Computes each expert score, multiplies it by the expert temperature, and aggregates the scores by summation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
List[str]
|
list of protein sequences of len [parallel_chains] |
required |
Returns:
Name | Type | Description |
---|---|---|
ohs |
List[torch.Tensor]
|
list of one-hot encoded sequences of len [parallel_chains] |
PoE |
torch.Tensor
|
product of experts score of shape [parallel_chains] |
_compute_gradients(ohs: List[torch.Tensor], PoE: torch.Tensor) -> torch.Tensor
Compute the gradients of the product of experts with respect to the one-hots. We put each expert's amino acid alphabet, used to construct one-hot inputs, in a canonical order before summing gradients together.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ohs |
List[torch.Tensor]
|
tensor one-hot embeddings of shape [parallel_chains, seq_len, vocab_size]. List is of length # experts |
required |
PoE |
torch.Tensor
|
product of experts score of shape [parallel_chains] |
required |
Returns:
Name | Type | Description |
---|---|---|
grads |
torch.Tensor
|
gradients of the product of experts with respect to the one-hots. |
save_results(csv_filename: str, variants: Optional[List[str]] = None, scores: Optional[np.ndarray] = None, n_seqs_to_keep: int = 10000) -> None
Save the output sequences and scores to a CSV file. Also saves the params
used to run the sampler in a _params.txt
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
csv_filename |
str
|
Filename for saving the results. Ends in .csv. |
required |
variants |
list of list of str
|
The list of sequence variants. |
None
|
scores |
torch.Tensor
|
The scores for the sequence variants. |
None
|
n_seqs_to_keep |
int
|
Number of sequences to keep in the results. Default is 10000. |
10000
|
__call__() -> Tuple[List[str], np.ndarray]
Run the gradient-based MCMC sampler.
Returns:
Name | Type | Description |
---|---|---|
variants |
List[str]
|
list of protein sequences |
scores |
np.ndarray
|
the product of expert scores for the variants |