Skip to content

SPLICEVAE

The underlying PyTorch module. Most users will interact with SPLICEVI rather than this class directly.

splicevi.splicevae.SPLICEVAE

Bases: BaseModuleClass

Variational auto encoder for joint (paired or unpaired) RNA-seq gene expression and alternative splicing (junction usage). Two encoder–decoder branches (expression and splicing) produce posterior latents that are mixed into a shared latent space z.

Reconstruction is performed per modality. Optional penalties align the two posteriors.

Parameters:

Name Type Description Default
n_input_genes int

Number of gene expression features (G).

0
n_input_junctions int

Number of splicing junction features (J).

0
n_batch int

Number of batches for batch correction (categorical covariate).

0
n_obs int

Number of observations (cells); needed when modality_weights="cell".

0
n_labels int

Number of labels (for gene_dispersion="gene-label").

0
n_cats_per_cov Iterable[int] or None

Category counts for each categorical covariate provided.

None
n_continuous_cov int

Number of continuous covariates.

0
gene_likelihood (zinb, nb, poisson)

Expression likelihood.

"zinb","nb","poisson"
gene_dispersion (gene, gene - batch, gene - label, gene - cell)

Dispersion layout for expression.

"gene","gene-batch","gene-label","gene-cell"
use_size_factor_key bool

If True, use provided library size factors for expression. Otherwise learn a library size encoder.

False
splicing_encoder_architecture (vanilla, partial)

Splicing encoder choice: "vanilla" uses a standard SCVI Encoder. "partial" uses PartialEncoderEDDIFaster.

"vanilla","partial"
splicing_decoder_architecture (vanilla, linear)

Splicing decoder choice: "vanilla" uses DecoderSplice (nonlinear). "linear" uses a LinearDecoder.

"vanilla","linear"
expression_architecture (vanilla, linear)

Expression decoder: "vanilla" uses a non linear DecoderSCVI. "linear" uses a LinearDecoderSCVI.

"vanilla","linear"
n_hidden int or None

Hidden width for SCVI style MLPs. If None, defaults to roughly sqrt(J) capped at 128.

None
n_latent int or None

Latent dimensionality. If None, defaults to roughly sqrt(n_hidden). When modality_weights="concatenate", the mixed latent is doubled internally.

None
n_layers_encoder int

Depth of SCVI encoders (expression branch and splicing branch when splicing_encoder_architecture="vanilla"). Also controls the depth of the post pooling MLP in the partial splicing encoder.

2
n_layers_decoder int

Depth of non linear decoders (expression "vanilla", splicing "vanilla"). Not used by linear decoders.

2
dropout_rate float

Dropout for SCVI style MLPs.

0.1
use_batch_norm (encoder, decoder, none, both)

Apply BatchNorm to encoder or decoder stacks.

"encoder","decoder","none","both"
use_layer_norm (encoder, decoder, none, both)

Apply LayerNorm to encoder or decoder stacks.

"encoder","decoder","none","both"
latent_distribution (normal, ln)

Posterior family for encoders. If "ln" (logistic normal), the latent sample is softmax transformed before decoding.

"normal","ln"
deeply_inject_covariates bool

Deeply inject categorical and continuous covariates into decoder layers.

False
encode_covariates bool

Concatenate continuous covariates to encoder inputs and pass categorical covariates via n_cat_list.

False
splicing_loss_type (binomial, beta_binomial, dirichlet_multinomial)

Reconstruction loss for splicing. "binomial" expects junction counts and ATSE totals. "beta_binomial" uses a beta binomial with a learned concentration. "dirichlet_multinomial" uses grouped softmax within ATSEs.

"binomial","beta_binomial","dirichlet_multinomial"
splicing_concentration float or None

Optional scalar concentration for the beta binomial case.

None
dm_concentration (atse, scalar)

For Dirichlet multinomial. Controls whether the concentration is per ATSE or scalar.

"atse","scalar"
code_dim int

Dimensionality of per junction codes before pooling.

16
h_hidden_dim int

Hidden width of the per junction h subnetwork that combines PSI and feature embedding.

64
encoder_hidden_dim int

Hidden width of the post pooling MLP that maps the pooled code to (mu, log var) of z.

128
pool_mode (mean, sum)

Aggregation of per junction codes within each cell.

"mean","sum"
max_nobs int

Optional cap on the number of observed entries processed in each scatter chunk. A negative value disables chunking.

-1
modality_weights (equal, cell, universal, concatenate, per_dimension_weighted_average)

How to combine expression and splicing posteriors. "per_dimension_weighted_average" learns a per-latent-dimension splicing contribution weight w_d in [0,1] (expression weight = 1-w_d), constrained so the mean weight stays at 0.5 (preventing one modality from dominating globally).

"equal","cell","universal","concatenate","per_dimension_weighted_average"
modality_penalty (Jeffreys, MMD, None)

Alignment penalty between the two posteriors on paired cells.

"Jeffreys","MMD","None"
**model_kwargs

Forwarded to underlying components.

{}
Notes

The partial splicing branch uses a single PartialEncoderEDDIFaster encoder that aggregates only observed junctions per cell. The same n_layers_encoder parameter controls the depth of the final MLP that produces latent statistics for this branch.

dirichlet_multinomial_likelihood(counts, atse_counts, junc2atse, alpha, mask=None)

Computes each cell’s Dirichlet–multinomial log-likelihood, masking out any junctions where mask==0 and any ATSE-groups where atse_counts==0.

Returns: Tensor of shape (N,) giving per-cell log-likelihood: ll[i] = sum_over_groups LL_group(i,g) – sum_over_junctions LL_junc(i,p)

generative(z, qz_m, batch_index, cont_covs=None, cat_covs=None, libsize_expr=None, use_z_mean=False, label=None)

Run the generative model to decode gene expression and splicing.

Decodes the latent representation into parameters for gene expression reconstruction and splicing probabilities.

Returns:

Type Description
dict

A dictionary with keys: - "p": decoded splicing probabilities, - "px_scale", "px_rate", "px_dropout": gene expression decoder outputs.

get_reconstruction_loss_expression(x, px_rate, px_r, px_dropout)

Compute the reconstruction loss for gene expression data.

get_reconstruction_loss_splicing(x, atse_counts, junc_counts, mask, p, phi)

x – (N × J) binary matrix (often unused) atse_counts – (N × J) denominator counts junc_counts – (N × J) numerator counts phi – (J) concentration parameter per junction

inference(x, mask, batch_index, cont_covs, cat_covs, label, cell_idx, size_factor, n_samples=1)

Run the inference network.

Splits input x into gene expression and splicing parts, encodes each branch, and mixes their latent representations.

loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)

Compute the total loss combining gene expression and splicing reconstruction losses, latent KL divergence, and the modality alignment penalty.

For splicing, if count data is provided (via the keys "atse_counts_key" and "junc_counts_key"), the loss is computed using the specified binomial, DM, or beta-binomial likelihood; otherwise, binary cross-entropy is used.

Returns:

Type Description
LossOutput

A container with total loss, reconstruction losses, and KL divergence details.