spatialfusion.utils.embed_gcn_utils
spatialfusion.utils.embed_gcn_utils
Utility functions for extracting GCN embeddings and metadata for downstream analysis.
This module provides: - extract_gcn_embeddings_with_metadata: Extracts GCN embeddings and merges with cell type, spatial, and ligand-receptor metadata. Supports both full-graph and batched subgraph inference.
expand_k_hop(g, seeds, k)
Expand a set of seed nodes to include all nodes reachable within k hops.
This function performs exact k-hop expansion for both incoming and outgoing edges and is used for memory-efficient subgraph inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
g
|
DGLGraph
|
Input graph. |
required |
seeds
|
Tensor
|
Tensor of seed node indices. |
required |
k
|
int
|
Number of hops. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Sorted tensor of expanded node indices. |
extract_gcn_embeddings_with_metadata(model, graphs, sample_list, base_path, z_joint, device='cuda:0', spatial_key='spatial_px', celltype_key='celltypes', adata_by_sample=None, batch_size=None, k_hop=2)
Extract GCN embeddings along with metadata such as spatial coordinates and cell types.
This function supports two inference modes:
-
Full-graph inference (default, exact behavior): If
batch_sizeis None, the entire graph is moved to the device and processed in a single forward pass. -
Batched subgraph inference (memory-efficient): If
batch_sizeis an integer, nodes are processed in batches. For each batch of seed nodes, a k-hop subgraph is constructed to preserve exact receptive fields for multi-layer GCNs.
Supports both in-memory AnnData inputs and on-disk loading via base_path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Trained GCN model with an |
required |
graphs
|
list[DGLGraph]
|
List of DGL graphs, one per sample, containing node features in |
required |
sample_list
|
list[str]
|
List of sample IDs corresponding to the graphs. |
required |
base_path
|
str or Path
|
Root directory containing per-sample subdirectories.
Used only if |
required |
z_joint
|
DataFrame
|
Joint embeddings from the autoencoder step. Used to align graph nodes with cell identifiers. |
required |
device
|
str
|
Device string for inference (e.g., |
'cuda:0'
|
spatial_key
|
str
|
Key name for spatial coordinates stored in |
'spatial_px'
|
celltype_key
|
str
|
Column name in |
'celltypes'
|
adata_by_sample
|
Optional[Dict[str, AnnData]]
|
Optional mapping from sample IDs to preloaded AnnData objects. If provided, this takes precedence over disk loading. |
None
|
batch_size
|
Optional[int]
|
Number of seed nodes per subgraph batch. If None, full-graph inference is used. |
None
|
k_hop
|
int
|
Number of hops for subgraph expansion when batching. Should match the effective receptive field of the GCN. |
2
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
pd.DataFrame: Concatenated DataFrame of GCN embeddings across all samples, including metadata:
|
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If AnnData cannot be loaded from memory or disk. |
ValueError
|
If graph node count does not match aligned cell identifiers. |
infer_output_dim(model, feat_dim)
Infer the output dimensionality of a GCN encoder.
This function runs a dummy forward pass through model.encode
using a single-node graph to determine the latent output size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
GCN model with an |
required |
feat_dim
|
int
|
Input feature dimensionality. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Output embedding dimensionality. |