Skip to content

spatialfusion.embed.embed

spatialfusion.embed.embed

Autoencoder (AE) and Graph Convolutional Network (GCN) embedding pipeline for SpatialFusion.

This module provides utilities for: - Running a paired autoencoder on UNI / scGPT embeddings - Combining embeddings across modalities - Constructing spatial graphs - Running a GCN to produce final embeddings - Orchestrating the full end-to-end embedding workflow

AEInputs dataclass

Container for in-memory autoencoder inputs.

Attributes:

Name Type Description
adata AnnData

AnnData object containing spatial metadata.

z_uni DataFrame

UNI embeddings indexed by cell.

z_scgpt Optional[DataFrame]

Optional scGPT embeddings indexed by cell.

GCNInputs dataclass

Container for inputs to the GCN stage.

Attributes:

Name Type Description
z_joint DataFrame

Joint AE embedding indexed by cell ID.

adata_by_sample Dict[str, AnnData]

Mapping from sample ID to AnnData objects, each containing spatial coordinates and metadata.

ae_from_arrays(model, inputs, device='cuda:0', combine_mode='average')

Run a pretrained PairedAE on in-memory UNI and scGPT embeddings.

This function standardizes the input embeddings using the same preprocessing applied during training, runs the paired autoencoder, and combines the resulting latent representations according to combine_mode.

Parameters:

Name Type Description Default
model PairedAE

Pretrained PairedAE model.

required
inputs AEInputs

In-memory inputs containing AnnData and modality embeddings.

required
device str

Torch device used for inference.

'cuda:0'
combine_mode Literal['average', 'concat', 'z1', 'z2']

Strategy for combining modality embeddings. One of {"average", "concat", "z1", "z2"}.

'average'

Returns:

Type Description
DataFrame

A tuple (z1_df, z2_df, z_joint_df) where: - z1_df: Latent embeddings for UNI modality (cells × latent_dim). - z2_df: Latent embeddings for scGPT modality (cells × latent_dim), or empty if not produced. - z_joint_df: Combined latent embedding according to combine_mode.

DataFrame

All DataFrames are indexed by cell IDs aligned to inputs.adata.obs_names.

Raises:

Type Description
ValueError

If required embeddings are missing or no overlapping cells are found between AnnData and embeddings.

ae_from_disk_for_samples(model, sample_list, base_path, device='cuda:0', combine_mode='average', save_dir=None)

Run the paired autoencoder using disk-based inputs for multiple samples.

This function loads UNI and scGPT embeddings from disk for each sample, runs the PairedAE model, combines modality embeddings, and optionally saves the outputs to disk.

Parameters:

Name Type Description Default
model PairedAE

Pretrained PairedAE model.

required
sample_list Iterable[str]

Iterable of sample identifiers.

required
base_path Union[str, Path]

Base directory containing sample subfolders.

required
device str

Torch device used for inference.

'cuda:0'
combine_mode Literal['average', 'concat', 'z1', 'z2']

Strategy for combining modality embeddings. One of {"average", "concat", "z1", "z2"}.

'average'
save_dir Optional[Union[str, Path]]

Optional directory in which to save AE outputs.

None

Returns:

Type Description

A tuple (z1_df, z2_df, z_joint_df) containing: - z1_df: UNI latent embeddings. - z2_df: scGPT latent embeddings. - z_joint_df: Combined latent embeddings.

gcn_embeddings_from_joint(gcn_model, z_joint, adata_by_sample, base_path, device='cuda:0', spatial_key='spatial', celltype_key='celltypes', k=30)

Generate GCN embeddings from joint AE embeddings and spatial graphs.

This function constructs spatial graphs for each sample and applies a pretrained GCN model to produce final embeddings with associated metadata.

Parameters:

Name Type Description Default
gcn_model GCNAutoencoder

Pretrained GCN autoencoder.

required
z_joint DataFrame

Joint AE embedding indexed by cell ID.

required
adata_by_sample Dict[str, AnnData]

Mapping from sample ID to AnnData.

required
base_path Union[str, Path]

Base path used for metadata resolution.

required
device str

Torch device used for inference.

'cuda:0'
spatial_key str

Key in AnnData.obsm containing spatial coordinates.

'spatial'
celltype_key str

Key in AnnData.obs containing cell type annotations.

'celltypes'
k int

Number of neighbors for KNN graph construction.

30

Returns:

Type Description
DataFrame

DataFrame containing GCN embeddings and associated metadata.

graphs_from_embeddings_and_adata(z_joint, adata_by_sample, spatial_key='spatial', k=30)

Construct spatial KNN graphs from joint embeddings and AnnData objects.

For each sample, this function: - Aligns cells between z_joint and AnnData - Standardizes joint embeddings - Builds a k-nearest-neighbor graph using spatial coordinates - Attaches node features to the graph

Parameters:

