Skip to content

SPLICEVI

splicevi.SPLICEVI

Bases: VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin

Integration of gene expression and alternative splicing signals.

SPLICEVI integrates gene expression and alternative splicing (junction-usage) modalities, learning a joint latent space with reconstruction heads per modality. It wraps the :class:~scvi.module.SPLICEVAE module and adds training utilities and convenience APIs.

Parameters:

Name Type Description Default
adata AnnOrMuData

AnnData/MuData object that has been registered via :meth:~scvi.model.SPLICEVI.setup_anndata or :meth:~scvi.model.SPLICEVI.setup_mudata.

required
modality_weights Literal['equal', 'cell', 'universal', 'concatenate', 'per_dimension_weighted_average']

How to weight modalities when forming the joint latent: * "equal" – equal weight per modality, * "universal" – one global weight per modality (learned), * "cell" – per-cell weights (learned), * "concatenate" – do not mix; concatenate per-modality latents, * "per_dimension_weighted_average" – one learned splicing weight per latent dimension (expression weight = 1 - splicing weight), constrained so the mean weight stays at 0.5 to prevent either modality from dominating globally.

'equal'
modality_penalty Literal['Jeffreys', 'MMD', 'None']

Alignment penalty across modalities: "Jeffreys", "MMD", or "None".

'Jeffreys'
gene_likelihood Literal['zinb', 'nb', 'poisson']

Expression likelihood: "zinb", "nb", or "poisson".

'zinb'
dispersion Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']

Expression dispersion parameterization: "gene", "gene-batch", "gene-label", or "gene-cell".

'gene'
splicing_encoder_architecture Literal['vanilla', 'partial']

Splicing encoder: * "vanilla" – SCVI Encoder, * "partial"PartialEncoderEDDIFaster.

'partial'
splicing_decoder_architecture Literal['vanilla', 'linear']

Splicing decoder: * "vanilla" – nonlinear DecoderSplice, * "linear" – linear decoder.

'linear'
expression_architecture Literal['vanilla', 'linear']

Expression decoder: "vanilla" (nonlinear) or "linear" (linear decoder).

'linear'
n_hidden int | None

Width of SCVI encoders/decoders. If None, uses √(n_junctions) (capped at 128).

None
n_latent int | None

Joint latent dimensionality (per-modality before concatenation). If None, uses √(n_hidden).

None
n_layers_encoder int

Hidden layers in encoders (including the post pooling MLP in the partial splicing encoder).

2
n_layers_decoder int

Hidden layers in decoders (not used by the linear decoders).

2
dropout_rate float

Dropout rate in MLPs.

0.1
use_batch_norm Literal['encoder', 'decoder', 'none', 'both']

Where to apply batch norm: "encoder", "decoder", "both", or "none".

'none'
use_layer_norm Literal['encoder', 'decoder', 'none', 'both']

Where to apply layer norm: "encoder", "decoder", "both", or "none".

'both'
latent_distribution Literal['normal', 'ln']

Latent distribution: "normal" or "ln" (logistic normal).

'normal'
deeply_inject_covariates bool

If True, injects covariates at all decoder layers.

False
encode_covariates bool

If True, provides covariates to encoders.

False
splicing_loss_type Literal['binomial', 'beta_binomial', 'dirichlet_multinomial']

Splicing reconstruction loss: "binomial", "beta_binomial", or "dirichlet_multinomial".

'dirichlet_multinomial'
splicing_concentration float | None

Optional concentration for beta binomial. Ignored for binomial.

None
dm_concentration Literal['atse', 'scalar']

For Dirichlet multinomial: "atse" (per ATSE concentration) or "scalar" (single shared).

'atse'
encoder_hidden_dim int

Hidden width inside the PartialEncoder MLP that maps pooled junction codes to latent statistics.

128
code_dim int

Dimensionality of per junction embeddings in PartialEncoderEDDIFaster.

16
h_hidden_dim int

Hidden width of the shared h network that processes (psi, feature_embedding) for each observed junction.

64
pool_mode Literal['mean', 'sum']

Pooling across observed junctions per cell: "mean" or "sum".

'mean'
max_nobs int

Optional cap on the number of observed (cell, junction) pairs processed in a single scatter chunk. Set to a negative value to disable chunking.

-1
n_genes int | None

Number of gene expression features. Required if adata is AnnData.

None
n_junctions int | None

Number of splicing features. Required if adata is AnnData.

None
initialize_embeddings_from_pca bool

If True and using the partial splicing encoder, initialize the per junction embedding table via PCA on junction ratios.

False
fully_paired bool

If True, the model exposes only a joint latent (no modality specific latents).

False
**model_kwargs

Forwarded to :class:~scvi.module.SPLICEVAE.

{}

differential_expression(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.25, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, **kwargs)

Differential expression analysis.

Performs differential gene expression analysis using normalized gene expression estimates.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments for differential computation.

{}

Returns:

Type Description
A pandas DataFrame with differential expression results.

differential_splicing(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.1, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, norm_splicing_function='dm_posterior_mean', **kwargs)

Differential splicing analysis.

