Skip to content

spatialfusion.finetune.finetune

spatialfusion.finetune.finetune

ae_from_arrays_finetune(model, feat1, feat2, adata=None, device='cuda:0')

Run a PairedAE model on in-memory feature matrices for fine-tuning.

This function mirrors the preprocessing used in disk-based embedding extraction: - Converts indices to strings - Aligns features across modalities (and AnnData if provided) - Applies safe standardization - Uses encoder networks directly (no reconstruction loss)

Parameters:

Name Type Description Default
model

Pretrained PairedAE model.

required
feat1 DataFrame

Feature matrix for modality 1 (cells × features).

required
feat2 DataFrame

Feature matrix for modality 2 (cells × features).

required
adata Optional[AnnData]

Optional AnnData object for additional index alignment.

None
device str

Torch device used for inference.

'cuda:0'

Returns:

Type Description
Tuple[DataFrame, DataFrame, DataFrame]

A tuple (z1_df, z2_df, z_joint_df) where: - z1_df: Latent embeddings from encoder1. - z2_df: Latent embeddings from encoder2. - z_joint_df: Average of z1 and z2 embeddings.

Raises:

Type Description
ValueError

If no overlapping cell identifiers are found.

build_ae_dataset(samples, base_path=None, preloaded_data=None, batch_size=128, max_cells=10 ** 6)

Build a PyTorch DataLoader for autoencoder fine-tuning.

This function supports two data sources: - Disk-based loading using base_path - In-memory loading using preloaded_data

In both cases, the same preprocessing steps are applied to ensure consistency with the original AE training pipeline.

Parameters:

Name Type Description Default
samples

List of sample identifiers.

required
base_path

Root directory containing sample data on disk.

None
preloaded_data

Optional dict mapping sample name to (feat1_df, feat2_df) DataFrames.

None
batch_size

Number of cell pairs per batch.

128
max_cells

Maximum number of cells loaded per sample (disk mode).

10 ** 6

Returns:

Type Description

A tuple (loader, d1_dim, d2_dim) where: - loader: PyTorch DataLoader over paired features. - d1_dim: Feature dimension of modality 1. - d2_dim: Feature dimension of modality 2.

Raises:

Type Description
ValueError

If neither base_path nor preloaded_data is provided.

KeyError

If a requested sample is missing from preloaded_data.

build_graphs(samples, z_joint_df, base_path=None, adatas=None, pathway_data=None, knn_k=30, subgraph_size=5000, stride=2500, use_cls_loss=True, spatial_key='spatial_he')

Construct spatial graphs (and overlapping subgraphs) for GCN fine-tuning.

For each sample, this function: - Loads or uses preloaded AnnData - Aligns cells with joint AE embeddings - Builds a kNN spatial graph - Optionally attaches pathway activation labels - Generates overlapping subgraphs

Parameters:

Name Type Description Default
samples List[str]

List of sample identifiers.

required
z_joint_df DataFrame

Joint AE embedding indexed by cell ID.

required
base_path Optional[str]

Root directory for loading AnnData and labels.

None
adatas Optional[Dict[str, Any]]

Optional dict of preloaded AnnData objects.

None
pathway_data Optional[Dict[str, DataFrame]]

Optional dict of pathway activation DataFrames.

None
knn_k int

Number of neighbors for kNN graph.

30
subgraph_size int

Number of nodes per subgraph.

5000
stride int

Stride between subgraph centers.

2500
use_cls_loss bool

Whether classification labels are used.

True
spatial_key str

Key in AnnData.obsm for spatial coordinates.

'spatial_he'

Returns:

Type Description
List[DGLGraph]

List of DGL graphs ready for GCN training.

Raises:

Type Description
RuntimeError

If no graphs are successfully built.

