spatialfusion.embed.embed
spatialfusion.embed.embed
Autoencoder (AE) and Graph Convolutional Network (GCN) embedding pipeline for SpatialFusion.
This module provides utilities for: - Running a paired autoencoder on UNI / scGPT embeddings - Combining embeddings across modalities - Constructing spatial graphs - Running a GCN to produce final embeddings - Orchestrating the full end-to-end embedding workflow
AEInputs
dataclass
Container for in-memory autoencoder inputs.
Attributes:
| Name | Type | Description |
|---|---|---|
adata |
AnnData
|
AnnData object containing spatial metadata. |
z_uni |
DataFrame
|
UNI embeddings indexed by cell. |
z_scgpt |
Optional[DataFrame]
|
Optional scGPT embeddings indexed by cell. |
GCNInputs
dataclass
Container for inputs to the GCN stage.
Attributes:
| Name | Type | Description |
|---|---|---|
z_joint |
DataFrame
|
Joint AE embedding indexed by cell ID. |
adata_by_sample |
Dict[str, AnnData]
|
Mapping from sample ID to AnnData objects, each containing spatial coordinates and metadata. |
ae_from_arrays(model, inputs, device='cuda:0', combine_mode='average')
Run a pretrained PairedAE on in-memory UNI and scGPT embeddings.
This function standardizes the input embeddings using the same preprocessing
applied during training, runs the paired autoencoder, and combines the
resulting latent representations according to combine_mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PairedAE
|
Pretrained PairedAE model. |
required |
inputs
|
AEInputs
|
In-memory inputs containing AnnData and modality embeddings. |
required |
device
|
str
|
Torch device used for inference. |
'cuda:0'
|
combine_mode
|
Literal['average', 'concat', 'z1', 'z2']
|
Strategy for combining modality embeddings. One of {"average", "concat", "z1", "z2"}. |
'average'
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
A tuple |
DataFrame
|
All DataFrames are indexed by cell IDs aligned to |
Raises:
| Type | Description |
|---|---|
ValueError
|
If required embeddings are missing or no overlapping cells are found between AnnData and embeddings. |
ae_from_disk_for_samples(model, sample_list, base_path, device='cuda:0', combine_mode='average', save_dir=None)
Run the paired autoencoder using disk-based inputs for multiple samples.
This function loads UNI and scGPT embeddings from disk for each sample, runs the PairedAE model, combines modality embeddings, and optionally saves the outputs to disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PairedAE
|
Pretrained PairedAE model. |
required |
sample_list
|
Iterable[str]
|
Iterable of sample identifiers. |
required |
base_path
|
Union[str, Path]
|
Base directory containing sample subfolders. |
required |
device
|
str
|
Torch device used for inference. |
'cuda:0'
|
combine_mode
|
Literal['average', 'concat', 'z1', 'z2']
|
Strategy for combining modality embeddings. One of {"average", "concat", "z1", "z2"}. |
'average'
|
save_dir
|
Optional[Union[str, Path]]
|
Optional directory in which to save AE outputs. |
None
|
Returns:
| Type | Description |
|---|---|
|
A tuple |
gcn_embeddings_from_joint(gcn_model, z_joint, adata_by_sample, base_path, device='cuda:0', spatial_key='spatial', celltype_key='celltypes', k=30)
Generate GCN embeddings from joint AE embeddings and spatial graphs.
This function constructs spatial graphs for each sample and applies a pretrained GCN model to produce final embeddings with associated metadata.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gcn_model
|
GCNAutoencoder
|
Pretrained GCN autoencoder. |
required |
z_joint
|
DataFrame
|
Joint AE embedding indexed by cell ID. |
required |
adata_by_sample
|
Dict[str, AnnData]
|
Mapping from sample ID to AnnData. |
required |
base_path
|
Union[str, Path]
|
Base path used for metadata resolution. |
required |
device
|
str
|
Torch device used for inference. |
'cuda:0'
|
spatial_key
|
str
|
Key in AnnData.obsm containing spatial coordinates. |
'spatial'
|
celltype_key
|
str
|
Key in AnnData.obs containing cell type annotations. |
'celltypes'
|
k
|
int
|
Number of neighbors for KNN graph construction. |
30
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
DataFrame containing GCN embeddings and associated metadata. |
graphs_from_embeddings_and_adata(z_joint, adata_by_sample, spatial_key='spatial', k=30)
Construct spatial KNN graphs from joint embeddings and AnnData objects.
For each sample, this function:
- Aligns cells between z_joint and AnnData
- Standardizes joint embeddings
- Builds a k-nearest-neighbor graph using spatial coordinates
- Attaches node features to the graph
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z_joint
|
DataFrame
|
Joint AE embeddings indexed by cell ID. |
required |
adata_by_sample
|
Dict[str, AnnData]
|
Mapping from sample ID to AnnData. |
required |
spatial_key
|
str
|
Key in |
'spatial'
|
k
|
int
|
Number of nearest neighbors for graph construction. |
30
|
Returns:
| Type | Description |
|---|---|
Tuple[List['dgl.DGLGraph'], List[str]]
|
A tuple |
infer_input_dims(sample_list, base_path, uni_path=None, scgpt_path=None)
Infer AE input dimensions from disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_list
|
Iterable[str]
|
Iterable of sample identifiers. |
required |
base_path
|
Union[str, Path]
|
Base directory containing sample subfolders. |
required |
uni_path
|
Optional[Union[str, Path]]
|
Optional explicit UNI file path. |
None
|
scgpt_path
|
Optional[Union[str, Path]]
|
Optional explicit scGPT file path. |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[int, int]
|
Tuple of (UNI dimension, scGPT dimension). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If no valid embeddings are found. |
infer_input_dims_from_files(uni_path, scgpt_path)
Infer embedding dimensions from UNI and scGPT files.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
uni_path
|
Union[str, Path]
|
Path to UNI embedding file. |
required |
scgpt_path
|
Union[str, Path]
|
Path to scGPT embedding file. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[int, int]
|
Tuple of (UNI dimension, scGPT dimension). |
load_gcn(gcn_ckpt, in_dim, device='cuda:0')
Load a pretrained GCN autoencoder from disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gcn_ckpt
|
Union[str, Path]
|
Path to the GCN checkpoint file. |
required |
in_dim
|
int
|
Input feature dimensionality. |
required |
device
|
str
|
Torch device on which to load the model. |
'cuda:0'
|
Returns:
| Type | Description |
|---|---|
GCNAutoencoder
|
A GCNAutoencoder instance in evaluation mode. |
load_paired_ae(ae_ckpt, d1_dim, d2_dim, latent_dim=64, device='cuda:0')
Load a pretrained PairedAE model from disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ae_ckpt
|
Union[str, Path]
|
Path to the AE checkpoint. |
required |
d1_dim
|
int
|
Input dimension of modality 1. |
required |
d2_dim
|
int
|
Input dimension of modality 2. |
required |
latent_dim
|
int
|
Latent dimension size. |
64
|
device
|
str
|
Torch device string. |
'cuda:0'
|
Returns:
| Type | Description |
|---|---|
PairedAE
|
Loaded PairedAE model in evaluation mode. |
run_full_embedding(*, ae_inputs_by_sample=None, sample_list=None, base_path=None, ae_model_path=None, gcn_model_path=None, ae_model=None, gcn_model=None, latent_dim=64, device='cuda:0', spatial_key='spatial_px', k=30, celltype_key='celltypes', combine_mode='average', uni_path=None, scgpt_path=None, save_ae_dir=None)
Run the full SpatialFusion embedding pipeline.
This function supports two execution modes:
1. In-memory mode using ae_inputs_by_sample
2. Disk-based mode using sample_list and base_path
In both cases, it: - Runs the paired autoencoder (AE) - Combines modality embeddings - Constructs spatial graphs - Runs the GCN to produce final embeddings
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ae_inputs_by_sample
|
Optional[Dict[str, AEInputs]]
|
Optional in-memory AE inputs per sample. |
None
|
sample_list
|
Optional[Iterable[str]]
|
Sample identifiers for disk-based execution. |
None
|
base_path
|
Optional[Union[str, Path]]
|
Base directory containing sample data. |
None
|
ae_model_path
|
Optional[Union[str, Path]]
|
Path to AE checkpoint (if AE model not provided). |
None
|
gcn_model_path
|
Optional[Union[str, Path]]
|
Path to GCN checkpoint (if GCN model not provided). |
None
|
ae_model
|
Optional[PairedAE]
|
Optional preloaded AE model. |
None
|
gcn_model
|
Optional[GCNAutoencoder]
|
Optional preloaded GCN model. |
None
|
latent_dim
|
int
|
Latent dimensionality of the AE. |
64
|
device
|
str
|
Torch device used for inference. |
'cuda:0'
|
spatial_key
|
str
|
Key in AnnData.obsm for spatial coordinates. |
'spatial_px'
|
k
|
int
|
Number of neighbors for spatial graph construction. |
30
|
celltype_key
|
str
|
Key in AnnData.obs for cell type labels. |
'celltypes'
|
combine_mode
|
Literal['average', 'concat', 'z1', 'z2']
|
Strategy for combining modality embeddings. |
'average'
|
uni_path
|
Optional[Union[str, Path]]
|
Optional UNI file path for dimension inference. |
None
|
scgpt_path
|
Optional[Union[str, Path]]
|
Optional scGPT file path for dimension inference. |
None
|
save_ae_dir
|
Optional[Union[str, Path]]
|
Optional directory to save AE outputs. |
None
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
DataFrame containing final GCN embeddings with metadata. |