Name Type Description Default
z_joint DataFrame

Joint AE embeddings indexed by cell ID.

required
adata_by_sample Dict[str, AnnData]

Mapping from sample ID to AnnData.

required
spatial_key str

Key in adata.obsm containing spatial coordinates.

'spatial'
k int

Number of nearest neighbors for graph construction.

30

Returns:

Type Description
Tuple[List['dgl.DGLGraph'], List[str]]

A tuple (graphs, keep_samples) where: - graphs: List of DGL graphs, one per retained sample. - keep_samples: List of sample IDs corresponding to the graphs.

infer_input_dims(sample_list, base_path, uni_path=None, scgpt_path=None)

Infer AE input dimensions from disk.

Parameters:

Name Type Description Default
sample_list Iterable[str]

Iterable of sample identifiers.

required
base_path Union[str, Path]

Base directory containing sample subfolders.

required
uni_path Optional[Union[str, Path]]

Optional explicit UNI file path.

None
scgpt_path Optional[Union[str, Path]]

Optional explicit scGPT file path.

None

Returns:

Type Description
Tuple[int, int]

Tuple of (UNI dimension, scGPT dimension).

Raises:

Type Description
ValueError

If no valid embeddings are found.

infer_input_dims_from_files(uni_path, scgpt_path)

Infer embedding dimensions from UNI and scGPT files.

Parameters:

Name Type Description Default
uni_path Union[str, Path]

Path to UNI embedding file.

required
scgpt_path Union[str, Path]

Path to scGPT embedding file.

required

Returns:

Type Description
Tuple[int, int]

Tuple of (UNI dimension, scGPT dimension).

load_gcn(gcn_ckpt, in_dim, device='cuda:0')

Load a pretrained GCN autoencoder from disk.

Parameters:

Name Type Description Default
gcn_ckpt Union[str, Path]

Path to the GCN checkpoint file.

required
in_dim int

Input feature dimensionality.

required
device str

Torch device on which to load the model.

'cuda:0'

Returns:

Type Description
GCNAutoencoder

A GCNAutoencoder instance in evaluation mode.

load_paired_ae(ae_ckpt, d1_dim, d2_dim, latent_dim=64, device='cuda:0')

Load a pretrained PairedAE model from disk.

Parameters:

Name Type Description Default
ae_ckpt Union[str, Path]

Path to the AE checkpoint.

required
d1_dim int

Input dimension of modality 1.

required
d2_dim int

Input dimension of modality 2.

required
latent_dim int

Latent dimension size.

64
device str

Torch device string.

'cuda:0'

Returns:

Type Description
PairedAE

Loaded PairedAE model in evaluation mode.

run_full_embedding(*, ae_inputs_by_sample=None, sample_list=None, base_path=None, ae_model_path=None, gcn_model_path=None, ae_model=None, gcn_model=None, latent_dim=64, device='cuda:0', spatial_key='spatial_px', k=30, celltype_key='celltypes', combine_mode='average', uni_path=None, scgpt_path=None, save_ae_dir=None)

Run the full SpatialFusion embedding pipeline.

This function supports two execution modes: 1. In-memory mode using ae_inputs_by_sample 2. Disk-based mode using sample_list and base_path

In both cases, it: - Runs the paired autoencoder (AE) - Combines modality embeddings - Constructs spatial graphs - Runs the GCN to produce final embeddings

Parameters:

Name Type Description Default
ae_inputs_by_sample Optional[Dict[str, AEInputs]]

Optional in-memory AE inputs per sample.

None
sample_list Optional[Iterable[str]]

Sample identifiers for disk-based execution.

None
base_path Optional[Union[str, Path]]

Base directory containing sample data.

None
ae_model_path Optional[Union[str, Path]]

Path to AE checkpoint (if AE model not provided).

None
gcn_model_path Optional[Union[str, Path]]

Path to GCN checkpoint (if GCN model not provided).

None
ae_model Optional[PairedAE]

Optional preloaded AE model.

None
gcn_model Optional[GCNAutoencoder]

Optional preloaded GCN model.

None
latent_dim int

Latent dimensionality of the AE.

64
device str

Torch device used for inference.

'cuda:0'
spatial_key str

Key in AnnData.obsm for spatial coordinates.

'spatial_px'
k int

Number of neighbors for spatial graph construction.

30
celltype_key str

Key in AnnData.obs for cell type labels.

'celltypes'
combine_mode Literal['average', 'concat', 'z1', 'z2']

Strategy for combining modality embeddings.

'average'
uni_path Optional[Union[str, Path]]

Optional UNI file path for dimension inference.

None
scgpt_path Optional[Union[str, Path]]

Optional scGPT file path for dimension inference.

None
save_ae_dir Optional[Union[str, Path]]

Optional directory to save AE outputs.

None

Returns:

Type Description
DataFrame

DataFrame containing final GCN embeddings with metadata.