Skip to content

API Documentation

Here lies the official top-level API for interacting with jax-unirep.

Calculating Representations

jax_unirep.get_reps

jax_unirep.get_reps(seqs, params=None)

Get reps of proteins.

This function generates representations of protein sequences using the 1900 hidden-unit mLSTM model with pre-trained weights from the UniRep paper.

Each element of the output 3-tuple is a np.array of shape (n_input_sequences, 1900):

  • h_avg: Average hidden state of the mLSTM over the whole sequence.
  • h_final: Final hidden state of the mLSTM
  • c_final: Final cell state of the mLSTM

You should not use this function if you want to do further JAX-based computations on the output vectors! In that case, the DeviceArray futures returned by mLSTM1900 should be passed directly into the next step instead of converting them to np.arrays. The conversion to np.arrays is done in the dispatched rep_x_lengths functions to force python to wait with returning the values until the computation is completed.

The keys of the params dictionary must be:

b, gh, gmh, gmx, gx, wh, wmh, wmx, wx

Parameters

  • seqs: A list of sequences as strings or a single string.
  • params: A dictionary of mLSTM1900 weights.

Returns

A 3-tuple of np.arrays containing the reps, in the order h_avg, h_final, and c_final. Each np.array has shape (n_sequences, 1900).

Evotuning

jax_unirep.fit

jax_unirep.fit(params, sequences, n_epochs, batch_method='random', batch_size=25, step_size=0.0001, holdout_seqs=None, proj_name='temp', epochs_per_print=1, backend='cpu')

Return mLSTM weights fitted to predict the next letter in each AA sequence.

The training loop is as follows, depending on the batching strategy:

Length batching:

  • At each iteration, of all sequence lengths present in sequences, one length gets chosen at random.
  • Next, batch_size number of sequences of the chosen length get selected at random.
  • If there are less sequences of a given length than batch_size, all sequences of that length get chosen.
  • Those sequences then get passed through the model. No padding of sequences occurs.

To get batching of sequences by length done, we call on batch_sequences from our utils.py module, which returns a list of sub-lists, in which each sub-list contains the indices in the original list of sequences that are of a particular length.

Random batching:

  • Before training, all sequences get padded to be the same length as the longest sequence in sequences.
  • Then, at each iteration, we randomly sample batch_size sequences and pass them through the model.

The training loop does not adhere to the common notion of epochs, where all sequences would be seen by the model exactly once per epoch. Instead sequences always get sampled at random, and one epoch approximately consists of round(len(sequences) / batch_size) weight updates. Asymptotically, this should be approximately equivalent to doing epoch passes over the dataset.

To learn more about the passing of params, have a look at the evotune function docstring.

You can optionally dump parameters and print weights every epochs_per_print epochs to monitor training progress. For ergonomics, training/holdout set losses are estimated on a batch size the same as batch_size, rather than calculated exactly on the entire set. Set epochs_per_print to None to avoid parameter dumping.

Parameters

  • params: mLSTM1900 and Dense parameters.
  • sequences: List of sequences to evotune on.
  • n: The number of iterations to evotune on.
  • batch_method: One of "length" or "random".
  • batch_size: If random batching is used, number of sequences per batch. As a rule of thumb, batch size of 50 consumes about 5GB of GPU RAM.
  • step_size: The learning rate.
  • holdout_seqs: Holdout set, an optional input.
  • proj_name: The directory path for weights to be output to.
  • epochs_per_print: Number of epochs to progress before printing and dumping of weights. Must be greater than or equal to 1.
  • backend: Whether or not to use the GPU. Defaults to "cpu", but can be set to "gpu" if desired.

Returns

Final optimized parameters.

jax_unirep.evotune

jax_unirep.evotune(sequences, params=None, proj_name='temp', out_dom_seqs=None, n_trials=20, n_epochs_config=None, learning_rate_config=None, n_splits=5, epochs_per_print=200)

Evolutionarily tune the model to a set of sequences.

