Skip to content

SpliceVI

SpliceVI is a multimodal variational autoencoder that jointly models gene expression and alternative splicing (junction usage / PSI) from single-cell data. It learns a shared low-dimensional latent space that captures variation across both modalities, enabling tasks like imputation, clustering, trajectory inference, and differential splicing analysis.

SpliceVI follows the same overall architecture and inherits the same scvi-tools mixins as MultiVI (Ashuach et al., 2021), which jointly models RNA-seq and chromatin accessibility. The key difference is that the ATAC branch is replaced throughout by a splicing-specific branch — a missingness-aware encoder, a splicing-specific decoder, and splicing likelihoods (Binomial / Beta-Binomial / Dirichlet-Multinomial) suited to the PSI scale. The DE/DS analysis API mirrors MultiVI's, with get_normalized_splicing and differential_splicing replacing the accessibility equivalents.

Ashuach T*, Gabitto MI*, Jordan MI, Yosef N. MultiVI: deep generative model for the integration of multi-modal data. bioRxiv, 2021.


Model overview

SpliceVI uses two separate encoder branches — one for gene expression and one for splicing — whose posteriors are mixed into a joint latent space. Two decoder branches reconstruct each modality from the joint latent.

Gene expression ──► GE Encoder ──► q(z|x_e) ──┐
                                                ├──► Mix ──► z ──► GE Decoder ──► x_e
Splicing (PSI) ───► Splice Encoder ──► q(z|x_s) ─┘              └──► Splice Decoder ──► x_s

Splicing encoder

Because splicing observations are highly sparse — many junctions are unobserved in a given cell — SpliceVI uses a missingness-aware partial encoder (PartialEncoderEDDI). The design draws inspiration from the EDDI framework (Ma et al., 2019), which handles pervasive missing values in VAEs using feature embeddings and exchangeable neural networks (see also the SpliceVI preprint). The encoder:

  1. Embeds each observed junction independently via a shared \(h\) network
  2. Pools the per-junction embeddings (mean or sum) across all observed junctions in a cell
  3. Maps the pooled representation through an MLP to produce posterior statistics \((\mu, \sigma^2)\)

This avoids any imputation of missing values at the input level and is more principled than masking or zero-filling.

Splicing likelihood

Three options are supported:

Option Description
binomial Simple binomial over junction read counts
beta_binomial Beta-binomial with a learned per-junction concentration \(\phi_j\)
dirichlet_multinomial Dirichlet-multinomial grouped per ATSE event; enforces that junction probabilities within an event sum to 1

The default and recommended option is dirichlet_multinomial.

Modality mixing

The two per-modality posteriors are combined into a single joint posterior before sampling. The mixing strategy is controlled by the modality_weights argument, and follows the same parameterization as MultiVI:

Option Description
equal Fixed 50/50 average
universal One global learned weight per modality
cell Per-cell learned weights
concatenate Concatenate both latents — no mixing, doubled latent size

Architecture

Component Default Options
Splicing encoder partial vanilla, partial
Splicing decoder vanilla vanilla, linear
Expression decoder vanilla vanilla, linear
Splicing likelihood dirichlet_multinomial binomial, beta_binomial, dirichlet_multinomial
Latent distribution normal normal, ln

Key hyperparameters

Parameter Default Description
n_latent 30 Joint latent dimensionality
n_hidden auto Width of encoder/decoder MLPs
n_layers_encoder 2 Depth of encoder networks
dropout_rate 0.01 Dropout in MLPs
code_dim 16 Per-junction embedding size (partial encoder)
h_hidden_dim 64 Hidden width of per-junction \(h\) network
encoder_hidden_dim 128 Hidden width of the post-pooling MLP

API reference

splicevi.SPLICEVI

Bases: VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin

Integration of gene expression and alternative splicing signals.

SPLICEVI integrates gene expression and alternative splicing (junction-usage) modalities, learning a joint latent space with reconstruction heads per modality. It wraps the :class:~scvi.module.SPLICEVAE module and adds training utilities and convenience APIs.

Parameters:

Name Type Description Default
adata AnnOrMuData

AnnData/MuData object that has been registered via :meth:~scvi.model.SPLICEVI.setup_anndata or :meth:~scvi.model.SPLICEVI.setup_mudata.

required
modality_weights Literal['equal', 'cell', 'universal', 'concatenate', 'per_dimension_weighted_average']

How to weight modalities when forming the joint latent: * "equal" – equal weight per modality, * "universal" – one global weight per modality (learned), * "cell" – per-cell weights (learned), * "concatenate" – do not mix; concatenate per-modality latents, * "per_dimension_weighted_average" – one learned splicing weight per latent dimension (expression weight = 1 - splicing weight), constrained so the mean weight stays at 0.5 to prevent either modality from dominating globally.