Performs differential junction usage analysis using normalized PSI from the model's splicing decoder.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments for differential computation.

{}

Returns:

Type Description
A pandas DataFrame with differential splicing results.

get_latent_representation(adata=None, modality='joint', indices=None, give_mean=True, batch_size=None)

Return the latent representation for each cell.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object used in setup.

None
modality Literal['joint', 'expression', 'splicing']

One of: - "joint": joint latent space, - "expression": expression-specific latent space, - "splicing": splicing-specific latent space.

'joint'
indices Sequence[int] | None

Cell indices to use.

None
give_mean bool

If True, returns the mean of the latent distribution.

True
batch_size int | None

Batch size for processing.

None

Returns:

Type Description
A NumPy array of the latent representations.

get_library_size_factors(adata=None, indices=None, batch_size=128)

Return library size factors for gene expression.

Note that the splicing modality does not use library size factors.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object.

None
indices Sequence[int]

Cell indices (default: all cells).

None
batch_size int

Batch size for processing.

128

Returns:

Type Description
A dictionary with key "expression" for the gene expression library size factors.

get_normalized_expression(adata=None, indices=None, n_samples_overall=None, transform_batch=None, gene_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)

Returns the normalized (decoded) gene expression.

This is denoted as :math:\rho_n in the scVI paper.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object with equivalent structure to initial AnnData. If None, defaults to the AnnOrMuData object used to initialize the model.

None
indices Sequence[int] | None

Indices of cells in adata to use. If None, all cells are used.

None
n_samples_overall int | None

Number of observations to sample from indices if indices is provided.

None
transform_batch Sequence[Number | str] | None

Batch to condition on. If transform_batch is:

  • None, then real observed batch is used.
  • int, then batch transform_batch is used.
None
gene_list Sequence[str] | None

Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest.

None
use_z_mean bool

If True, use the mean of the latent distribution, otherwise sample from it

True
n_samples int

Number of posterior samples to use for estimation.

1
batch_size int | None

Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

None
return_mean bool

Whether to return the mean of the samples.

True
return_numpy bool

Return a numpy array instead of a pandas DataFrame.

False

Returns:

Type Description
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is
class:`~pandas.DataFrame` unless `return_numpy` is True.

get_normalized_splicing(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)

Returns the normalized (decoded) splicing probabilities.

This is denoted as :math:p_{nj} in the SPLICEVI model.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object with the same structure as used in setup. If None, defaults to the AnnOrMuData object used to initialize the model.

None
indices Sequence[int] | None

Cell indices to use. If None, all cells are used.

None
n_samples_overall int | None

Number of observations to sample from indices if provided.

None
transform_batch Sequence[Number | str] | None

Batch(s) to condition on: - None: use true observed batch - int: force all cells to that batch - list[int|str]: average over those batches

None
junction_list Sequence[str] | None

Subset of junction names to return. If None, returns all junctions.

None
use_z_mean bool

If True, use the mean of the latent distribution; otherwise sample.

True
n_samples int

Number of posterior samples to draw. If >1 and return_mean is True, the result is averaged over draws.

1
batch_size int | None

Minibatch size for decoding. Defaults to scvi.settings.batch_size.

None
return_mean bool

Whether to average over posterior samples when n_samples>1.

True
return_numpy bool

If True, returns a NumPy array; otherwise a pandas DataFrame.

False
silent bool

If True, suppresses the progress bar.

True

Returns:

Type Description
A NumPy array or pandas DataFrame of shape `(cells, junctions)` containing the
decoded splicing probabilities.

get_normalized_splicing_DM(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)

Returns DM-normalized splicing probabilities (PSI*).

This decodes junction PSI (:math:p) and applies a Dirichlet–multinomial posterior mean smoothing per junction: :math:\psi^\* = (c p + y_j) / (c + n), where :math:y_j is the observed junction count and :math:n is the observed ATSE total for that junction. The concentration :math:c is taken from the module: * If self.dm_concentration == "atse": use the per-ATSE values in self.module.log_phi_j mapped to per-junction via self.module.junc2atse. * Otherwise: treat self.module.log_phi_j as a scalar concentration.

Parameters:

Name Type Description Default
adata AnnOrMuData | None

AnnOrMuData object with equivalent structure to initialization. If None, defaults to the object used to initialize the model.

None
indices Sequence[int] | None

Indices of cells in adata to use. If None, all cells are used.

None
n_samples_overall int | None

Number of observations to sample from indices if provided.

None
transform_batch Sequence[Number | str] | None

Batch conditioning: - None: use observed batch - int/str: force decode to that batch - list[int|str]: average decoded outputs over listed batches

None
junction_list Sequence[str] | None

Optional subset of junction names to return.

None
use_z_mean bool

If True, decode from the mean of the latent; else sample.

True
n_samples int

Posterior samples to draw. If >1 and return_mean is True, results are averaged over samples.

1
batch_size int | None

Minibatch size for decoding.

None
return_mean bool

If True and n_samples > 1, average samples before returning.

True
return_numpy bool

If True, return a NumPy array; else a pandas DataFrame.

False
silent bool

If True, disables progress display.

True

Returns:

