spatialfusion.finetune.finetune
spatialfusion.finetune.finetune
ae_from_arrays_finetune(model, feat1, feat2, adata=None, device='cuda:0')
Run a PairedAE model on in-memory feature matrices for fine-tuning.
This function mirrors the preprocessing used in disk-based embedding extraction: - Converts indices to strings - Aligns features across modalities (and AnnData if provided) - Applies safe standardization - Uses encoder networks directly (no reconstruction loss)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Pretrained PairedAE model. |
required | |
feat1
|
DataFrame
|
Feature matrix for modality 1 (cells × features). |
required |
feat2
|
DataFrame
|
Feature matrix for modality 2 (cells × features). |
required |
adata
|
Optional[AnnData]
|
Optional AnnData object for additional index alignment. |
None
|
device
|
str
|
Torch device used for inference. |
'cuda:0'
|
Returns:
| Type | Description |
|---|---|
Tuple[DataFrame, DataFrame, DataFrame]
|
A tuple |
Raises:
| Type | Description |
|---|---|
ValueError
|
If no overlapping cell identifiers are found. |
build_ae_dataset(samples, base_path=None, preloaded_data=None, batch_size=128, max_cells=10 ** 6)
Build a PyTorch DataLoader for autoencoder fine-tuning.
This function supports two data sources:
- Disk-based loading using base_path
- In-memory loading using preloaded_data
In both cases, the same preprocessing steps are applied to ensure consistency with the original AE training pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples
|
List of sample identifiers. |
required | |
base_path
|
Root directory containing sample data on disk. |
None
|
|
preloaded_data
|
Optional dict mapping sample name to
|
None
|
|
batch_size
|
Number of cell pairs per batch. |
128
|
|
max_cells
|
Maximum number of cells loaded per sample (disk mode). |
10 ** 6
|
Returns:
| Type | Description |
|---|---|
|
A tuple |
Raises:
| Type | Description |
|---|---|
ValueError
|
If neither |
KeyError
|
If a requested sample is missing from |
build_graphs(samples, z_joint_df, base_path=None, adatas=None, pathway_data=None, knn_k=30, subgraph_size=5000, stride=2500, use_cls_loss=True, spatial_key='spatial_he')
Construct spatial graphs (and overlapping subgraphs) for GCN fine-tuning.
For each sample, this function: - Loads or uses preloaded AnnData - Aligns cells with joint AE embeddings - Builds a kNN spatial graph - Optionally attaches pathway activation labels - Generates overlapping subgraphs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples
|
List[str]
|
List of sample identifiers. |
required |
z_joint_df
|
DataFrame
|
Joint AE embedding indexed by cell ID. |
required |
base_path
|
Optional[str]
|
Root directory for loading AnnData and labels. |
None
|
adatas
|
Optional[Dict[str, Any]]
|
Optional dict of preloaded AnnData objects. |
None
|
pathway_data
|
Optional[Dict[str, DataFrame]]
|
Optional dict of pathway activation DataFrames. |
None
|
knn_k
|
int
|
Number of neighbors for kNN graph. |
30
|
subgraph_size
|
int
|
Number of nodes per subgraph. |
5000
|
stride
|
int
|
Stride between subgraph centers. |
2500
|
use_cls_loss
|
bool
|
Whether classification labels are used. |
True
|
spatial_key
|
str
|
Key in AnnData.obsm for spatial coordinates. |
'spatial_he'
|
Returns:
| Type | Description |
|---|---|
List[DGLGraph]
|
List of DGL graphs ready for GCN training. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If no graphs are successfully built. |
finetune_autoencoder(loader, d1_dim, d2_dim, pretrained_ae, save_dir, device, latent_dim=64, enc_hidden_dims=[64], dec_hidden_dims=[64], epochs=5, lr=0.0001, weight_decay=0.0, grad_clip=1.0, lambda_recon1=0.5, lambda_recon2=0.5, lambda_cross12=0.25, lambda_cross21=0.25, lambda_align=1.0)
Fine-tune a pretrained paired autoencoder.
The training objective includes reconstruction losses for each modality, cross-modality reconstruction losses, and a latent alignment loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loader
|
DataLoader
|
DataLoader yielding paired feature tensors. |
required |
d1_dim
|
int
|
Input dimension of modality 1. |
required |
d2_dim
|
int
|
Input dimension of modality 2. |
required |
pretrained_ae
|
str
|
Path to pretrained AE checkpoint. |
required |
save_dir
|
Path
|
Directory to save fine-tuned model and loss history. |
required |
device
|
device
|
Torch device for training. |
required |
latent_dim
|
int
|
Latent space dimensionality. |
64
|
enc_hidden_dims
|
List[int]
|
Hidden layer sizes for encoders. |
[64]
|
dec_hidden_dims
|
List[int]
|
Hidden layer sizes for decoders. |
[64]
|
epochs
|
int
|
Number of training epochs. |
5
|
lr
|
float
|
Learning rate. |
0.0001
|
weight_decay
|
float
|
Weight decay coefficient. |
0.0
|
grad_clip
|
Optional[float]
|
Optional gradient clipping norm. |
1.0
|
lambda_recon1
|
float
|
Weight for modality 1 reconstruction loss. |
0.5
|
lambda_recon2
|
float
|
Weight for modality 2 reconstruction loss. |
0.5
|
lambda_cross12
|
float
|
Weight for cross reconstruction (1→2). |
0.25
|
lambda_cross21
|
float
|
Weight for cross reconstruction (2→1). |
0.25
|
lambda_align
|
float
|
Weight for latent alignment loss. |
1.0
|
Returns:
| Type | Description |
|---|---|
Module
|
The fine-tuned PairedAE model. |
finetune_gcn(graphs, pretrained_gcn, save_dir, device, hidden_dim=10, num_layers=2, node_mask_ratio=0.9, epochs=10, batch_size=2, lr=0.0001, lambda_reg=0.001, lambda_cls=1.0, use_cls_loss=True, use_huber=True)
Fine-tune a pretrained GCN autoencoder on spatial graphs.
The loss includes feature reconstruction, latent regularization, and optional pathway classification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
graphs
|
List[DGLGraph]
|
List of DGL graphs for training. |
required |
pretrained_gcn
|
str
|
Path to pretrained GCN checkpoint. |
required |
save_dir
|
Path
|
Directory to save fine-tuned model and loss history. |
required |
device
|
device
|
Torch device for training. |
required |
hidden_dim
|
int
|
Hidden layer dimensionality. |
10
|
num_layers
|
int
|
Number of GCN layers. |
2
|
node_mask_ratio
|
float
|
Fraction of masked nodes. |
0.9
|
epochs
|
int
|
Number of training epochs. |
10
|
batch_size
|
int
|
Graph batch size. |
2
|
lr
|
float
|
Learning rate. |
0.0001
|
lambda_reg
|
float
|
Weight for latent regularization loss. |
0.001
|
lambda_cls
|
float
|
Weight for classification loss. |
1.0
|
use_cls_loss
|
bool
|
Whether to include classification loss. |
True
|
use_huber
|
bool
|
Whether to use Huber loss for classification. |
True
|
Returns:
| Type | Description |
|---|---|
Module
|
Fine-tuned GCNAutoencoder model. |
finetune_models(samples, base_path=None, pretrained_ae='', pretrained_gcn='', save_dir='./finetuned_outputs', preloaded_data=None, adatas=None, preloaded_pathway_data=None, latent_dim=64, enc_hidden_dims=[64], dec_hidden_dims=[64], ae_epochs=5, ae_batch_size=128, ae_lr=0.0001, ae_weight_decay=0.0, ae_grad_clip=1.0, lambda_recon1=0.5, lambda_recon2=0.5, lambda_cross12=0.25, lambda_cross21=0.25, lambda_align=1.0, knn_k=30, subgraph_size=5000, stride=2500, gcn_hidden_dim=10, gcn_num_layers=2, node_mask_ratio=0.9, gcn_epochs=10, gcn_batch_size=2, gcn_lr=0.0001, lambda_reg=0.001, lambda_cls=1.0, use_cls_loss=True, use_huber=True, spatial_key='spatial')
End-to-end fine-tuning pipeline for SpatialFusion models.
This function fine-tunes a paired autoencoder followed by a GCN, using either disk-based inputs or fully preloaded data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples
|
List[str]
|
List of sample identifiers. |
required |
base_path
|
Optional[str]
|
Root directory for disk-based data loading. |
None
|
pretrained_ae
|
str
|
Path to pretrained AE checkpoint. |
''
|
pretrained_gcn
|
str
|
Path to pretrained GCN checkpoint. |
''
|
save_dir
|
str
|
Output directory for fine-tuned models and logs. |
'./finetuned_outputs'
|
preloaded_data
|
Optional[Dict[str, Tuple[DataFrame, DataFrame]]]
|
Optional in-memory feature data per sample. |
None
|
adatas
|
Optional[Dict[str, Any]]
|
Optional preloaded AnnData objects. |
None
|
preloaded_pathway_data
|
Optional[Dict[str, DataFrame]]
|
Optional pathway activation labels. |
None
|
latent_dim
|
int
|
AE latent dimensionality. |
64
|
enc_hidden_dims
|
List[int]
|
Encoder hidden layer sizes. |
[64]
|
dec_hidden_dims
|
List[int]
|
Decoder hidden layer sizes. |
[64]
|
ae_epochs
|
int
|
Number of AE fine-tuning epochs. |
5
|
ae_batch_size
|
int
|
AE batch size. |
128
|
ae_lr
|
float
|
AE learning rate. |
0.0001
|
ae_weight_decay
|
float
|
AE weight decay. |
0.0
|
ae_grad_clip
|
Optional[float]
|
AE gradient clipping norm. |
1.0
|
lambda_recon1
|
float
|
AE reconstruction loss weight (modality 1). |
0.5
|
lambda_recon2
|
float
|
AE reconstruction loss weight (modality 2). |
0.5
|
lambda_cross12
|
float
|
AE cross reconstruction loss weight (1→2). |
0.25
|
lambda_cross21
|
float
|
AE cross reconstruction loss weight (2→1). |
0.25
|
lambda_align
|
float
|
AE latent alignment loss weight. |
1.0
|
knn_k
|
int
|
Number of neighbors for spatial graph construction. |
30
|
subgraph_size
|
int
|
Size of spatial subgraphs. |
5000
|
stride
|
int
|
Stride between subgraphs. |
2500
|
gcn_hidden_dim
|
int
|
GCN hidden layer size. |
10
|
gcn_num_layers
|
int
|
Number of GCN layers. |
2
|
node_mask_ratio
|
float
|
GCN node masking ratio. |
0.9
|
gcn_epochs
|
int
|
Number of GCN fine-tuning epochs. |
10
|
gcn_batch_size
|
int
|
GCN batch size. |
2
|
gcn_lr
|
float
|
GCN learning rate. |
0.0001
|
lambda_reg
|
float
|
GCN latent regularization weight. |
0.001
|
lambda_cls
|
float
|
GCN classification loss weight. |
1.0
|
use_cls_loss
|
bool
|
Whether to use pathway classification loss. |
True
|
use_huber
|
bool
|
Whether to use Huber loss for classification. |
True
|
spatial_key
|
str
|
Key for spatial coordinates in AnnData. |
'spatial'
|
Returns:
| Type | Description |
|---|---|
|
A tuple |
get_coords(adata, eps=1e-06, key='spatial')
Extract and standardize spatial coordinates from an AnnData object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
adata
|
AnnData containing spatial coordinates. |
required | |
eps
|
Small constant to avoid division by zero. |
1e-06
|
|
key
|
Key in |
'spatial'
|
Returns:
| Type | Description |
|---|---|
|
Standardized coordinate array. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If the specified spatial key is not present. |
get_device()
Select an available computation device.
standardize_pathways(df, method='robust_z', eps=1e-06, tol=0.001)
Standardize pathway activation scores column-wise.
Supported methods: - 'robust_z': (x - median) / IQR - 'z': (x - mean) / std
Columns with near-zero variation are set to zero, and all NaN or infinite values are replaced with 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Pathway activation DataFrame. |
required |
method
|
str
|
Standardization method ('robust_z' or 'z'). |
'robust_z'
|
eps
|
float
|
Numerical stability constant. |
1e-06
|
tol
|
float
|
Threshold for detecting near-zero columns. |
0.001
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Standardized pathway activation DataFrame (float32). |