SPLICEVAE
The underlying PyTorch module. Most users will interact with SPLICEVI rather than this class directly.
splicevi.splicevae.SPLICEVAE
Bases: BaseModuleClass
Variational auto encoder for joint (paired or unpaired) RNA-seq gene expression and alternative splicing (junction usage). Two encoder–decoder branches (expression and splicing) produce posterior latents that are mixed into a shared latent space z.
Reconstruction is performed per modality. Optional penalties align the two posteriors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_input_genes
|
int
|
Number of gene expression features (G). |
0
|
n_input_junctions
|
int
|
Number of splicing junction features (J). |
0
|
n_batch
|
int
|
Number of batches for batch correction (categorical covariate). |
0
|
n_obs
|
int
|
Number of observations (cells); needed when modality_weights="cell". |
0
|
n_labels
|
int
|
Number of labels (for gene_dispersion="gene-label"). |
0
|
n_cats_per_cov
|
Iterable[int] or None
|
Category counts for each categorical covariate provided. |
None
|
n_continuous_cov
|
int
|
Number of continuous covariates. |
0
|
gene_likelihood
|
(zinb, nb, poisson)
|
Expression likelihood. |
"zinb","nb","poisson"
|
gene_dispersion
|
(gene, gene - batch, gene - label, gene - cell)
|
Dispersion layout for expression. |
"gene","gene-batch","gene-label","gene-cell"
|
use_size_factor_key
|
bool
|
If True, use provided library size factors for expression. Otherwise learn a library size encoder. |
False
|
splicing_encoder_architecture
|
(vanilla, partial)
|
Splicing encoder choice: "vanilla" uses a standard SCVI Encoder. "partial" uses PartialEncoderEDDIFaster. |
"vanilla","partial"
|
splicing_decoder_architecture
|
(vanilla, linear)
|
Splicing decoder choice: "vanilla" uses DecoderSplice (nonlinear). "linear" uses a LinearDecoder. |
"vanilla","linear"
|
expression_architecture
|
(vanilla, linear)
|
Expression decoder: "vanilla" uses a non linear DecoderSCVI. "linear" uses a LinearDecoderSCVI. |
"vanilla","linear"
|
n_hidden
|
int or None
|
Hidden width for SCVI style MLPs. If None, defaults to roughly sqrt(J) capped at 128. |
None
|
n_latent
|
int or None
|
Latent dimensionality. If None, defaults to roughly sqrt(n_hidden). When modality_weights="concatenate", the mixed latent is doubled internally. |
None
|
n_layers_encoder
|
int
|
Depth of SCVI encoders (expression branch and splicing branch when splicing_encoder_architecture="vanilla"). Also controls the depth of the post pooling MLP in the partial splicing encoder. |
2
|
n_layers_decoder
|
int
|
Depth of non linear decoders (expression "vanilla", splicing "vanilla"). Not used by linear decoders. |
2
|
dropout_rate
|
float
|
Dropout for SCVI style MLPs. |
0.1
|
use_batch_norm
|
(encoder, decoder, none, both)
|
Apply BatchNorm to encoder or decoder stacks. |
"encoder","decoder","none","both"
|
use_layer_norm
|
(encoder, decoder, none, both)
|
Apply LayerNorm to encoder or decoder stacks. |
"encoder","decoder","none","both"
|
latent_distribution
|
(normal, ln)
|
Posterior family for encoders. If "ln" (logistic normal), the latent sample is softmax transformed before decoding. |
"normal","ln"
|
deeply_inject_covariates
|
bool
|
Deeply inject categorical and continuous covariates into decoder layers. |
False
|
encode_covariates
|
bool
|
Concatenate continuous covariates to encoder inputs and pass categorical covariates via n_cat_list. |
False
|
splicing_loss_type
|
(binomial, beta_binomial, dirichlet_multinomial)
|
Reconstruction loss for splicing. "binomial" expects junction counts and ATSE totals. "beta_binomial" uses a beta binomial with a learned concentration. "dirichlet_multinomial" uses grouped softmax within ATSEs. |
"binomial","beta_binomial","dirichlet_multinomial"
|
splicing_concentration
|
float or None
|
Optional scalar concentration for the beta binomial case. |
None
|
dm_concentration
|
(atse, scalar)
|
For Dirichlet multinomial. Controls whether the concentration is per ATSE or scalar. |
"atse","scalar"
|
code_dim
|
int
|
Dimensionality of per junction codes before pooling. |
16
|
h_hidden_dim
|
int
|
Hidden width of the per junction h subnetwork that combines PSI and feature embedding. |
64
|
encoder_hidden_dim
|
int
|
Hidden width of the post pooling MLP that maps the pooled code to (mu, log var) of z. |
128
|
pool_mode
|
(mean, sum)
|
Aggregation of per junction codes within each cell. |
"mean","sum"
|
max_nobs
|
int
|
Optional cap on the number of observed entries processed in each scatter chunk. A negative value disables chunking. |
-1
|
modality_weights
|
(equal, cell, universal, concatenate, per_dimension_weighted_average)
|
How to combine expression and splicing posteriors. "per_dimension_weighted_average" learns a per-latent-dimension splicing contribution weight w_d in [0,1] (expression weight = 1-w_d), constrained so the mean weight stays at 0.5 (preventing one modality from dominating globally). |
"equal","cell","universal","concatenate","per_dimension_weighted_average"
|
modality_penalty
|
(Jeffreys, MMD, None)
|
Alignment penalty between the two posteriors on paired cells. |
"Jeffreys","MMD","None"
|
**model_kwargs
|
Forwarded to underlying components. |
{}
|
Notes
The partial splicing branch uses a single PartialEncoderEDDIFaster encoder that aggregates only observed junctions per cell. The same n_layers_encoder parameter controls the depth of the final MLP that produces latent statistics for this branch.
dirichlet_multinomial_likelihood(counts, atse_counts, junc2atse, alpha, mask=None)
Computes each cell’s Dirichlet–multinomial log-likelihood, masking out any junctions where mask==0 and any ATSE-groups where atse_counts==0.
Returns: Tensor of shape (N,) giving per-cell log-likelihood: ll[i] = sum_over_groups LL_group(i,g) – sum_over_junctions LL_junc(i,p)
generative(z, qz_m, batch_index, cont_covs=None, cat_covs=None, libsize_expr=None, use_z_mean=False, label=None)
Run the generative model to decode gene expression and splicing.
Decodes the latent representation into parameters for gene expression reconstruction and splicing probabilities.
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary with keys: - "p": decoded splicing probabilities, - "px_scale", "px_rate", "px_dropout": gene expression decoder outputs. |
get_reconstruction_loss_expression(x, px_rate, px_r, px_dropout)
Compute the reconstruction loss for gene expression data.
get_reconstruction_loss_splicing(x, atse_counts, junc_counts, mask, p, phi)
x – (N × J) binary matrix (often unused) atse_counts – (N × J) denominator counts junc_counts – (N × J) numerator counts phi – (J) concentration parameter per junction
inference(x, mask, batch_index, cont_covs, cat_covs, label, cell_idx, size_factor, n_samples=1)
Run the inference network.
Splits input x into gene expression and splicing parts, encodes each branch, and mixes their latent representations.
loss(tensors, inference_outputs, generative_outputs, kl_weight=1.0)
Compute the total loss combining gene expression and splicing reconstruction losses, latent KL divergence, and the modality alignment penalty.
For splicing, if count data is provided (via the keys "atse_counts_key" and "junc_counts_key"), the loss is computed using the specified binomial, DM, or beta-binomial likelihood; otherwise, binary cross-entropy is used.
Returns:
| Type | Description |
|---|---|
LossOutput
|
A container with total loss, reconstruction losses, and KL divergence details. |