'equal'
modality_penalty Literal['Jeffreys', 'MMD', 'None']

Alignment penalty across modalities: "Jeffreys", "MMD", or "None".

'Jeffreys'
gene_likelihood Literal['zinb', 'nb', 'poisson']

Expression likelihood: "zinb", "nb", or "poisson".

'zinb'
dispersion Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']

Expression dispersion parameterization: "gene", "gene-batch", "gene-label", or "gene-cell".

'gene'
splicing_encoder_architecture Literal['vanilla', 'partial']

Splicing encoder: * "vanilla" – SCVI Encoder, * "partial"PartialEncoderEDDIFaster.

'partial'
splicing_decoder_architecture Literal['vanilla', 'linear']

Splicing decoder: * "vanilla" – nonlinear DecoderSplice, * "linear" – linear decoder.

'linear'
expression_architecture Literal['vanilla', 'linear']

Expression decoder: "vanilla" (nonlinear) or "linear" (linear decoder).

'linear'
n_hidden int | None

Width of SCVI encoders/decoders. If None, uses √(n_junctions) (capped at 128).

None
n_latent int | None

Joint latent dimensionality (per-modality before concatenation). If None, uses √(n_hidden).

None
n_layers_encoder int

Hidden layers in encoders (including the post pooling MLP in the partial splicing encoder).

2
n_layers_decoder int

Hidden layers in decoders (not used by the linear decoders).

2
dropout_rate float

Dropout rate in MLPs.

0.1
use_batch_norm Literal['encoder', 'decoder', 'none', 'both']

Where to apply batch norm: "encoder", "decoder", "both", or "none".

'none'
use_layer_norm Literal['encoder', 'decoder', 'none', 'both']

Where to apply layer norm: "encoder", "decoder", "both", or "none".

'both'
latent_distribution Literal['normal', 'ln']

Latent distribution: "normal" or "ln" (logistic normal).

'normal'
deeply_inject_covariates bool

If True, injects covariates at all decoder layers.

False
encode_covariates bool

If True, provides covariates to encoders.

False
splicing_loss_type Literal['binomial', 'beta_binomial', 'dirichlet_multinomial']

Splicing reconstruction loss: "binomial", "beta_binomial", or "dirichlet_multinomial".

'dirichlet_multinomial'
splicing_concentration float | None

Optional concentration for beta binomial. Ignored for binomial.

None
dm_concentration Literal['atse', 'scalar']

For Dirichlet multinomial: "atse" (per ATSE concentration) or "scalar" (single shared).

'atse'
encoder_hidden_dim int

Hidden width inside the PartialEncoder MLP that maps pooled junction codes to latent statistics.

128
code_dim int

Dimensionality of per junction embeddings in PartialEncoderEDDIFaster.

16
h_hidden_dim int

Hidden width of the shared h network that processes (psi, feature_embedding) for each observed junction.

64
pool_mode Literal['mean', 'sum']

Pooling across observed junctions per cell: "mean" or "sum".

'mean'
max_nobs int

Optional cap on the number of observed (cell, junction) pairs processed in a single scatter chunk. Set to a negative value to disable chunking.

-1
n_genes int | None

Number of gene expression features. Required if adata is AnnData.

None
n_junctions int | None

Number of splicing features. Required if adata is AnnData.

None
initialize_embeddings_from_pca bool

If True and using the partial splicing encoder, initialize the per junction embedding table via PCA on junction ratios.

False
fully_paired bool

If True, the model exposes only a joint latent (no modality specific latents).

False
**model_kwargs

Forwarded to :class:~scvi.module.SPLICEVAE.

{}

differential_expression(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.25, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, **kwargs)

Differential expression analysis.

Performs differential gene expression analysis using normalized gene expression estimates.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments for differential computation.

{}

Returns:

Type Description
A pandas DataFrame with differential expression results.

differential_splicing(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.1, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, norm_splicing_function='dm_posterior_mean', **kwargs)

Differential splicing analysis.

Performs differential junction usage analysis using normalized PSI from the model's splicing decoder.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments for differential computation.

{}

Returns:

Type Description
A pandas DataFrame with differential splicing results.

get_latent_representation(adata=None, modality='joint', indices=None, give_mean=True, batch_size=None)

Return the latent representation for each cell.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object used in setup.

None
modality Literal['joint', 'expression', 'splicing']

One of: - "joint": joint latent space, - "expression": expression-specific latent space, - "splicing": splicing-specific latent space.

'joint'
indices Sequence[int] | None

Cell indices to use.

None
give_mean bool

If True, returns the mean of the latent distribution.

True
batch_size int | None

Batch size for processing.

None

Returns:

Type Description
A NumPy array of the latent representations.

