SPLICEVI
splicevi.SPLICEVI
Bases: VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin
Integration of gene expression and alternative splicing signals.
SPLICEVI integrates gene expression and alternative splicing (junction-usage) modalities,
learning a joint latent space with reconstruction heads per modality. It wraps the
:class:~scvi.module.SPLICEVAE module and adds training utilities and convenience APIs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData
|
AnnData/MuData object that has been registered via
:meth: |
required |
modality_weights
|
Literal['equal', 'cell', 'universal', 'concatenate', 'per_dimension_weighted_average']
|
How to weight modalities when forming the joint latent:
* |
'equal'
|
modality_penalty
|
Literal['Jeffreys', 'MMD', 'None']
|
Alignment penalty across modalities: |
'Jeffreys'
|
gene_likelihood
|
Literal['zinb', 'nb', 'poisson']
|
Expression likelihood: |
'zinb'
|
dispersion
|
Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']
|
Expression dispersion parameterization: |
'gene'
|
splicing_encoder_architecture
|
Literal['vanilla', 'partial']
|
Splicing encoder:
* |
'partial'
|
splicing_decoder_architecture
|
Literal['vanilla', 'linear']
|
Splicing decoder:
* |
'linear'
|
expression_architecture
|
Literal['vanilla', 'linear']
|
Expression decoder: |
'linear'
|
n_hidden
|
int | None
|
Width of SCVI encoders/decoders. If |
None
|
n_latent
|
int | None
|
Joint latent dimensionality (per-modality before concatenation). If |
None
|
n_layers_encoder
|
int
|
Hidden layers in encoders (including the post pooling MLP in the partial splicing encoder). |
2
|
n_layers_decoder
|
int
|
Hidden layers in decoders (not used by the linear decoders). |
2
|
dropout_rate
|
float
|
Dropout rate in MLPs. |
0.1
|
use_batch_norm
|
Literal['encoder', 'decoder', 'none', 'both']
|
Where to apply batch norm: |
'none'
|
use_layer_norm
|
Literal['encoder', 'decoder', 'none', 'both']
|
Where to apply layer norm: |
'both'
|
latent_distribution
|
Literal['normal', 'ln']
|
Latent distribution: |
'normal'
|
deeply_inject_covariates
|
bool
|
If True, injects covariates at all decoder layers. |
False
|
encode_covariates
|
bool
|
If True, provides covariates to encoders. |
False
|
splicing_loss_type
|
Literal['binomial', 'beta_binomial', 'dirichlet_multinomial']
|
Splicing reconstruction loss: |
'dirichlet_multinomial'
|
splicing_concentration
|
float | None
|
Optional concentration for beta binomial. Ignored for binomial. |
None
|
dm_concentration
|
Literal['atse', 'scalar']
|
For Dirichlet multinomial: |
'atse'
|
encoder_hidden_dim
|
int
|
Hidden width inside the PartialEncoder MLP that maps pooled junction codes to latent statistics. |
128
|
code_dim
|
int
|
Dimensionality of per junction embeddings in PartialEncoderEDDIFaster. |
16
|
h_hidden_dim
|
int
|
Hidden width of the shared h network that processes (psi, feature_embedding) for each observed junction. |
64
|
pool_mode
|
Literal['mean', 'sum']
|
Pooling across observed junctions per cell: |
'mean'
|
max_nobs
|
int
|
Optional cap on the number of observed (cell, junction) pairs processed in a single scatter chunk. Set to a negative value to disable chunking. |
-1
|
n_genes
|
int | None
|
Number of gene expression features. Required if |
None
|
n_junctions
|
int | None
|
Number of splicing features. Required if |
None
|
initialize_embeddings_from_pca
|
bool
|
If True and using the partial splicing encoder, initialize the per junction embedding table via PCA on junction ratios. |
False
|
fully_paired
|
bool
|
If True, the model exposes only a joint latent (no modality specific latents). |
False
|
**model_kwargs
|
Forwarded to :class: |
{}
|
differential_expression(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.25, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, **kwargs)
Differential expression analysis.
Performs differential gene expression analysis using normalized gene expression estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Additional keyword arguments for differential computation. |
{}
|
Returns:
| Type | Description |
|---|---|
A pandas DataFrame with differential expression results.
|
|
differential_splicing(adata=None, groupby=None, group1=None, group2=None, idx1=None, idx2=None, mode='change', delta=0.1, batch_size=None, all_stats=True, batch_correction=False, batchid1=None, batchid2=None, fdr_target=0.05, silent=False, norm_splicing_function='dm_posterior_mean', **kwargs)
Differential splicing analysis.
Performs differential junction usage analysis using normalized PSI from the model's splicing decoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Additional keyword arguments for differential computation. |
{}
|
Returns:
| Type | Description |
|---|---|
A pandas DataFrame with differential splicing results.
|
|
get_latent_representation(adata=None, modality='joint', indices=None, give_mean=True, batch_size=None)
Return the latent representation for each cell.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object used in setup. |
None
|
modality
|
Literal['joint', 'expression', 'splicing']
|
One of:
- |
'joint'
|
indices
|
Sequence[int] | None
|
Cell indices to use. |
None
|
give_mean
|
bool
|
If True, returns the mean of the latent distribution. |
True
|
batch_size
|
int | None
|
Batch size for processing. |
None
|
Returns:
| Type | Description |
|---|---|
A NumPy array of the latent representations.
|
|
get_library_size_factors(adata=None, indices=None, batch_size=128)
Return library size factors for gene expression.
Note that the splicing modality does not use library size factors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object. |
None
|
indices
|
Sequence[int]
|
Cell indices (default: all cells). |
None
|
batch_size
|
int
|
Batch size for processing. |
128
|
Returns:
| Type | Description |
|---|---|
A dictionary with key "expression" for the gene expression library size factors.
|
|
get_normalized_expression(adata=None, indices=None, n_samples_overall=None, transform_batch=None, gene_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)
Returns the normalized (decoded) gene expression.
This is denoted as :math:\rho_n in the scVI paper.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object with equivalent structure to initial AnnData. If |
None
|
indices
|
Sequence[int] | None
|
Indices of cells in adata to use. If |
None
|
n_samples_overall
|
int | None
|
Number of observations to sample from |
None
|
transform_batch
|
Sequence[Number | str] | None
|
Batch to condition on. If transform_batch is:
|
None
|
gene_list
|
Sequence[str] | None
|
Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest. |
None
|
use_z_mean
|
bool
|
If True, use the mean of the latent distribution, otherwise sample from it |
True
|
n_samples
|
int
|
Number of posterior samples to use for estimation. |
1
|
batch_size
|
int | None
|
Minibatch size for data loading into model. Defaults to |
None
|
return_mean
|
bool
|
Whether to return the mean of the samples. |
True
|
return_numpy
|
bool
|
Return a numpy array instead of a pandas DataFrame. |
False
|
Returns:
| Type | Description |
|---|---|
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
|
|
Otherwise, shape is `(cells, genes)`. In this case, return type is
|
|
class:`~pandas.DataFrame` unless `return_numpy` is True.
|
|
get_normalized_splicing(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)
Returns the normalized (decoded) splicing probabilities.
This is denoted as :math:p_{nj} in the SPLICEVI model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object with the same structure as used in setup. If |
None
|
indices
|
Sequence[int] | None
|
Cell indices to use. If |
None
|
n_samples_overall
|
int | None
|
Number of observations to sample from |
None
|
transform_batch
|
Sequence[Number | str] | None
|
Batch(s) to condition on: - None: use true observed batch - int: force all cells to that batch - list[int|str]: average over those batches |
None
|
junction_list
|
Sequence[str] | None
|
Subset of junction names to return. If |
None
|
use_z_mean
|
bool
|
If True, use the mean of the latent distribution; otherwise sample. |
True
|
n_samples
|
int
|
Number of posterior samples to draw. If >1 and |
1
|
batch_size
|
int | None
|
Minibatch size for decoding. Defaults to |
None
|
return_mean
|
bool
|
Whether to average over posterior samples when |
True
|
return_numpy
|
bool
|
If True, returns a NumPy array; otherwise a pandas DataFrame. |
False
|
silent
|
bool
|
If True, suppresses the progress bar. |
True
|
Returns:
| Type | Description |
|---|---|
A NumPy array or pandas DataFrame of shape `(cells, junctions)` containing the
|
|
decoded splicing probabilities.
|
|
get_normalized_splicing_DM(adata=None, indices=None, n_samples_overall=None, transform_batch=None, junction_list=None, use_z_mean=True, n_samples=1, batch_size=None, return_mean=True, return_numpy=False, silent=True)
Returns DM-normalized splicing probabilities (PSI*).
This decodes junction PSI (:math:p) and applies a Dirichlet–multinomial
posterior mean smoothing per junction:
:math:\psi^\* = (c p + y_j) / (c + n),
where :math:y_j is the observed junction count and :math:n is the
observed ATSE total for that junction. The concentration :math:c is
taken from the module:
* If self.dm_concentration == "atse": use the per-ATSE values in
self.module.log_phi_j mapped to per-junction via
self.module.junc2atse.
* Otherwise: treat self.module.log_phi_j as a scalar concentration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnOrMuData | None
|
AnnOrMuData object with equivalent structure to initialization.
If |
None
|
indices
|
Sequence[int] | None
|
Indices of cells in |
None
|
n_samples_overall
|
int | None
|
Number of observations to sample from |
None
|
transform_batch
|
Sequence[Number | str] | None
|
Batch conditioning: - None: use observed batch - int/str: force decode to that batch - list[int|str]: average decoded outputs over listed batches |
None
|
junction_list
|
Sequence[str] | None
|
Optional subset of junction names to return. |
None
|
use_z_mean
|
bool
|
If True, decode from the mean of the latent; else sample. |
True
|
n_samples
|
int
|
Posterior samples to draw. If >1 and |
1
|
batch_size
|
int | None
|
Minibatch size for decoding. |
None
|
return_mean
|
bool
|
If True and |
True
|
return_numpy
|
bool
|
If True, return a NumPy array; else a pandas DataFrame. |
False
|
silent
|
bool
|
If True, disables progress display. |
True
|
Returns:
| Type | Description |
|---|---|
Array or DataFrame of shape ``(cells, junctions)`` with DM-normalized PSI*.
|
|
setup_anndata(adata, layer=None, junc_ratio=None, cell_by_junction_matrix=None, cell_by_cluster_matrix=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)
classmethod
Set up an AnnData object for SPLICEVI.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
junc_ratio
|
str | None
|
Key in |
None
|
cell_by_junction_matrix
|
str | None
|
Key in |
None
|
cell_by_cluster_matrix
|
str | None
|
Key in |
None
|
psi_mask_layer
|
str | None
|
Layer with binary mask (1=observed, 0=missing) per junction. |
None
|
Notes
Use this method if your splicing data is stored in an AnnData object where gene expression and splicing features are concatenated.
setup_mudata(mdata, rna_layer=None, junc_ratio_layer=None, atse_counts_layer=None, junc_counts_layer=None, psi_mask_layer=None, batch_key=None, size_factor_key=None, categorical_covariate_keys=None, continuous_covariate_keys=None, idx_layer=None, modalities=None, **kwargs)
classmethod
Set up a MuData object for SPLICEVI.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rna_layer
|
str | None
|
Key in the RNA AnnData for gene expression counts.
If |
None
|
junc_ratio_layer
|
str | None
|
Key in the splicing AnnData for junction ratio values.
If |
None
|
atse_counts_layer
|
str | None
|
Key in the splicing AnnData for total event counts.
If |
None
|
junc_counts_layer
|
str | None
|
Key in the splicing AnnData for observed junction counts.
If |
None
|
psi_mask_layer
|
str | None
|
Layer with binary mask (1=observed, 0=missing) per junction. |
None
|
size_factor_key
|
str | None
|
Key in |
None
|
Examples:
>>> mdata = mu.MuData({
... "rna": ge_anndata.copy(),
... "splicing": atse_anndata.copy()
... })
>>> scvi.model.SPLICEVI.setup_mudata(
... mdata,
... modalities={"rna_layer": "rna", "junc_ratio_layer": "splicing"},
... rna_layer="raw_counts", # gene expression data is in the GE AnnData's "raw_counts" layer
... junc_ratio_layer="junc_ratio", # splicing data is in the ATSE AnnData's "junc_ratio" layer
... )
>>> model = scvi.model.SPLICEVI(mdata)
train(max_epochs=500, lr=0.0001, accelerator='auto', devices='auto', train_size=None, validation_size=None, shuffle_set_split=True, batch_size=128, weight_decay=0.001, eps=1e-08, early_stopping=True, early_stopping_patience=50, check_val_every_n_epoch=None, n_steps_kl_warmup=None, n_epochs_kl_warmup=50, adversarial_mixing=True, lr_scheduler_type='plateau', reduce_lr_on_plateau=False, lr_factor=0.6, lr_patience=30, lr_threshold=0.0, lr_scheduler_metric='elbo_validation', step_size=10, gradient_clipping=True, gradient_clipping_max_norm=5.0, datasplitter_kwargs=None, plan_kwargs=None, **kwargs)
Trains the model using amortized variational inference on gene expression and splicing modalities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_epochs
|
int
|
Number of epochs to train over. |
500
|
lr
|
float
|
Learning rate for optimization. |
0.0001
|
accelerator
|
str
|
Hardware acceleration options. |
'auto'
|
devices
|
str
|
Hardware acceleration options. |
'auto'
|
train_size
|
float | None
|
Proportions for splitting the data into train and validation sets. |
None
|
validation_size
|
float | None
|
Proportions for splitting the data into train and validation sets. |
None
|
shuffle_set_split
|
bool
|
Whether to shuffle before splitting. |
True
|
batch_size
|
int
|
Minibatch size for training. |
128
|
weight_decay
|
float
|
Optimizer hyperparameters. |
0.001
|
eps
|
float
|
Optimizer hyperparameters. |
0.001
|
early_stopping
|
bool
|
Whether to enable early stopping. |
True
|
early_stopping_patience
|
int
|
Number of epochs with no improvement before stopping. |
50
|
check_val_every_n_epoch
|
int | None
|
How often (in epochs) to run validation. |
None
|
n_steps_kl_warmup
|
int | None
|
KL divergence warmup parameters (by steps or epochs). |
None
|
n_epochs_kl_warmup
|
int | None
|
KL divergence warmup parameters (by steps or epochs). |
None
|
adversarial_mixing
|
bool
|
Whether to include adversarial classifier during training. |
True
|
lr_scheduler_type
|
Literal['plateau', 'step']
|
Scheduler type in TrainingPlan: “plateau” (reduce-on-plateau) or “step” (fixed-step). |
'plateau'
|
reduce_lr_on_plateau
|
bool
|
If True and using plateau scheduler, enable ReduceLROnPlateau. |
False
|
lr_factor
|
float
|
Multiplicative factor for LR reduction (used for both plateau and step schedulers). |
0.6
|
lr_patience
|
int
|
Number of epochs with no improvement for plateau scheduler. |
30
|
lr_threshold
|
float
|
Threshold for measuring new optimum (plateau scheduler). |
0.0
|
lr_scheduler_metric
|
Literal['elbo_validation', 'reconstruction_loss_validation', 'kl_local_validation']
|
Metric to monitor for plateau scheduler. |
'elbo_validation'
|
step_size
|
int
|
Epoch interval between LR drops (step scheduler). |
10
|
gradient_clipping
|
bool
|
Whether or not (true or false) to use gradient norm clipping |
True
|
gradient_clipping_max_norm
|
float
|
Max norm of the gradients to be used in gradient clipping |
5.0
|
datasplitter_kwargs
|
dict | None
|
Additional kwargs for the data splitter. |
None
|
plan_kwargs
|
dict | None
|
Additional kwargs to pass to the TrainingPlan constructor. |
None
|
**kwargs
|
Additional Trainer kwargs (callbacks, strategy, etc.). |
{}
|