SpliceVI
SpliceVI is a multimodal variational autoencoder that jointly models gene expression and alternative splicing (junction usage / PSI) from single-cell data. It learns a shared low-dimensional latent space that captures variation across both modalities, enabling tasks like imputation, clustering, trajectory inference, and differential splicing analysis.
SpliceVI follows the same overall architecture and inherits the same scvi-tools mixins as MultiVI (Ashuach et al., 2021), which jointly models RNA-seq and chromatin accessibility. The key difference is that the ATAC branch is replaced throughout by a splicing-specific branch — a missingness-aware encoder, a splicing-specific decoder, and splicing likelihoods (Binomial / Beta-Binomial / Dirichlet-Multinomial) suited to the PSI scale. The DE/DS analysis API mirrors MultiVI's, with get_normalized_splicing and differential_splicing replacing the accessibility equivalents.
Ashuach T*, Gabitto MI*, Jordan MI, Yosef N. MultiVI: deep generative model for the integration of multi-modal data. bioRxiv, 2021.
Model overview
SpliceVI uses two separate encoder branches — one for gene expression and one for splicing — whose posteriors are mixed into a joint latent space. Two decoder branches reconstruct each modality from the joint latent.
Gene expression ──► GE Encoder ──► q(z|x_e) ──┐
├──► Mix ──► z ──► GE Decoder ──► x_e
Splicing (PSI) ───► Splice Encoder ──► q(z|x_s) ─┘ └──► Splice Decoder ──► x_s
Splicing encoder
Because splicing observations are highly sparse — many junctions are unobserved in a given cell — SpliceVI uses a missingness-aware partial encoder (PartialEncoderEDDI). The design draws inspiration from the EDDI framework (Ma et al., 2019), which handles pervasive missing values in VAEs using feature embeddings and exchangeable neural networks (see also the SpliceVI preprint). The encoder:
- Embeds each observed junction independently via a shared \(h\) network
- Pools the per-junction embeddings (mean or sum) across all observed junctions in a cell
- Maps the pooled representation through an MLP to produce posterior statistics \((\mu, \sigma^2)\)
This avoids any imputation of missing values at the input level and is more principled than masking or zero-filling.
Splicing likelihood
Three options are supported:
| Option | Description |
|---|---|
binomial |
Simple binomial over junction read counts |
beta_binomial |
Beta-binomial with a learned per-junction concentration \(\phi_j\) |
dirichlet_multinomial |
Dirichlet-multinomial grouped per ATSE event; enforces that junction probabilities within an event sum to 1 |
The default and recommended option is dirichlet_multinomial.
Modality mixing
The two per-modality posteriors are combined into a single joint posterior before sampling. The mixing strategy is controlled by the modality_weights argument, and follows the same parameterization as MultiVI:
| Option | Description |
|---|---|
equal |
Fixed 50/50 average |
universal |
One global learned weight per modality |
cell |
Per-cell learned weights |
concatenate |
Concatenate both latents — no mixing, doubled latent size |
Architecture
| Component | Default | Options |
|---|---|---|
| Splicing encoder | partial |
vanilla, partial |
| Splicing decoder | vanilla |
vanilla, linear |
| Expression decoder | vanilla |
vanilla, linear |
| Splicing likelihood | dirichlet_multinomial |
binomial, beta_binomial, dirichlet_multinomial |
| Latent distribution | normal |
normal, ln |
Key hyperparameters
| Parameter | Default | Description |
|---|---|---|
n_latent |
30 | Joint latent dimensionality |
n_hidden |
auto | Width of encoder/decoder MLPs |
n_layers_encoder |
2 | Depth of encoder networks |
dropout_rate |
0.01 | Dropout in MLPs |
code_dim |
16 | Per-junction embedding size (partial encoder) |
h_hidden_dim |
64 | Hidden width of per-junction \(h\) network |
encoder_hidden_dim |
128 | Hidden width of the post-pooling MLP |
API reference
splicevi.SPLICEVI
Bases: VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin
Integration of gene expression and alternative splicing signals.
SPLICEVI integrates gene expression and alternative splicing (junction-usage) modalities,
learning a joint latent space with reconstruction heads per modality. It wraps the
:class:~scvi.module.SPLICEVAE module and adds training utilities and convenience APIs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData
|
AnnData/MuData object that has been registered via
:meth: |
required |
modality_weights
|
Literal['equal', 'cell', 'universal', 'concatenate', 'per_dimension_weighted_average']
|
How to weight modalities when forming the joint latent:
* |
'equal'
|
modality_penalty
|
Literal['Jeffreys', 'MMD', 'None']
|
Alignment penalty across modalities: |
'Jeffreys'
|
gene_likelihood
|
Literal['zinb', 'nb', 'poisson']
|
Expression likelihood: |
'zinb'
|
dispersion
|
Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']
|
Expression dispersion parameterization: |
'gene'
|
splicing_encoder_architecture
|
Literal['vanilla', 'partial']
|
Splicing encoder:
* |
'partial'
|
splicing_decoder_architecture
|
Literal['vanilla', 'linear']
|
Splicing decoder:
* |
'linear'
|
expression_architecture
|
Literal['vanilla', 'linear']
|
Expression decoder: |
'linear'
|
n_hidden
|
int | None
|
Width of SCVI encoders/decoders. If |
None
|
n_latent
|
int | None
|
Joint latent dimensionality (per-modality before concatenation). If |
None
|
n_layers_encoder
|
int
|
Hidden layers in encoders (including the post pooling MLP in the partial splicing encoder). |
2
|
n_layers_decoder
|
int
|
Hidden layers in decoders (not used by the linear decoders). |
2
|
dropout_rate
|
float
|
Dropout rate in MLPs. |
0.1
|
use_batch_norm
|
Literal['encoder', 'decoder', 'none', 'both']
|
Where to apply batch norm: |
'none'
|
use_layer_norm
|
Literal['encoder', 'decoder', 'none', 'both']
|
Where to apply layer norm: |
'both'
|
latent_distribution
|
Literal['normal', 'ln']
|
Latent distribution: |
'normal'
|
deeply_inject_covariates
|
bool
|
If True, injects covariates at all decoder layers. |
False
|
encode_covariates
|
bool
|
If True, provides covariates to encoders. |
False
|
splicing_loss_type
|
Literal['binomial', 'beta_binomial', 'dirichlet_multinomial']
|
Splicing reconstruction loss: |
'dirichlet_multinomial'
|
splicing_concentration
|
float | None
|
Optional concentration for beta binomial. Ignored for binomial. |
None
|
dm_concentration
|
Literal['atse', 'scalar']
|
For Dirichlet multinomial: |
'atse'
|
encoder_hidden_dim
|
int
|
Hidden width inside the PartialEncoder MLP that maps pooled junction codes to latent statistics. |
128
|
code_dim
|
int
|
Dimensionality of per junction embeddings in PartialEncoderEDDIFaster. |
16
|
h_hidden_dim
|
int
|
Hidden width of the shared h network that processes (psi, feature_embedding) for each observed junction. |
64
|
pool_mode
|
Literal['mean', 'sum']
|
Pooling across observed junctions per cell: |
'mean'
|
max_nobs
|
int
|
Optional cap on the number of observed (cell, junction) pairs processed in a single scatter chunk. Set to a negative value to disable chunking. |
-1
|
n_genes
|
int | None
|
Number of gene expression features. Required if |
None
|
n_junctions
|
int | None
|
Number of splicing features. Required if |
None
|
initialize_embeddings_from_pca
|
bool
|
If True and using the partial splicing encoder, initialize the per junction embedding table via PCA on junction ratios. |
False
|
fully_paired
|
bool
|
If True, the model exposes only a joint latent (no modality specific latents). |
False
|
**model_kwargs
|
Forwarded to :class: |
{}
|
differential_expression(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.25, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, **kwargs)
Differential expression analysis.
Performs differential gene expression analysis using normalized gene expression estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Additional keyword arguments for differential computation. |
{}
|
Returns:
| Type | Description |
|---|---|
A pandas DataFrame with differential expression results.
|
|
differential_splicing(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.1, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, norm_splicing_function='dm_posterior_mean', **kwargs)
Differential splicing analysis.
Performs differential junction usage analysis using normalized PSI from the model's splicing decoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Additional keyword arguments for differential computation. |
{}
|
Returns:
| Type | Description |
|---|---|
A pandas DataFrame with differential splicing results.
|
|
get_latent_representation(adata=None, modality='joint', indices=None, give_mean=True, batch_size=None)
Return the latent representation for each cell.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object used in setup. |
None
|
modality
|
Literal['joint', 'expression', 'splicing']
|
One of:
- |
'joint'
|
indices
|
Sequence[int] | None
|
Cell indices to use. |
None
|
give_mean
|
bool
|
If True, returns the mean of the latent distribution. |
True
|
batch_size
|
int | None
|
Batch size for processing. |
None
|
Returns:
| Type | Description |
|---|---|
A NumPy array of the latent representations.
|
|
get_library_size_factors(adata=None, indices=None, batch_size=128)
Return library size factors for gene expression.
Note that the splicing modality does not use library size factors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object. |
None
|
indices
|
Sequence[int]
|
Cell indices (default: all cells). |
None
|
batch_size
|
int
|
Batch size for processing. |
128
|
Returns:
| Type | Description |
|---|---|
A dictionary with key "expression" for the gene expression library size factors.
|
|
get_normalized_expression(adata=None, indices=None, n_samples_overall=None, transform_batch=None, gene_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)
Returns the normalized (decoded) gene expression.
This is denoted as :math:\rho_n in the scVI paper.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object with equivalent structure to initial AnnData. If |
None
|
indices
|
Sequence[int] | None
|
Indices of cells in adata to use. If |
None
|
n_samples_overall
|
int | None
|
Number of observations to sample from |
None
|
transform_batch
|
Sequence[Number | str] | None
|
Batch to condition on. If transform_batch is:
|
None
|
gene_list
|
Sequence[str] | None
|
Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest. |
None
|
use_z_mean
|
bool
|
If True, use the mean of the latent distribution, otherwise sample from it |
True
|
n_samples
|
int
|
Number of posterior samples to use for estimation. |
1
|
batch_size
|
int | None
|
Minibatch size for data loading into model. Defaults to |
None
|
return_mean
|
bool
|
Whether to return the mean of the samples. |
True
|
return_numpy
|
bool
|
Return a numpy array instead of a pandas DataFrame. |
False
|
Returns:
| Type | Description |
|---|---|
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
|
|
Otherwise, shape is `(cells, genes)`. In this case, return type is
|
|
class:`~pandas.DataFrame` unless `return_numpy` is True.
|
|
get_normalized_splicing(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)
Returns the normalized (decoded) splicing probabilities.
This is denoted as :math:p_{nj} in the SPLICEVI model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object with the same structure as used in setup. If |
None
|
indices
|
Sequence[int] | None
|
Cell indices to use. If |
None
|
n_samples_overall
|
int | None
|
Number of observations to sample from |
None
|
transform_batch
|
Sequence[Number | str] | None
|
Batch(s) to condition on: - None: use true observed batch - int: force all cells to that batch - list[int|str]: average over those batches |
None
|
junction_list
|
Sequence[str] | None
|
Subset of junction names to return. If |
None
|
use_z_mean
|
bool
|
If True, use the mean of the latent distribution; otherwise sample. |
True
|
n_samples
|
int
|
Number of posterior samples to draw. If >1 and |
1
|
batch_size
|
int | None
|
Minibatch size for decoding. Defaults to |
None
|
return_mean
|
bool
|
Whether to average over posterior samples when |
True
|
return_numpy
|
bool
|
If True, returns a NumPy array; otherwise a pandas DataFrame. |
False
|
silent
|
bool
|
If True, suppresses the progress bar. |
True
|
Returns:
| Type | Description |
|---|---|
A NumPy array or pandas DataFrame of shape `(cells, junctions)` containing the
|
|
decoded splicing probabilities.
|
|
get_normalized_splicing_DM(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)
Returns DM-normalized splicing probabilities (PSI*).
This decodes junction PSI (:math:p) and applies a Dirichlet–multinomial
posterior mean smoothing per junction:
:math:\psi^\* = (c p + y_j) / (c + n),
where :math:y_j is the observed junction count and :math:n is the
observed ATSE total for that junction. The concentration :math:c is
taken from the module:
* If self.dm_concentration == "atse": use the per-ATSE values in
self.module.log_phi_j mapped to per-junction via
self.module.junc2atse.
* Otherwise: treat self.module.log_phi_j as a scalar concentration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object with equivalent structure to initialization.
If |
None
|
indices
|
Sequence[int] | None
|
Indices of cells in |
None
|
n_samples_overall
|
int | None
|
Number of observations to sample from |
None
|
transform_batch
|
Sequence[Number | str] | None
|
Batch conditioning: - None: use observed batch - int/str: force decode to that batch - list[int|str]: average decoded outputs over listed batches |
None
|
junction_list
|
Sequence[str] | None
|
Optional subset of junction names to return. |
None
|
use_z_mean
|
bool
|
If True, decode from the mean of the latent; else sample. |
True
|
n_samples
|
int
|
Posterior samples to draw. If >1 and |
1
|
batch_size
|
int | None
|
Minibatch size for decoding. |
None
|
return_mean
|
bool
|
If True and |
True
|
return_numpy
|
bool
|
If True, return a NumPy array; else a pandas DataFrame. |
False
|
silent
|
bool
|
If True, disables progress display. |
True
|
Returns:
| Type | Description |
|---|---|
Array or DataFrame of shape ``(cells, junctions)`` with DM-normalized PSI*.
|
|
setup_anndata(adata, layer=None, junc_ratio=None, cell_by_junction_matrix=None, cell_by_cluster_matrix=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)
classmethod
Set up an AnnData object for SPLICEVI.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
junc_ratio
|
str | None
|
Key in |
None
|
cell_by_junction_matrix
|
str | None
|
Key in |
None
|
cell_by_cluster_matrix
|
str | None
|
Key in |
None
|
psi_mask_layer
|
str | None
|
Layer with binary mask (1=observed, 0=missing) per junction. |
None
|
Notes
Use this method if your splicing data is stored in an AnnData object where gene expression and splicing features are concatenated.
setup_mudata(mdata, rna_layer=None, junc_ratio_layer=None, atse_counts_layer=None, junc_counts_layer=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, idx_layer=None, modalities=None, **kwargs)
classmethod
Set up a MuData object for SPLICEVI.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rna_layer
|
str | None
|
Key in the RNA AnnData for gene expression counts.
If |
None
|
junc_ratio_layer
|
str | None
|
Key in the splicing AnnData for junction ratio values.
If |
None
|
atse_counts_layer
|
str | None
|
Key in the splicing AnnData for total event counts.
If |
None
|
junc_counts_layer
|
str | None
|
Key in the splicing AnnData for observed junction counts.
If |
None
|
psi_mask_layer
|
str | None
|
Layer with binary mask (1=observed, 0=missing) per junction. |
None
|
size_factor_key
|
str | None
|
Key in |
None
|
Examples:
>>> mdata = mu.MuData({
... "rna": ge_anndata.copy(),
... "splicing": atse_anndata.copy()
... })
>>> scvi.model.SPLICEVI.setup_mudata(
... mdata,
... modalities={"rna_layer": "rna", "junc_ratio_layer": "splicing"},
... rna_layer="raw_counts", # gene expression data is in the GE AnnData's "raw_counts" layer
... junc_ratio_layer="junc_ratio", # splicing data is in the ATSE AnnData's "junc_ratio" layer
... )
>>> model = scvi.model.SPLICEVI(mdata)
train(max_epochs=500, lr=0.0001, accelerator='auto', devices='auto', train_size=None, validation_size=None, shuffle_set_split=True, batch_size=128, weight_decay=0.001, eps=1e-08, early_stopping=True, early_stopping_patience=50, check_val_every_n_epoch=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=50, adversarial_mixing=True, lr_scheduler_type='plateau', reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', step_size=10, gradient_clipping=True, gradient_clipping_max_norm=5.0, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)
Trains the model using amortized variational inference on gene expression and splicing modalities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_epochs
|
int
|
Number of epochs to train over. |
500
|
lr
|
float
|
Learning rate for optimization. |
0.0001
|
accelerator
|
str
|
Hardware acceleration options. |
'auto'
|
devices
|
str
|
Hardware acceleration options. |
'auto'
|
train_size
|
float | None
|
Proportions for splitting the data into train and validation sets. |
None
|
validation_size
|
float | None
|
Proportions for splitting the data into train and validation sets. |
None
|
shuffle_set_split
|
bool
|
Whether to shuffle before splitting. |
True
|
batch_size
|
int
|
Minibatch size for training. |
128
|
weight_decay
|
float
|
Optimizer hyperparameters. |
0.001
|
eps
|
float
|
Optimizer hyperparameters. |
0.001
|
early_stopping
|
bool
|
Whether to enable early stopping. |
True
|
early_stopping_patience
|
int
|
Number of epochs with no improvement before stopping. |
50
|
check_val_every_n_epoch
|
int | None
|
How often (in epochs) to run validation. |
None
|
n_steps_kl_warmup
|
int | None
|
KL divergence warmup parameters (by steps or epochs). |
None
|
n_epochs_kl_warmup
|
int | None
|
KL divergence warmup parameters (by steps or epochs). |
None
|
adversarial_mixing
|
bool
|
Whether to include adversarial classifier during training. |
True
|
lr_scheduler_type
|
Literal['plateau', 'step']
|
Scheduler type in TrainingPlan: “plateau” (reduce-on-plateau) or “step” (fixed-step). |
'plateau'
|
reduce_lr_on_plateau
|
bool
|
If True and using plateau scheduler, enable ReduceLROnPlateau. |
False
|
lr_factor
|
float
|
Multiplicative factor for LR reduction (used for both plateau and step schedulers). |
0.6
|
lr_patience
|
int
|
Number of epochs with no improvement for plateau scheduler. |
30
|
lr_threshold
|
float
|
Threshold for measuring new optimum (plateau scheduler). |
0.0
|
lr_scheduler_metric
|
Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']
|
Metric to monitor for plateau scheduler. |
'elbo_validation'
|
step_size
|
int
|
Epoch interval between LR drops (step scheduler). |
10
|
gradient_clipping
|
bool
|
Whether or not (true or false) to use gradient norm clipping |
True
|
gradient_clipping_max_norm
|
float
|
Max norm of the gradients to be used in gradient clipping |
5.0
|
datasplitter_kwargs
|
dict | None
|
Additional kwargs for the data splitter. |
None
|
plan_kwargs
|
dict | None
|
Additional kwargs to pass to the TrainingPlan constructor. |
None
|
**kwargs
|
Additional Trainer kwargs (callbacks, strategy, etc.). |
{}
|