get_library_size_factors(adata=None, indices=None, batch_size=128)

Return library size factors for gene expression.

Note that the splicing modality does not use library size factors.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object.

None
indices Sequence[int]

Cell indices (default: all cells).

None
batch_size int

Batch size for processing.

128

Returns:

Type Description
A dictionary with key "expression" for the gene expression library size factors.

get_normalized_expression(adata=None, indices=None, n_samples_overall=None, transform_batch=None, gene_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)

Returns the normalized (decoded) gene expression.

This is denoted as :math:\rho_n in the scVI paper.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object with equivalent structure to initial AnnData. If None, defaults to the AnnOrMuData object used to initialize the model.

None
indices Sequence[int] | None

Indices of cells in adata to use. If None, all cells are used.

None
n_samples_overall int | None

Number of observations to sample from indices if indices is provided.

None
transform_batch Sequence[Number | str] | None

Batch to condition on. If transform_batch is:

  • None, then real observed batch is used.
  • int, then batch transform_batch is used.
None
gene_list Sequence[str] | None

Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest.

None
use_z_mean bool

If True, use the mean of the latent distribution, otherwise sample from it

True
n_samples int

Number of posterior samples to use for estimation.

1
batch_size int | None

Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

None
return_mean bool

Whether to return the mean of the samples.

True
return_numpy bool

Return a numpy array instead of a pandas DataFrame.

False

Returns:

Type Description
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is
class:`~pandas.DataFrame` unless `return_numpy` is True.

get_normalized_splicing(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)

Returns the normalized (decoded) splicing probabilities.

This is denoted as :math:p_{nj} in the SPLICEVI model.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object with the same structure as used in setup. If None, defaults to the AnnOrMuData object used to initialize the model.

None
indices Sequence[int] | None

Cell indices to use. If None, all cells are used.

None
n_samples_overall int | None

Number of observations to sample from indices if provided.

None
transform_batch Sequence[Number | str] | None

Batch(s) to condition on: - None: use true observed batch - int: force all cells to that batch - list[int|str]: average over those batches

None
junction_list Sequence[str] | None

Subset of junction names to return. If None, returns all junctions.

None
use_z_mean bool

If True, use the mean of the latent distribution; otherwise sample.

True
n_samples int

Number of posterior samples to draw. If >1 and return_mean is True, the result is averaged over draws.

1
batch_size int | None

Minibatch size for decoding. Defaults to scvi.settings.batch_size.

None
return_mean bool

Whether to average over posterior samples when n_samples>1.

True
return_numpy bool

If True, returns a NumPy array; otherwise a pandas DataFrame.

False
silent bool

If True, suppresses the progress bar.

True

Returns:

Type Description
A NumPy array or pandas DataFrame of shape `(cells, junctions)` containing the
decoded splicing probabilities.

get_normalized_splicing_DM(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)

Returns DM-normalized splicing probabilities (PSI*).

This decodes junction PSI (:math:p) and applies a Dirichlet–multinomial posterior mean smoothing per junction: :math:\psi^\* = (c p + y_j) / (c + n), where :math:y_j is the observed junction count and :math:n is the observed ATSE total for that junction. The concentration :math:c is taken from the module: * If self.dm_concentration == "atse": use the per-ATSE values in self.module.log_phi_j mapped to per-junction via self.module.junc2atse. * Otherwise: treat self.module.log_phi_j as a scalar concentration.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object with equivalent structure to initialization. If None, defaults to the object used to initialize the model.

None
indices Sequence[int] | None

Indices of cells in adata to use. If None, all cells are used.

None
n_samples_overall int | None

Number of observations to sample from indices if provided.

None
transform_batch Sequence[Number | str] | None

Batch conditioning: - None: use observed batch - int/str: force decode to that batch - list[int|str]: average decoded outputs over listed batches

None
junction_list Sequence[str] | None

Optional subset of junction names to return.

None
use_z_mean bool

If True, decode from the mean of the latent; else sample.

True
n_samples int

Posterior samples to draw. If >1 and return_mean is True, results are averaged over samples.

1
batch_size int | None

Minibatch size for decoding.

None
return_mean bool

If True and n_samples > 1, average samples before returning.

True
return_numpy bool

If True, return a NumPy array; else a pandas DataFrame.

False
silent bool

If True, disables progress display.

True

Returns:

Type Description
Array or DataFrame of shape ``(cells, junctions)`` with DM-normalized PSI*.

setup_anndata(adata, layer=None, junc_ratio=None, cell_by_junction_matrix=None, cell_by_cluster_matrix=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs) classmethod

Set up an AnnData object for SPLICEVI.

Parameters:

Name Type Description Default
junc_ratio str | None

Key in adata.layers for junction ratio values.

None
cell_by_junction_matrix str | None