Type Description
Array or DataFrame of shape ``(cells, junctions)`` with DM-normalized PSI*.

setup_anndata(adata, layer=None, junc_ratio=None, cell_by_junction_matrix=None, cell_by_cluster_matrix=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs) classmethod

Set up an AnnData object for SPLICEVI.

Parameters:

Name Type Description Default
junc_ratio str | None

Key in adata.layers for junction ratio values.

None
cell_by_junction_matrix str | None

Key in adata.layers for the cell-by-junction matrix.

None
cell_by_cluster_matrix str | None

Key in adata.layers for the cell-by-cluster splicing matrix.

None
psi_mask_layer str | None

Layer with binary mask (1=observed, 0=missing) per junction.

None
Notes

Use this method if your splicing data is stored in an AnnData object where gene expression and splicing features are concatenated.

setup_mudata(mdata, rna_layer=None, junc_ratio_layer=None, atse_counts_layer=None, junc_counts_layer=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, idx_layer=None, modalities=None, **kwargs) classmethod

Set up a MuData object for SPLICEVI.

Parameters:

Name Type Description Default
rna_layer str | None

Key in the RNA AnnData for gene expression counts. If None, the primary data (.X) of that AnnData is used.

None
junc_ratio_layer str | None

Key in the splicing AnnData for junction ratio values. If None, the primary data (.X) of that AnnData is used.

None
atse_counts_layer str | None

Key in the splicing AnnData for total event counts. If None, defaults to "cell_by_cluster_matrix".

None
junc_counts_layer str | None

Key in the splicing AnnData for observed junction counts. If None, defaults to "cell_by_junction_matrix".

None
psi_mask_layer str | None

Layer with binary mask (1=observed, 0=missing) per junction.

None
size_factor_key str | None

Key in mdata.obsm for size factors.

None

Examples:

>>> mdata = mu.MuData({
...    "rna": ge_anndata.copy(),
...    "splicing": atse_anndata.copy()
... })
>>> scvi.model.SPLICEVI.setup_mudata(
...     mdata,
...     modalities={"rna_layer": "rna", "junc_ratio_layer": "splicing"},
...     rna_layer="raw_counts",            # gene expression data is in the GE AnnData's "raw_counts" layer
...     junc_ratio_layer="junc_ratio",     # splicing data is in the ATSE AnnData's "junc_ratio" layer
... )
>>> model = scvi.model.SPLICEVI(mdata)

train(max_epochs=500, lr=0.0001, accelerator='auto', devices='auto', train_size=None, validation_size=None, shuffle_set_split=True, batch_size=128, weight_decay=0.001, eps=1e-08, early_stopping=True, early_stopping_patience=50, check_val_every_n_epoch=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=50, adversarial_mixing=True, lr_scheduler_type='plateau', reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', step_size=10, gradient_clipping=True, gradient_clipping_max_norm=5.0, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)

Trains the model using amortized variational inference on gene expression and splicing modalities.

Parameters:

Name Type Description Default
max_epochs int

Number of epochs to train over.

500
lr float

Learning rate for optimization.

0.0001
accelerator str

Hardware acceleration options.

'auto'
devices str

Hardware acceleration options.

'auto'
train_size float | None

Proportions for splitting the data into train and validation sets.

None
validation_size float | None

Proportions for splitting the data into train and validation sets.

None
shuffle_set_split bool

Whether to shuffle before splitting.

True
batch_size int

Minibatch size for training.

128
weight_decay float

Optimizer hyperparameters.

0.001
eps float

Optimizer hyperparameters.

0.001
early_stopping bool

Whether to enable early stopping.

True
early_stopping_patience int

Number of epochs with no improvement before stopping.

50
check_val_every_n_epoch int | None

How often (in epochs) to run validation.

None
n_steps_kl_warmup int | None

KL divergence warmup parameters (by steps or epochs).

None
n_epochs_kl_warmup int | None

KL divergence warmup parameters (by steps or epochs).

None
adversarial_mixing bool

Whether to include adversarial classifier during training.

True
lr_scheduler_type Literal['plateau', 'step']

Scheduler type in TrainingPlan: “plateau” (reduce-on-plateau) or “step” (fixed-step).

'plateau'
reduce_lr_on_plateau bool

If True and using plateau scheduler, enable ReduceLROnPlateau.

False
lr_factor float

Multiplicative factor for LR reduction (used for both plateau and step schedulers).

0.6
lr_patience int

Number of epochs with no improvement for plateau scheduler.

30
lr_threshold float

Threshold for measuring new optimum (plateau scheduler).

0.0
lr_scheduler_metric Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']

Metric to monitor for plateau scheduler.

'elbo_validation'
step_size int

Epoch interval between LR drops (step scheduler).

10
gradient_clipping bool

Whether or not (true or false) to use gradient norm clipping

True
gradient_clipping_max_norm float

Max norm of the gradients to be used in gradient clipping

5.0
datasplitter_kwargs dict | None

Additional kwargs for the data splitter.

None
plan_kwargs dict | None

Additional kwargs to pass to the TrainingPlan constructor.

None
**kwargs

Additional Trainer kwargs (callbacks, strategy, etc.).

{}