Skip to content

spatialfusion.utils.embed_ae_utils

spatialfusion.utils.embed_ae_utils

Utility functions for extracting and saving AE embeddings and metadata.

This module provides: - safe_standardize: Robust z-score standardization for DataFrames. - extract_embeddings_for_all_samples: Extract embeddings for all samples using a trained AE model. - save_embeddings_separately: Save embeddings and metadata to disk.

extract_embeddings_for_all_samples(model, sample_list, base_path, device='cpu')

Extract embeddings for all samples using a trained AE model. Loads UNI and scGPT embeddings, matches cell IDs, standardizes features, and computes model embeddings. Also extracts cell type labels and sample names.

Parameters:

Name Type Description Default
model

Trained AE model with encoder1 and encoder2.

required
sample_list list

List of sample info (str or dict).

required
base_path str or Path

Base directory for samples.

required
device str

Device for model inference.

'cpu'

Returns:

Name Type Description
tuple

(z1_df, z2_df, z_joint_df, celltypes, samples) z1_df (pd.DataFrame): Embeddings from encoder1. z2_df (pd.DataFrame): Embeddings from encoder2. z_joint_df (pd.DataFrame): Averaged joint embeddings. celltypes (np.ndarray): Cell type labels. samples (np.ndarray): Sample names.

safe_standardize(df, fill_value=0.0, min_std=1e-05)

Standardizes a DataFrame (z-score per column) while avoiding NaNs and large numbers. Handles unsafe float16 input by casting to float32 first.

Any column with std < min_std is filled with fill_value.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame.

required
fill_value float

Value to fill for low-variance columns.

0.0
min_std float

Minimum allowed std for columns.

1e-05

Returns:

Type Description
DataFrame

pd.DataFrame: Standardized DataFrame (float32), no NaNs.

save_embeddings_separately(z1_df, z2_df, z_joint_df, celltypes, samples, out_dir, mode='train', compression='gzip')

Save embeddings and metadata to disk as Parquet and HDF5 files.

Parameters:

Name Type Description Default
z1_df DataFrame

Embeddings from encoder1.

required
z2_df DataFrame

Embeddings from encoder2.

required
z_joint_df DataFrame

Joint embeddings.

required
celltypes ndarray

Cell type labels.

required
samples ndarray

Sample names.

required
out_dir str or Path

Output directory.

required
mode str

Mode string for filenames (e.g., 'train').

'train'
compression str

Compression type for HDF5 datasets.

'gzip'

spatialfusion.utils.ae_data_loader

Utility functions for loading and preprocessing multi-modal AE data.

This module provides: - load_file_with_fallback: Load DataFrame from CSV or Parquet with fallback. - safe_standardize: Robust z-score standardization for DataFrames. - load_and_preprocess_sample: Load, intersect, impute, and standardize paired sample embeddings.

load_and_preprocess_sample(sample_name, base_path, max_cells=30000)

Loads and preprocesses paired sample embeddings for AE training. - Loads UNI and scGPT embeddings for a sample. - Intersects cell IDs, samples up to max_cells. - Imputes NaNs with mean values. - Standardizes features robustly.

Parameters:

Name Type Description Default
sample_name str

Sample identifier.

required
base_path str or Path

Directory containing sample data.

required
max_cells int

Maximum number of cells to sample.

30000

Returns:

Name Type Description
tuple

(std_feat_1, std_feat_2, selected_ids) std_feat_1 (pd.DataFrame): Standardized UNI features. std_feat_2 (pd.DataFrame): Standardized scGPT features. selected_ids (list): List of selected cell IDs.

load_file_with_fallback(base_path, filename_base)

Attempts to load a DataFrame from CSV or Parquet. Raises FileNotFoundError if neither is available.

Parameters:

Name Type Description Default
base_path Path

Directory containing the file.

required
filename_base str

Base filename (without extension).

required

Returns:

Type Description

pd.DataFrame: Loaded DataFrame.

safe_standardize(df, fill_value=0.0, min_std=1e-05)

Standardizes a DataFrame (z-score per column) while avoiding NaNs and large numbers. Handles unsafe float16 input by casting to float32 first.

Any column with std < min_std is filled with fill_value.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame.

required
fill_value float

Value to fill for low-variance columns.

0.0
min_std float

Minimum allowed std for columns.

1e-05

Returns:

Type Description
DataFrame

pd.DataFrame: Standardized DataFrame (float32), no NaNs.