finetune_autoencoder(loader, d1_dim, d2_dim, pretrained_ae, save_dir, device, latent_dim=64, enc_hidden_dims=[64], dec_hidden_dims=[64], epochs=5, lr=0.0001, weight_decay=0.0, grad_clip=1.0, lambda_recon1=0.5, lambda_recon2=0.5, lambda_cross12=0.25, lambda_cross21=0.25, lambda_align=1.0)

Fine-tune a pretrained paired autoencoder.

The training objective includes reconstruction losses for each modality, cross-modality reconstruction losses, and a latent alignment loss.

Parameters:

Name Type Description Default
loader DataLoader

DataLoader yielding paired feature tensors.

required
d1_dim int

Input dimension of modality 1.

required
d2_dim int

Input dimension of modality 2.

required
pretrained_ae str

Path to pretrained AE checkpoint.

required
save_dir Path

Directory to save fine-tuned model and loss history.

required
device device

Torch device for training.

required
latent_dim int

Latent space dimensionality.

64
enc_hidden_dims List[int]

Hidden layer sizes for encoders.

[64]
dec_hidden_dims List[int]

Hidden layer sizes for decoders.

[64]
epochs int

Number of training epochs.

5
lr float

Learning rate.

0.0001
weight_decay float

Weight decay coefficient.

0.0
grad_clip Optional[float]

Optional gradient clipping norm.

1.0
lambda_recon1 float

Weight for modality 1 reconstruction loss.

0.5
lambda_recon2 float

Weight for modality 2 reconstruction loss.

0.5
lambda_cross12 float

Weight for cross reconstruction (1→2).

0.25
lambda_cross21 float

Weight for cross reconstruction (2→1).

0.25
lambda_align float

Weight for latent alignment loss.

1.0

Returns:

Type Description
Module

The fine-tuned PairedAE model.

finetune_gcn(graphs, pretrained_gcn, save_dir, device, hidden_dim=10, num_layers=2, node_mask_ratio=0.9, epochs=10, batch_size=2, lr=0.0001, lambda_reg=0.001, lambda_cls=1.0, use_cls_loss=True, use_huber=True)

Fine-tune a pretrained GCN autoencoder on spatial graphs.

The loss includes feature reconstruction, latent regularization, and optional pathway classification.

Parameters:

Name Type Description Default
graphs List[DGLGraph]

List of DGL graphs for training.

required
pretrained_gcn str

Path to pretrained GCN checkpoint.

required
save_dir Path

Directory to save fine-tuned model and loss history.

required
device device

Torch device for training.

required
hidden_dim int

Hidden layer dimensionality.

10
num_layers int

Number of GCN layers.

2
node_mask_ratio float

Fraction of masked nodes.

0.9
epochs int

Number of training epochs.

10
batch_size int

Graph batch size.

2
lr float

Learning rate.

0.0001
lambda_reg float

Weight for latent regularization loss.

0.001
lambda_cls float

Weight for classification loss.

1.0
use_cls_loss bool

Whether to include classification loss.

True
use_huber bool

Whether to use Huber loss for classification.

True

Returns:

Type Description
Module

Fine-tuned GCNAutoencoder model.

finetune_models(samples, base_path=None, pretrained_ae='', pretrained_gcn='', save_dir='./finetuned_outputs', preloaded_data=None, adatas=None, preloaded_pathway_data=None, latent_dim=64, enc_hidden_dims=[64], dec_hidden_dims=[64], ae_epochs=5, ae_batch_size=128, ae_lr=0.0001, ae_weight_decay=0.0, ae_grad_clip=1.0, lambda_recon1=0.5, lambda_recon2=0.5, lambda_cross12=0.25, lambda_cross21=0.25, lambda_align=1.0, knn_k=30, subgraph_size=5000, stride=2500, gcn_hidden_dim=10, gcn_num_layers=2, node_mask_ratio=0.9, gcn_epochs=10, gcn_batch_size=2, gcn_lr=0.0001, lambda_reg=0.001, lambda_cls=1.0, use_cls_loss=True, use_huber=True, spatial_key='spatial')

