Skip to content

spatialfusion.utils.embed_gcn_utils

spatialfusion.utils.embed_gcn_utils

Utility functions for extracting GCN embeddings and metadata for downstream analysis.

This module provides: - extract_gcn_embeddings_with_metadata: Extracts GCN embeddings and merges with cell type, spatial, and ligand-receptor metadata. Supports both full-graph and batched subgraph inference.

expand_k_hop(g, seeds, k)

Expand a set of seed nodes to include all nodes reachable within k hops.

This function performs exact k-hop expansion for both incoming and outgoing edges and is used for memory-efficient subgraph inference.

Parameters:

Name Type Description Default
g DGLGraph

Input graph.

required
seeds Tensor

Tensor of seed node indices.

required
k int

Number of hops.

required

Returns:

Type Description
Tensor

torch.Tensor: Sorted tensor of expanded node indices.

extract_gcn_embeddings_with_metadata(model, graphs, sample_list, base_path, z_joint, device='cuda:0', spatial_key='spatial_px', celltype_key='celltypes', adata_by_sample=None, batch_size=None, k_hop=2)

Extract GCN embeddings along with metadata such as spatial coordinates and cell types.

This function supports two inference modes:

  • Full-graph inference (default, exact behavior): If batch_size is None, the entire graph is moved to the device and processed in a single forward pass.

  • Batched subgraph inference (memory-efficient): If batch_size is an integer, nodes are processed in batches. For each batch of seed nodes, a k-hop subgraph is constructed to preserve exact receptive fields for multi-layer GCNs.

Supports both in-memory AnnData inputs and on-disk loading via base_path.

Parameters:

Name Type Description Default
model Module

Trained GCN model with an encode(graph, features) method.

required
graphs list[DGLGraph]

List of DGL graphs, one per sample, containing node features in ndata["feat"].

required
sample_list list[str]

List of sample IDs corresponding to the graphs.

required
base_path str or Path

Root directory containing per-sample subdirectories. Used only if adata_by_sample is not provided.

required
z_joint DataFrame

Joint embeddings from the autoencoder step. Used to align graph nodes with cell identifiers.

required
device str

Device string for inference (e.g., "cuda:0" or "cpu").

'cuda:0'
spatial_key str

Key name for spatial coordinates stored in adata.obsm.

'spatial_px'
celltype_key str

Column name in adata.obs or celltypes.csv for cell type annotation.

'celltypes'
adata_by_sample Optional[Dict[str, AnnData]]

Optional mapping from sample IDs to preloaded AnnData objects. If provided, this takes precedence over disk loading.

None
batch_size Optional[int]

Number of seed nodes per subgraph batch. If None, full-graph inference is used.

None
k_hop int

Number of hops for subgraph expansion when batching. Should match the effective receptive field of the GCN.

2

Returns:

Type Description
DataFrame

pd.DataFrame: Concatenated DataFrame of GCN embeddings across all samples, including metadata:

  • sample_id
  • cell_id
  • celltype (and optional subtype/niche labels)
  • spatial coordinates (X_coord, Y_coord)
  • optional ligand–receptor features if present

Raises:

Type Description
FileNotFoundError

If AnnData cannot be loaded from memory or disk.

ValueError

If graph node count does not match aligned cell identifiers.

infer_output_dim(model, feat_dim)

Infer the output dimensionality of a GCN encoder.

This function runs a dummy forward pass through model.encode using a single-node graph to determine the latent output size.

Parameters:

Name Type Description Default
model Module

GCN model with an encode method.

required
feat_dim int

Input feature dimensionality.

required

Returns:

Name Type Description
int int

Output embedding dimensionality.