spatialfusion.utils.embed_ae_utils
spatialfusion.utils.embed_ae_utils
Utility functions for extracting and saving AE embeddings and metadata.
This module provides: - safe_standardize: Robust z-score standardization for DataFrames. - extract_embeddings_for_all_samples: Extract embeddings for all samples using a trained AE model. - save_embeddings_separately: Save embeddings and metadata to disk.
extract_embeddings_for_all_samples(model, sample_list, base_path, device='cpu')
Extract embeddings for all samples using a trained AE model. Loads UNI and scGPT embeddings, matches cell IDs, standardizes features, and computes model embeddings. Also extracts cell type labels and sample names.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Trained AE model with encoder1 and encoder2. |
required | |
sample_list
|
list
|
List of sample info (str or dict). |
required |
base_path
|
str or Path
|
Base directory for samples. |
required |
device
|
str
|
Device for model inference. |
'cpu'
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
(z1_df, z2_df, z_joint_df, celltypes, samples) z1_df (pd.DataFrame): Embeddings from encoder1. z2_df (pd.DataFrame): Embeddings from encoder2. z_joint_df (pd.DataFrame): Averaged joint embeddings. celltypes (np.ndarray): Cell type labels. samples (np.ndarray): Sample names. |
safe_standardize(df, fill_value=0.0, min_std=1e-05)
Standardizes a DataFrame (z-score per column) while avoiding NaNs and large numbers. Handles unsafe float16 input by casting to float32 first.
Any column with std < min_std is filled with fill_value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Input DataFrame. |
required |
fill_value
|
float
|
Value to fill for low-variance columns. |
0.0
|
min_std
|
float
|
Minimum allowed std for columns. |
1e-05
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
pd.DataFrame: Standardized DataFrame (float32), no NaNs. |
save_embeddings_separately(z1_df, z2_df, z_joint_df, celltypes, samples, out_dir, mode='train', compression='gzip')
Save embeddings and metadata to disk as Parquet and HDF5 files.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z1_df
|
DataFrame
|
Embeddings from encoder1. |
required |
z2_df
|
DataFrame
|
Embeddings from encoder2. |
required |
z_joint_df
|
DataFrame
|
Joint embeddings. |
required |
celltypes
|
ndarray
|
Cell type labels. |
required |
samples
|
ndarray
|
Sample names. |
required |
out_dir
|
str or Path
|
Output directory. |
required |
mode
|
str
|
Mode string for filenames (e.g., 'train'). |
'train'
|
compression
|
str
|
Compression type for HDF5 datasets. |
'gzip'
|
spatialfusion.utils.ae_data_loader
Utility functions for loading and preprocessing multi-modal AE data.
This module provides: - load_file_with_fallback: Load DataFrame from CSV or Parquet with fallback. - safe_standardize: Robust z-score standardization for DataFrames. - load_and_preprocess_sample: Load, intersect, impute, and standardize paired sample embeddings.
load_and_preprocess_sample(sample_name, base_path, max_cells=30000)
Loads and preprocesses paired sample embeddings for AE training. - Loads UNI and scGPT embeddings for a sample. - Intersects cell IDs, samples up to max_cells. - Imputes NaNs with mean values. - Standardizes features robustly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample_name
|
str
|
Sample identifier. |
required |
base_path
|
str or Path
|
Directory containing sample data. |
required |
max_cells
|
int
|
Maximum number of cells to sample. |
30000
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
(std_feat_1, std_feat_2, selected_ids) std_feat_1 (pd.DataFrame): Standardized UNI features. std_feat_2 (pd.DataFrame): Standardized scGPT features. selected_ids (list): List of selected cell IDs. |
load_file_with_fallback(base_path, filename_base)
Attempts to load a DataFrame from CSV or Parquet. Raises FileNotFoundError if neither is available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_path
|
Path
|
Directory containing the file. |
required |
filename_base
|
str
|
Base filename (without extension). |
required |
Returns:
| Type | Description |
|---|---|
|
pd.DataFrame: Loaded DataFrame. |
safe_standardize(df, fill_value=0.0, min_std=1e-05)
Standardizes a DataFrame (z-score per column) while avoiding NaNs and large numbers. Handles unsafe float16 input by casting to float32 first.
Any column with std < min_std is filled with fill_value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Input DataFrame. |
required |
fill_value
|
float
|
Value to fill for low-variance columns. |
0.0
|
min_std
|
float
|
Minimum allowed std for columns. |
1e-05
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
pd.DataFrame: Standardized DataFrame (float32), no NaNs. |