Key in adata.layers for the cell-by-junction matrix.

None
cell_by_cluster_matrix str | None

Key in adata.layers for the cell-by-cluster splicing matrix.

None
psi_mask_layer str | None

Layer with binary mask (1=observed, 0=missing) per junction.

None
Notes

Use this method if your splicing data is stored in an AnnData object where gene expression and splicing features are concatenated.

setup_mudata(mdata, rna_layer=None, junc_ratio_layer=None, atse_counts_layer=None, junc_counts_layer=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, idx_layer=None, modalities=None, **kwargs) classmethod

Set up a MuData object for SPLICEVI.

Parameters:

Name Type Description Default
rna_layer str | None

Key in the RNA AnnData for gene expression counts. If None, the primary data (.X) of that AnnData is used.

None
junc_ratio_layer str | None

Key in the splicing AnnData for junction ratio values. If None, the primary data (.X) of that AnnData is used.

None
atse_counts_layer str | None

Key in the splicing AnnData for total event counts. If None, defaults to "cell_by_cluster_matrix".

None
junc_counts_layer str | None

Key in the splicing AnnData for observed junction counts. If None, defaults to "cell_by_junction_matrix".

None
psi_mask_layer str | None

Layer with binary mask (1=observed, 0=missing) per junction.

None
size_factor_key str | None

Key in mdata.obsm for size factors.

None

Examples:

>>> mdata = mu.MuData({
...    "rna": ge_anndata.copy(),
...    "splicing": atse_anndata.copy()
... })
>>> scvi.model.SPLICEVI.setup_mudata(
...     mdata,
...     modalities={"rna_layer": "rna", "junc_ratio_layer": "splicing"},
...     rna_layer="raw_counts",            # gene expression data is in the GE AnnData's "raw_counts" layer
...     junc_ratio_layer="junc_ratio",     # splicing data is in the ATSE AnnData's "junc_ratio" layer
... )
>>> model = scvi.model.SPLICEVI(mdata)

train(max_epochs=500, lr=0.0001, accelerator='auto', devices='auto', train_size=None, validation_size=None, shuffle_set_split=True, batch_size=128, weight_decay=0.001, eps=1e-08, early_stopping=True, early_stopping_patience=50, check_val_every_n_epoch=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=50, adversarial_mixing=True, lr_scheduler_type='plateau', reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', step_size=10, gradient_clipping=True, gradient_clipping_max_norm=5.0, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)

Trains the model using amortized variational inference on gene expression and splicing modalities.

Parameters:

Name Type Description Default
max_epochs int

Number of epochs to train over.

500
lr float

Learning rate for optimization.

0.0001
accelerator str

Hardware acceleration options.

'auto'
devices str

Hardware acceleration options.

'auto'
train_size float | None

Proportions for splitting the data into train and validation sets.

None
validation_size float | None

Proportions for splitting the data into train and validation sets.

None
shuffle_set_split bool

Whether to shuffle before splitting.

True
batch_size int

Minibatch size for training.

128
weight_decay float

Optimizer hyperparameters.

0.001
eps float

Optimizer hyperparameters.

0.001
early_stopping bool

Whether to enable early stopping.

True
early_stopping_patience int

Number of epochs with no improvement before stopping.

50
check_val_every_n_epoch int | None

How often (in epochs) to run validation.

None
n_steps_kl_warmup int | None

KL divergence warmup parameters (by steps or epochs).

None
n_epochs_kl_warmup int | None

KL divergence warmup parameters (by steps or epochs).

None
adversarial_mixing bool

Whether to include adversarial classifier during training.

True
lr_scheduler_type Literal['plateau', 'step']

Scheduler type in TrainingPlan: “plateau” (reduce-on-plateau) or “step” (fixed-step).

'plateau'
reduce_lr_on_plateau bool

If True and using plateau scheduler, enable ReduceLROnPlateau.

False
lr_factor float

Multiplicative factor for LR reduction (used for both plateau and step schedulers).

0.6
lr_patience int

Number of epochs with no improvement for plateau scheduler.

30
lr_threshold float

Threshold for measuring new optimum (plateau scheduler).

0.0
lr_scheduler_metric Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']

Metric to monitor for plateau scheduler.

'elbo_validation'
step_size int

Epoch interval between LR drops (step scheduler).

10
gradient_clipping bool

Whether or not (true or false) to use gradient norm clipping

True
gradient_clipping_max_norm float

Max norm of the gradients to be used in gradient clipping

5.0
datasplitter_kwargs dict | None

Additional kwargs for the data splitter.

None
plan_kwargs dict | None

Additional kwargs to pass to the TrainingPlan constructor.

None
**kwargs

Additional Trainer kwargs (callbacks, strategy, etc.).

{}