API Documentation
Here lies the official top-level API for interacting with jax-unirep
.
Calculating Representations
jax_unirep.get_reps
jax_unirep.get_reps
(seqs, params=None)Get reps of proteins.
This function generates representations of protein sequences using the 1900 hidden-unit mLSTM model with pre-trained weights from the UniRep paper.
Each element of the output 3-tuple is a np.array
of shape (n_input_sequences, 1900):
h_avg
: Average hidden state of the mLSTM over the whole sequence.h_final
: Final hidden state of the mLSTMc_final
: Final cell state of the mLSTM
You should not use this function
if you want to do further JAX-based computations
on the output vectors!
In that case, the DeviceArray
futures returned by mLSTM1900
should be passed directly into the next step
instead of converting them to np.array
s.
The conversion to np.array
s is done
in the dispatched rep_x_lengths
functions
to force python to wait with returning the values
until the computation is completed.
The keys of the params
dictionary must be:
b, gh, gmh, gmx, gx, wh, wmh, wmx, wx
Parameters
seqs
: A list of sequences as strings or a single string.params
: A dictionary of mLSTM1900 weights.
Returns
A 3-tuple of np.array
s containing the reps,
in the order h_avg
, h_final
, and c_final
.
Each np.array
has shape (n_sequences, 1900).
Evotuning
jax_unirep.fit
jax_unirep.fit
(params, sequences, n_epochs, batch_method='random', batch_size=25, step_size=0.0001, holdout_seqs=None, proj_name='temp', epochs_per_print=1, backend='cpu')Return mLSTM weights fitted to predict the next letter in each AA sequence.
The training loop is as follows, depending on the batching strategy:
Length batching:
- At each iteration,
of all sequence lengths present in
sequences
, one length gets chosen at random. - Next,
batch_size
number of sequences of the chosen length get selected at random. - If there are less sequences of a given length than
batch_size
, all sequences of that length get chosen. - Those sequences then get passed through the model. No padding of sequences occurs.
To get batching of sequences by length done,
we call on batch_sequences
from our utils.py
module,
which returns a list of sub-lists,
in which each sub-list contains the indices
in the original list of sequences
that are of a particular length.
Random batching:
- Before training, all sequences get padded
to be the same length as the longest sequence
in
sequences
. - Then, at each iteration,
we randomly sample
batch_size
sequences and pass them through the model.
The training loop does not adhere
to the common notion of epochs
,
where all sequences would be seen by the model
exactly once per epoch.
Instead sequences always get sampled at random,
and one epoch approximately consists of
round(len(sequences) / batch_size)
weight updates.
Asymptotically, this should be approximately equivalent
to doing epoch passes over the dataset.
To learn more about the passing of params
,
have a look at the evotune
function docstring.
You can optionally dump parameters
and print weights every epochs_per_print
epochs
to monitor training progress.
For ergonomics, training/holdout set losses are estimated
on a batch size the same as batch_size
,
rather than calculated exactly on the entire set.
Set epochs_per_print
to None
to avoid parameter dumping.
Parameters
params
: mLSTM1900 and Dense parameters.sequences
: List of sequences to evotune on.n
: The number of iterations to evotune on.batch_method
: One of "length" or "random".batch_size
: If random batching is used, number of sequences per batch. As a rule of thumb, batch size of 50 consumes about 5GB of GPU RAM.step_size
: The learning rate.holdout_seqs
: Holdout set, an optional input.proj_name
: The directory path for weights to be output to.epochs_per_print
: Number of epochs to progress before printing and dumping of weights. Must be greater than or equal to 1.backend
: Whether or not to use the GPU. Defaults to "cpu", but can be set to "gpu" if desired.
Returns
Final optimized parameters.
jax_unirep.evotune
jax_unirep.evotune
(sequences, params=None, proj_name='temp', out_dom_seqs=None, n_trials=20, n_epochs_config=None, learning_rate_config=None, n_splits=5, epochs_per_print=200)Evolutionarily tune the model to a set of sequences.
Evotuning is described in the original UniRep and eUniRep papers. This reimplementation of evotune provides a nicer API that automatically handles multiple sequences of variable lengths.
Evotuning always needs a starter set of weights. By default, the pre-trained weights from the Nature Methods paper are used. However, other pre-trained weights are legitimate.
We first use optuna to figure out how many epochs to fit before overfitting happens. To save on computation time, the number of trials run defaults to 20, but can be configured.
By default, mLSTM1900 and Dense weights from the paper are used
by passing in params=None
,
but if you want to use randomly intialized weights:
from jax_unirep.evotuning import init_fun
from jax.random import PRNGKey
_, params = init_fun(PRNGKey(0), input_shape=(-1, 10))
or dumped weights:
from jax_unirep.utils import load_params
params = load_params(folderpath="path/to/params/folder")
This function is intended as an automagic way of identifying
the best model and training routine hyperparameters.
If you want more control over how fitting happens,
please use the fit()
function directly.
There is an example in the examples/
directory
that shows how to use it.
Parameters
sequences
: Sequences to evotune against.params
: Parameters to be passed intomLSTM1900
andDense
. Optional; if None, will default to weights from paper, or you can pass in your own set of parameters, as long as they are stax-compatible.proj_name
: Name of the project, used to name created output directory.out_dom_seqs
: Out-domain holdout set of sequences, to check for loss on to prevent overfitting.- `n_trials: The number of trials Optuna should attempt.
n_epochs_config
: A dictionary of kwargs totrial.suggest_discrete_uniform
, which are:name
,low
,high
,q
. This controls how many epochs to have Optuna test. See source code for default configuration, at the definition ofn_epochs_kwargs
.learning_rate_config
: A dictionary of kwargs totrial.suggest_loguniform
, which are:name
,low
,high
. This controls the learning rate of the model. See source code for default configuration, at the definition oflearning_rate_kwargs
.n_splits
: The number of folds of cross-validation to do.epochs_per_print
: The number of steps between each printing and dumping of weights in the final evotuning step using the optimized hyperparameters.
Returns
study
: The optuna study object, containing information about all evotuning trials.evotuned_params
: A dictionary of the final, optimized weights.
Sampling
jax_unirep.sample_one_chain
jax_unirep.sample_one_chain
(starter_sequence, n_steps, scoring_func, is_accepted_kwargs={}, trust_radius=7, propose_kwargs={})Return one chain of MCMC samples of new sequences.
Given a starter_sequence
,
this function will sample one chain of protein sequences,
scored using a user-provided scoring_func
.
Design choices made here include the following.
Firstly, we record all sequences that were sampled, and not just the accepted ones. This behaviour differs from other MCMC samplers that record only the accepted values. We do this just in case sequences that are still "good" (but not better than current) are rejected. The effect here is that we get a cluster of sequences that are one-apart from newly accepted sequences.
Secondly, we check the Hamming distance between the newly proposed sequences and the original. This corresponds to the "trust radius" specified in the jax-unirep paper. If the hamming distance > trust radius, we reject the sequence outright.
A dictionary containing the following key-value pairs are returned:
- "sequences": All proposed sequences.
- "scores": All scores from the scoring function.
- "accept": Whether the sequence was accepted as the new 'current sequence' on which new sequences are proposed.
This can be turned into a pandas DataFrame.
Parameters
starter_sequence
: The starting sequence.n_steps
: Number of steps for the MC chain to walk.scoring_func
: Scoring function for a new sequence. It should only accept a stringsequence
.is_accepted_kwargs
: Dictionary of kwargs to pass intois_accepted
function. Seeis_accepted
docstring for more details.trust_radius
: Maximum allowed number of mutations away from starter sequence.propose_kwargs
: Dictionary of kwargs to pass intopropose
function. Seepropose
docstring for more details.verbose
: Whether or not to print iteration number and associated sequence + score. Defaults to False
Returns
A dictionary with sequences
, accept
and score
as keys.