End-to-end fine-tuning pipeline for SpatialFusion models.

This function fine-tunes a paired autoencoder followed by a GCN, using either disk-based inputs or fully preloaded data.

Parameters:

Name Type Description Default
samples List[str]

List of sample identifiers.

required
base_path Optional[str]

Root directory for disk-based data loading.

None
pretrained_ae str

Path to pretrained AE checkpoint.

''
pretrained_gcn str

Path to pretrained GCN checkpoint.

''
save_dir str

Output directory for fine-tuned models and logs.

'./finetuned_outputs'
preloaded_data Optional[Dict[str, Tuple[DataFrame, DataFrame]]]

Optional in-memory feature data per sample.

None
adatas Optional[Dict[str, Any]]

Optional preloaded AnnData objects.

None
preloaded_pathway_data Optional[Dict[str, DataFrame]]

Optional pathway activation labels.

None
latent_dim int

AE latent dimensionality.

64
enc_hidden_dims List[int]

Encoder hidden layer sizes.

[64]
dec_hidden_dims List[int]

Decoder hidden layer sizes.

[64]
ae_epochs int

Number of AE fine-tuning epochs.

5
ae_batch_size int

AE batch size.

128
ae_lr float

AE learning rate.

0.0001
ae_weight_decay float

AE weight decay.

0.0
ae_grad_clip Optional[float]

AE gradient clipping norm.

1.0
lambda_recon1 float

AE reconstruction loss weight (modality 1).

0.5
lambda_recon2 float

AE reconstruction loss weight (modality 2).

0.5
lambda_cross12 float

AE cross reconstruction loss weight (1→2).

0.25
lambda_cross21 float

AE cross reconstruction loss weight (2→1).

0.25
lambda_align float

AE latent alignment loss weight.

1.0
knn_k int

Number of neighbors for spatial graph construction.

30
subgraph_size int

Size of spatial subgraphs.

5000
stride int

Stride between subgraphs.

2500
gcn_hidden_dim int

GCN hidden layer size.

10
gcn_num_layers int

Number of GCN layers.

2
node_mask_ratio float

GCN node masking ratio.

0.9
gcn_epochs int

Number of GCN fine-tuning epochs.

10
gcn_batch_size int

GCN batch size.

2
gcn_lr float

GCN learning rate.

0.0001
lambda_reg float

GCN latent regularization weight.

0.001
lambda_cls float

GCN classification loss weight.

1.0
use_cls_loss bool

Whether to use pathway classification loss.

True
use_huber bool

Whether to use Huber loss for classification.

True
spatial_key str

Key for spatial coordinates in AnnData.

'spatial'

Returns:

Type Description

A tuple (ae_model, gcn_model) containing the fine-tuned models.

get_coords(adata, eps=1e-06, key='spatial')

Extract and standardize spatial coordinates from an AnnData object.

Parameters:

Name Type Description Default
adata

AnnData containing spatial coordinates.

required
eps

Small constant to avoid division by zero.

1e-06
key

Key in adata.obsm storing spatial coordinates.

'spatial'

Returns:

Type Description

Standardized coordinate array.

Raises:

Type Description
KeyError

If the specified spatial key is not present.

get_device()

Select an available computation device.

standardize_pathways(df, method='robust_z', eps=1e-06, tol=0.001)

Standardize pathway activation scores column-wise.

Supported methods: - 'robust_z': (x - median) / IQR - 'z': (x - mean) / std

Columns with near-zero variation are set to zero, and all NaN or infinite values are replaced with 0.

Parameters:

Name Type Description Default
df DataFrame

Pathway activation DataFrame.

required
method str

Standardization method ('robust_z' or 'z').

'robust_z'
eps float

Numerical stability constant.

1e-06
tol float

Threshold for detecting near-zero columns.

0.001

Returns:

Type Description
DataFrame

Standardized pathway activation DataFrame (float32).