Evotuning is described in the original UniRep and eUniRep papers. This reimplementation of evotune provides a nicer API that automatically handles multiple sequences of variable lengths.

Evotuning always needs a starter set of weights. By default, the pre-trained weights from the Nature Methods paper are used. However, other pre-trained weights are legitimate.

We first use optuna to figure out how many epochs to fit before overfitting happens. To save on computation time, the number of trials run defaults to 20, but can be configured.

By default, mLSTM1900 and Dense weights from the paper are used by passing in params=None, but if you want to use randomly intialized weights:

from jax_unirep.evotuning import init_fun
from jax.random import PRNGKey

_, params = init_fun(PRNGKey(0), input_shape=(-1, 10))

or dumped weights:

from jax_unirep.utils import load_params

params = load_params(folderpath="path/to/params/folder")

This function is intended as an automagic way of identifying the best model and training routine hyperparameters. If you want more control over how fitting happens, please use the fit() function directly. There is an example in the examples/ directory that shows how to use it.

Parameters

  • sequences: Sequences to evotune against.
  • params: Parameters to be passed into mLSTM1900 and Dense. Optional; if None, will default to weights from paper, or you can pass in your own set of parameters, as long as they are stax-compatible.
  • proj_name: Name of the project, used to name created output directory.
  • out_dom_seqs: Out-domain holdout set of sequences, to check for loss on to prevent overfitting.
  • `n_trials: The number of trials Optuna should attempt.
  • n_epochs_config: A dictionary of kwargs to trial.suggest_discrete_uniform, which are: name, low, high, q. This controls how many epochs to have Optuna test. See source code for default configuration, at the definition of n_epochs_kwargs.
  • learning_rate_config: A dictionary of kwargs to trial.suggest_loguniform, which are: name, low, high. This controls the learning rate of the model. See source code for default configuration, at the definition of learning_rate_kwargs.
  • n_splits: The number of folds of cross-validation to do.
  • epochs_per_print: The number of steps between each printing and dumping of weights in the final evotuning step using the optimized hyperparameters.

Returns

  • study: The optuna study object, containing information about all evotuning trials.
  • evotuned_params: A dictionary of the final, optimized weights.

Sampling

jax_unirep.sample_one_chain

jax_unirep.sample_one_chain(starter_sequence, n_steps, scoring_func, is_accepted_kwargs={}, trust_radius=7, propose_kwargs={})

Return one chain of MCMC samples of new sequences.

Given a starter_sequence, this function will sample one chain of protein sequences, scored using a user-provided scoring_func.

Design choices made here include the following.

Firstly, we record all sequences that were sampled, and not just the accepted ones. This behaviour differs from other MCMC samplers that record only the accepted values. We do this just in case sequences that are still "good" (but not better than current) are rejected. The effect here is that we get a cluster of sequences that are one-apart from newly accepted sequences.

Secondly, we check the Hamming distance between the newly proposed sequences and the original. This corresponds to the "trust radius" specified in the jax-unirep paper. If the hamming distance > trust radius, we reject the sequence outright.

A dictionary containing the following key-value pairs are returned:

  • "sequences": All proposed sequences.
  • "scores": All scores from the scoring function.
  • "accept": Whether the sequence was accepted as the new 'current sequence' on which new sequences are proposed.

This can be turned into a pandas DataFrame.

Parameters

  • starter_sequence: The starting sequence.
  • n_steps: Number of steps for the MC chain to walk.
  • scoring_func: Scoring function for a new sequence. It should only accept a string sequence.
  • is_accepted_kwargs: Dictionary of kwargs to pass into is_accepted function. See is_accepted docstring for more details.
  • trust_radius: Maximum allowed number of mutations away from starter sequence.
  • propose_kwargs: Dictionary of kwargs to pass into propose function. See propose docstring for more details.
  • verbose: Whether or not to print iteration number and associated sequence + score. Defaults to False

Returns

A dictionary with sequences, accept and score as keys.