API Reference¶
This page documents the public Python API of cetsax. The package is organized around the analytical flow of the
project: configuration and data access, curve fitting, sensitivity scoring, systems-level analysis, sequence learning,
and visualization.
The API is intentionally modular. Most functions operate on pandas.DataFrame objects with a shared schema built around
protein identifiers, experimental conditions, and dose columns. In practice, this means individual modules can be used
independently, but they are designed to compose into a coherent pipeline.
Package layout¶
cetsax.configdefines shared constants and configuration loading.cetsax.dataiohandles data ingestion and QC.cetsax.modelscontains the core ITDR response model.cetsax.fitperforms curve fitting.cetsax.hitsandcetsax.viz_hitshandle hit calling and diagnostics.cetsax.sensitivitycomputes protein-level sensitivity metrics.cetsax.enrichment,cetsax.network,cetsax.redox,cetsax.latent, andcetsax.mixturesupport systems-level interpretation.cetsax.ml,cetsax.bayes, andcetsax.deeplearn.*provide statistical and learning-based extensions.cetsax.plotting,cetsax.viz, andcetsax.viz_predictprovide plotting and result interpretation.
Core package¶
cetsax¶
CETSAx – CETSA-MS modelling toolkit.
This package currently implements ITDR-based binding curve fitting (EC50, Hill, Emax) for proteome-wide NADPH CETSA data. It also includes modules for hit calling, pathway enrichment, latent factor analysis, mixture modelling, redox role analysis, and sequence-based deep learning models of NADPH responsiveness.
load_cetsa_csv(path)
¶
Load CETSA NADPH ITDR dataset from CSV.
Assumes a column 'Unnamed: 0' can be dropped if present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Path to the CSV file. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Loaded dataset with appropriate columns. |
apply_basic_qc(df)
¶
Apply simple QC criteria at the protein-replicate level. Filters proteins based on minimum unique peptides, PSMs, and count number.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Input CETSA dataset. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Filtered dataset passing QC criteria. |
train_seq_model(csv_path, cfg, patience=True)
¶
Returns: model, metrics, paths(dict of caches)
itdr_model(c, E0, Emax, logEC50, h)
¶
4-parameter logistic ITDR model for CETSA:
f(c) = E0 + (Emax - E0) / (1 + (EC50 / c)^h)
where EC50 = 10 ** logEC50 (parameterized in log10 space for stability).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
c
|
ndarray
|
Concentration array. |
required |
E0
|
float
|
Response at zero concentration. |
required |
Emax
|
float
|
Maximum response at infinite concentration. |
required |
logEC50
|
float
|
Log10 of the concentration at half-maximum response. |
required |
h
|
float
|
Hill slope. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Response values at concentrations c. |
fit_all_proteins(df)
¶
Fit ITDR curves for all proteins and replicates in a QC-checked dataframe. Parameters: df: DataFrame Input data with columns: ID_COL, COND_COL, DOSE_COLS... Returns: DataFrame Fit results with columns: ID_COL, COND_COL, E0, Emax, EC50, log10_EC50, Hill, R2, delta_max
call_hits(fit_df, r2_min=0.8, delta_min=0.1)
¶
Filter fitted curves to keep only high-confidence hits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fit_df
|
DataFrame
|
Data Frame containing fitted curve parameters with at least the following columns: - 'R2': Coefficient of determination of the fit. - 'delta_max': Maximum change in response. - 'EC50': Half-maximal effective concentration. |
required |
r2_min
|
float
|
Minimum R² value to consider a fit as a hit. |
0.8
|
delta_min
|
float
|
Minimum delta_max value to consider a fit as a hit. |
0.1
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Data Frame containing only the hits that meet the specified criteria. |
summarize_hits(hits_df, min_reps=2)
¶
Aggregate hit information at the protein level across replicates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hits_df
|
DataFrame
|
Data Frame containing hit information with at least the following columns: - ID_COL: Identifier for the protein or target. - COND_COL: Condition or replicate identifier. - 'EC50': Half-maximal effective concentration. - 'Emax': Maximum effect. - 'Hill': Hill coefficient. |
required |
min_reps
|
int
|
Minimum number of replicates required to include a protein in the summary. |
2
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Summary Data Frame with aggregated hit information per protein. |
plot_protein_curve(df, fit_df, protein_id, condition=None, ax=None)
¶
Plot raw ITDR data and fitted curve for a given protein (and condition).
plot_goodness_of_fit(df, fit_df, ax=None)
¶
Global goodness-of-fit plot: observed vs predicted values for all proteins, colored by condition (r1, r2).
Returns: fig, ax : matplotlib Figure and Axes objects
bayesian_fit_ec50(df, protein_id, draws=1000, tune=1000, chains=4, cores=1, progressbar=True)
¶
Fit a hierarchical Bayesian EC50 model for a single protein across replicates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
containing CETSA data. |
required |
protein_id
|
str
|
ID of the protein to fit. |
required |
draws
|
int
|
Number of samples to draw from the posterior (per chain). |
1000
|
tune
|
int
|
Number of tuning steps (burn-in). |
1000
|
chains
|
int
|
Number of independent MCMC chains to run. |
4
|
cores
|
int
|
Number of CPU cores to use for parallel sampling. Set to 1 if running this function inside a joblib loop. Set to equal 'chains' if running on a single protein. |
1
|
progressbar
|
bool
|
Whether to show the PyMC progress bar. |
True
|
Returns:
| Type | Description |
|---|---|
dict containing PyMC model, posterior samples (trace), and summary.
|
|
summarize_pathway_effects(metric_df, annot_df, id_col='id', path_col='pathway', metrics=('NSS', 'EC50', 'delta_max', 'R2'))
¶
Summarize NADPH responsiveness per pathway/module.
metric_df : DataFrame Per-protein metric table, e.g. output from sensitivity.compute_sensitivity_scores or a custom per-protein summary with columns: id, EC50, delta_max, Hill, R2, NSS, ...
annot_df : DataFrame Annotation table mapping proteins to pathways/modules. Must contain columns: id_col, path_col
metrics : iterable of str Column names in metric_df to summarize per pathway.
Returns:
| Type | Description |
|---|---|
DataFrame
|
path_col | N_proteins | |
enrich_overrepresentation(hits_df, annot_df, id_col='id', path_col='pathway', hit_col='hit_class', strong_labels=('strong',), min_genes=3)
¶
Perform over-representation analysis for pathways using a binary hit set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hits_df
|
DataFrame
|
Per-protein hit classification table, must contain: id_col, hit_col. |
required |
annot_df
|
DataFrame
|
id-to-pathway mapping. |
required |
id_col
|
str
|
Column name for protein IDs. |
'id'
|
path_col
|
str
|
Column name for pathway/module names. |
'pathway'
|
hit_col
|
str
|
Column name for hit classification. |
'hit_class'
|
strong_labels
|
iterable of str
|
Labels in hit_col to consider as "hits". |
('strong',)
|
min_genes
|
int
|
Minimum number of genes in a pathway to consider it for enrichment. |
3
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
path_col | n_path | n_hits | n_bg | odds_ratio | pval | qval |
enrich_continuous_mannwhitney(sens_df, annot_df, score_col='NSS', id_col='id', path_col='pathway', min_genes=3)
¶
Continuous enrichment per pathway using Mann–Whitney U tests.
For each pathway, compares the distribution of score_col (e.g. NSS,
delta_max, or -log10(EC50)) between proteins in the pathway vs all
other proteins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Per-protein sensitivity scores, must contain: id_col, score_col. |
required |
annot_df
|
DataFrame
|
id-to-pathway mapping. |
required |
score_col
|
str
|
Column name for continuous score to test. |
'NSS'
|
id_col
|
str
|
Column name for protein IDs. |
'id'
|
path_col
|
str
|
Column name for pathway/module names. |
'pathway'
|
min_genes
|
int
|
Minimum number of genes in a pathway to consider it for enrichment. |
3
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
pathway | n_path | score_mean | score_median | U_stat | pval | qval |
build_feature_matrix(sens_df, redox_df=None, id_col='id', base_features=None, include_redox_axes=True)
¶
Build a standardized feature matrix per protein.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Output of sensitivity.compute_sensitivity_scores, expected columns: id, EC50, delta_max, Hill, R2, NSS, EC50_scaled, ... |
required |
redox_df
|
DataFrame
|
Output of redox.build_redox_axes, expected columns: id, axis_direct, axis_indirect, axis_network, redox_role, ... |
None
|
base_features
|
list of str or None
|
Which numeric columns from sens_df to include as features. If None, defaults to: ["EC50", "delta_max", "Hill", "R2", "NSS", "EC50_scaled", "delta_max_scaled", "Hill_scaled", "R2_scaled"] |
None
|
include_redox_axes
|
bool
|
If True and redox_df is provided, will also include: ["axis_direct", "axis_indirect", "axis_network"] as features. |
True
|
id_col
|
str
|
Column name for protein identifier in sens_df and redox_df. |
'id'
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Feature matrix indexed by id: index: id columns: selected features (standardized) |
fit_pca(feat_df, n_components=3)
¶
Fit PCA on the standardized feature matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Feature matrix as returned by build_feature_matrix, indexed by id. |
required |
n_components
|
int
|
Number of principal components to compute. |
3
|
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
fit_factor_analysis(feat_df, n_components=3)
¶
Fit Factor Analysis (FA) on the standardized feature matrix.
FA often yields more interpretable latent factors than PCA when features are noisy and partially redundant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Feature matrix as returned by build_feature_matrix, indexed by id. |
required |
n_components
|
int
|
Number of latent factors. |
3
|
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
attach_latent_to_metadata(meta_df, latent_df, id_col='id')
¶
Merge latent coordinates back to a per-protein metadata table.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
meta_df
|
DataFrame
|
Any per-protein table with column id_col (e.g. sens_df, redox_df). |
required |
latent_df
|
DataFrame
|
Latent representation indexed by id (e.g. PCA scores or FA scores). |
required |
id_col
|
str
|
Column name for protein identifier in meta_df and latent_df. |
'id'
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
meta_df with latent columns appended. |
build_mixture_features(sens_df, redox_df=None, id_col='id', feature_cols=None, include_redox_axes=True, log_transform_ec50=True)
¶
Build feature matrix for mixture modelling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Typically output of sensitivity.compute_sensitivity_scores, with at least: id, EC50, delta_max, R2, NSS, ... |
required |
redox_df
|
DataFrame
|
Output of redox.build_redox_axes, with columns: id, axis_direct, axis_indirect, axis_network, ... |
None
|
feature_cols
|
list of str or None
|
Which columns from sens_df to use. If None, defaults to: ["EC50", "delta_max", "NSS", "R2"] |
None
|
include_redox_axes
|
bool
|
If True and redox_df provided, add: ["axis_direct", "axis_indirect", "axis_network"] |
True
|
log_transform_ec50
|
bool
|
If True and "EC50" in features, replace with -log10(EC50) (so higher means stronger / more sensitive). |
True
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
id-indexed standardized feature matrix for mixture modelling. |
fit_gmm_bic_grid(feat_df, n_components_grid=None, covariance_type='full', random_state=0)
¶
Fit Gaussian Mixture Models for a grid of component numbers and select the best model via lowest BIC.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Standardized feature matrix indexed by id. |
required |
n_components_grid
|
list of int or None
|
List of component numbers to evaluate. If None, defaults to [1, 2, 3, 4, 5]. |
None
|
covariance_type
|
('full', 'tied', 'diag', 'spherical')
|
Covariance structure for GMM. |
"full"
|
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
assign_mixture_clusters(feat_df, gmm, id_col='id')
¶
Assign mixture clusters and posterior responsibilities for each protein.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Standardized feature matrix indexed by id. |
required |
gmm
|
GaussianMixture
|
Fitted GMM. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
id | cluster | resp_k... (one column per mixture component) |
label_clusters_by_sensitivity(sens_df, cluster_df, id_col='id', score_col='NSS')
¶
Assign human-readable labels ("high", "medium", "low") to mixture clusters by ranking them by the mean of a chosen score (e.g. NSS, -log10 EC50).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Per-protein scores, must contain id_col and score_col. |
required |
cluster_df
|
DataFrame
|
Output of assign_mixture_clusters, must contain id_col and 'cluster'. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
cluster | mean_score | label |
extract_curve_features(df, n_components=3)
¶
Reduce dose-response curves to principal components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Data Frame containing dose-response data with 'id' column and dose columns. |
required |
n_components
|
int
|
Number of PCA components to extract. |
3
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Data Frame with PCA features per protein ID. |
classify_curves_kmeans(features, k=4)
¶
Apply KMeans to curve embeddings (e.g. PCA features).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
DataFrame
|
Feature matrix per protein (e.g. output of extract_curve_features). |
required |
k
|
int
|
Number of clusters. |
4
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Input features with an additional 'cluster' column. |
detect_outliers(features)
¶
Simple z-score heuristic for outlier curves.
TODO : Maybe Replace with robust Mahalanobis distance or isolation forest.¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
DataFrame
|
Feature matrix per protein (e.g. output of extract_curve_features). |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Data Frame with boolean 'outlier' column. |
compute_costab_matrix(df)
¶
Compute co-stabilization (correlation) matrix across proteins. Each protein is represented by its dose-response vector.
Returns: DataFrame (n_proteins x n_proteins) correlation matrix.
make_network_from_matrix(corr_matrix, cutoff=0.7)
¶
Convert correlation matrix into a network graph. Edges exist for correlations >= cutoff.
detect_modules(G)
¶
Apply a community detection algorithm (Louvain or greedy modularity) to identify co-stabilization modules.
Returns: dict: {protein_id: module_id}
build_redox_axes(fits_df, sens_df, hits_df, net_df=None, id_col='id', hit_col='dominant_class', degree_col='degree', betweenness_col='betweenness')
¶
Build redox axes for each protein.
Inputs
fits_df : DataFrame Per-protein or per-(id,condition) CETSA parameters; must contain: id_col, EC50, delta_max, R2, Hill If per-condition, should be pre-aggregated to protein-level (e.g. median across replicates) before calling this function.
sens_df : DataFrame Output of sensitivity.compute_sensitivity_scores with columns: id_col, NSS, EC50, delta_max, Hill, R2, ...
hits_df : DataFrame Per-protein hit classification; must contain: id_col, hit_col (e.g. "strong", "medium", "weak")
net_df : DataFrame, optional Network centrality stats; if provided, must contain: id_col, degree_col, betweenness_col If None, network centrality will be ignored. id_col : str Column name for protein identifier. hit_col : str Column name for hit classification. degree_col : str Column name for network degree centrality. betweenness_col : str Column name for network betweenness centrality.
Returns:
| Type | Description |
|---|---|
DataFrame
|
id | EC50 | delta_max | R2 | NSS | axis_direct | axis_indirect | axis_network | redox_role |
summarize_redox_by_pathway(redox_df, annot_df, id_col='id', path_col='pathway')
¶
Summarize redox axes and roles per pathway/module.
redox_df : DataFrame Output of build_redox_axes, with columns: id_col, axis_direct, axis_indirect, axis_network, redox_role, ...
annot_df : DataFrame id-to-pathway mapping. Must contain: id_col, path_col id_col : str Column name for protein identifier. path_col : str Column name for pathway/module annotation.
Returns:
| Type | Description |
|---|---|
DataFrame
|
pathway | N | direct_mean | indirect_mean | network_mean | frac_direct | frac_indirect | frac_network | frac_peripheral |
compute_sensitivity_scores(fits_df, id_col='id', cond_col='condition', agg='median', weights=None)
¶
Compute a unified NADPH Sensitivity Score (NSS) per protein.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fits_df
|
DataFrame
|
Output from fit_all_proteins, containing: [id, condition, EC50, delta_max, Hill, R2] |
required |
agg
|
('median', 'mean')
|
How to aggregate across replicates. |
"median"
|
weights
|
dict or None
|
Optional component weights: {"EC50": w1, "delta_max": w2, "Hill": w3, "R2": w4} If None, defaults to: EC50: 0.45, delta_max: 0.3, Hill: 0.15, R2: 0.10 |
None
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
id | EC50 | delta_max | Hill | R2 | NSS | NSS_rank |
summarize_sensitivity_by_pathway(sens_df, annot_df, id_col='id', path_col='pathway')
¶
Summarize NADPH sensitivity per pathway or complex.
annot_df must map: id -> pathway/module
Output: pathway | N | NSS_mean | NSS_top25 | EC50_median | delta_max_median
compute_sensitivity_heterogeneity(sens_df, bins=50)
¶
Quantify how spread-out sensitivity is across the proteome.
Returns: dict with: - histogram - Gini coefficient (inequality) - top10% NSS threshold
run_hit_calling_and_plots(fits_df, out_dir, id_col='id', cond_col='condition', ec50_strong=0.01, delta_strong=0.1, r2_strong=0.7)
¶
Full hit-calling + plotting pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fits_df
|
DataFrame
|
Output of fit_all_proteins with columns: id, condition, EC50, delta_max, R2, ... |
required |
out_dir
|
Path or str
|
Directory where plots and ranked hits table will be saved. |
required |
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
plot_pathway_effects_bar(path_df, metric='NSS_mean', top_n=20, ax=None)
¶
Horizontal barplot of pathway-level effects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path_df
|
DataFrame
|
Output of summarize_pathway_effects, must contain: 'pathway' (or custom) and the chosen metric. |
required |
metric
|
str
|
Column name to plot, e.g. 'NSS_mean', 'delta_max_median'. |
'NSS_mean'
|
top_n
|
int
|
Number of top pathways to show. |
20
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_pathway_enrichment_volcano(enr_df, ax=None, label_top_n=10)
¶
Plot volcano plot of pathway over-representation analysis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
enr_df
|
DataFrame
|
Output of pathway_enrichment_analysis, must contain: 'odds_ratio', 'qval', 'pathway' (or custom). |
required |
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
label_top_n
|
int
|
Number of top pathways (by smallest qval) to label. |
10
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_redox_axes_scatter(redox_df, x_axis='axis_direct', y_axis='axis_indirect', color_by='redox_role', ax=None)
¶
Scatter of redox axes (e.g. direct vs indirect), colored by role.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
redox_df
|
DataFrame
|
Output of build_redox_axes, must contain x_axis, y_axis, color_by. |
required |
x_axis
|
str
|
Columns to use as x and y coordinates. |
'axis_direct'
|
y_axis
|
str
|
Columns to use as x and y coordinates. |
'axis_direct'
|
color_by
|
str
|
Categorical column used for coloring (e.g. 'redox_role'). |
'redox_role'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_redox_role_composition(path_redox_df, top_n=20, ax=None)
¶
Stacked barplot of redox role composition per pathway.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path_redox_df
|
DataFrame
|
Output of summarize_pathway_redox_roles, must contain: 'pathway' (or custom), 'frac_direct_core', 'frac_indirect_responder', 'frac_network_mediator', 'frac_peripheral'. |
required |
top_n
|
int
|
Number of top pathways to show. |
20
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
plot_pca_scores(scores_df, meta_df=None, id_col='id', color_by=None, pc_x='PC1', pc_y='PC2', ax=None)
¶
Scatter of PCA scores (PC1 vs PC2), optionally colored by a metadata column.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scores_df
|
DataFrame
|
PCA scores indexed by id, columns including pc_x, pc_y. |
required |
meta_df
|
DataFrame
|
Metadata dataframe with id_col and color_by column. |
None
|
id_col
|
str
|
Column name for protein identifier. |
'id'
|
color_by
|
str or None
|
Column in meta_df to color points by. If None, no coloring. |
None
|
pc_x
|
str
|
Columns in scores_df to use as x and y coordinates. |
'PC1'
|
pc_y
|
str
|
Columns in scores_df to use as x and y coordinates. |
'PC1'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_factor_scores(scores_df, meta_df=None, id_col='id', color_by=None, f_x='F1', f_y='F2', ax=None)
¶
Scatter of factor analysis scores (F1 vs F2), optionally colored by a metadata column.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scores_df
|
DataFrame
|
Factor analysis scores indexed by id, columns including f_x, f_y. |
required |
meta_df
|
DataFrame
|
Metadata dataframe with id_col and color_by column. |
None
|
id_col
|
str
|
Column name for protein identifier. |
'id'
|
color_by
|
str or None
|
Column in meta_df to color points by. If None, no coloring. |
None
|
f_x
|
str
|
Columns in scores_df to use as x and y coordinates. |
'F1'
|
f_y
|
str
|
Columns in scores_df to use as x and y coordinates. |
'F1'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_mixture_clusters_in_pca(pca_scores, cluster_df, id_col='id', pc_x='PC1', pc_y='PC2', ax=None)
¶
Plot mixture clusters in PCA space (PC1 vs PC2).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pca_scores
|
DataFrame
|
PCA scores indexed by id, columns including pc_x, pc_y. |
required |
cluster_df
|
DataFrame
|
Output of assign_mixture_clusters, must contain id_col, 'cluster'. |
required |
id_col
|
str
|
Column name for protein identifier. |
'id'
|
pc_x
|
str
|
Columns in pca_scores to use as x and y coordinates. |
'PC1'
|
pc_y
|
str
|
Columns in pca_scores to use as x and y coordinates. |
'PC1'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_cluster_size_bar(cluster_df, ax=None)
¶
Barplot of mixture cluster sizes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cluster_df
|
DataFrame
|
Output of assign_mixture_clusters, must contain 'cluster'. |
required |
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
visualize_predictions(pred_file, truth_file, out_dir)
¶
Generate a series of plots to visualize model predictions against ground truth labels and experimental data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_file
|
str
|
Path to CSV file containing model predictions. |
required |
truth_file
|
str
|
Path to CSV file containing ground truth labels and experimental data. |
required |
- Global Residue Importance (Aggregated IG)
analyze_fitting_data(fits_file, pred_file, out_dir)
¶
Analyze and visualize the quality of curve fitting data in relation to model predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fits_file
|
str
|
Path to CSV file containing curve fitting parameters. |
required |
pred_file
|
str
|
Path to CSV file containing model predictions. |
required |
- Curve Reconstruction for Top Predicted Targets
generate_bio_insight(pred_file, truth_file, annot_file, out_dir)
¶
Generate biological insight plots based on model predictions, experimental data, and pathway annotations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_file
|
str
|
Path to CSV file containing model predictions. |
required |
truth_file
|
str
|
Path to CSV file containing ground truth labels and experimental data. |
required |
annot_file
|
str
|
Path to CSV file containing protein annotations (e.g., pathways). |
required |
- EC50 Validation Across Predicted Classes
plot_training_loop(history_file, out_dir)
¶
Plot training and validation loss/accuracy curves from model training history.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
history_file
|
str
|
Path to CSV file containing training history. |
required |
- Accuracy Curves (if applicable)
options: show_root_heading: true show_root_toc_entry: false show_source: false members_order: source
Configuration and data I/O¶
cetsax.config¶
This module exposes the shared experimental schema. Most downstream modules rely on the constants defined here, so this is effectively the contract layer of the package.
Configuration and constants for CETSA EC50/KD modelling. Dynamically loaded from config.yaml.
load_yaml_config(path)
¶
Safe load the yaml config.
options: show_root_heading: true show_source: false members_order: source
cetsax.dataio¶
This module loads CETSA input tables and applies the first QC filter before any modeling is attempted.
Data loading and basic filtering for CETSA NADPH ITDR dataset.
load_cetsa_csv(path)
¶
Load CETSA NADPH ITDR dataset from CSV.
Assumes a column 'Unnamed: 0' can be dropped if present.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Path to the CSV file. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Loaded dataset with appropriate columns. |
apply_basic_qc(df)
¶
Apply simple QC criteria at the protein-replicate level. Filters proteins based on minimum unique peptides, PSMs, and count number.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Input CETSA dataset. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Filtered dataset passing QC criteria. |
options: show_root_heading: true show_source: false members_order: source
cetsax.annotate¶
This module supports annotation and sequence retrieval workflows, including ID normalization, MyGene-based annotation, and UniProt FASTA fetching.
annotate.py
Builds: 1) protein_annotations.csv 2) protein_sequences.fasta
from a list of protein IDs using mygene and UniProt REST, with parallelized FASTA retrieval.
strip_isoform_suffix(acc)
¶
Remove UniProt isoform suffixes like '-1', '-2', '-10'. O00231-2 → O00231
fetch_annotations_with_mygene(ids, species='human', chunk_size=1000)
¶
Use mygene to get basic annotations.
We try multiple scopes: symbol, entrezgene, uniprot. Returned fields: symbol, name, entrezgene, uniprot, go.BP, pathway
fetch_uniprot_fasta(acc, timeout=10.0)
¶
Fetch a single UniProt FASTA entry by accession.
Returns the FASTA string (with header + sequence) or None on failure.
fetch_fastas_parallel(accessions, max_workers=8)
¶
Parallel UniProt FASTA retrieval for a list of accessions.
Returns dict: {acc: fasta_text}.
write_fastas_with_ids(annot_df, acc_to_fasta, out_fasta, id_col='id', uniprot_col='uniprot')
¶
Write FASTA where header is your own id (from id_col), not the UniProt header.
For each row in annot_df: - take original id (e.g. O00231-2) - take UniProt accession (e.g. O00231) - get sequence from acc_to_fasta[accession] - strip original '>' header - write new header: >{id}
options: show_root_heading: true show_source: false members_order: source
Mathematical modeling and fitting¶
cetsax.models¶
This module contains the canonical ITDR model function used for CETSA dose–response evaluation.
models.py
Mathematical models for ITDR CETSA curves. Provides the 4-parameter logistic ITDR model function.
itdr_model(c, E0, Emax, logEC50, h)
¶
4-parameter logistic ITDR model for CETSA:
f(c) = E0 + (Emax - E0) / (1 + (EC50 / c)^h)
where EC50 = 10 ** logEC50 (parameterized in log10 space for stability).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
c
|
ndarray
|
Concentration array. |
required |
E0
|
float
|
Response at zero concentration. |
required |
Emax
|
float
|
Maximum response at infinite concentration. |
required |
logEC50
|
float
|
Log10 of the concentration at half-maximum response. |
required |
h
|
float
|
Hill slope. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Response values at concentrations c. |
options: show_root_heading: true show_source: false members_order: source
cetsax.fit¶
This is the main fitting engine. It handles monotonic smoothing, bounded optimization, fit diagnostics, and bulk fitting across proteins and conditions.
fit.py
ITDR curve fitting for CETSA data. Fit ITDR curves to dose-response data using a robust logistic model with monotonic smoothing and Hill slope regularization.
fit_all_proteins(df)
¶
Fit ITDR curves for all proteins and replicates in a QC-checked dataframe. Parameters: df: DataFrame Input data with columns: ID_COL, COND_COL, DOSE_COLS... Returns: DataFrame Fit results with columns: ID_COL, COND_COL, E0, Emax, EC50, log10_EC50, Hill, R2, delta_max
options: show_root_heading: true show_source: false members_order: source
cetsax.bayes¶
This module provides Bayesian alternatives for parameter inference when posterior uncertainty is of interest rather than only point estimates.
bayes.py
Purpose: Implement hierarchical Bayesian modelling of CETSA EC50 parameters.
bayesian_fit_ec50(df, protein_id, draws=1000, tune=1000, chains=4, cores=1, progressbar=True)
¶
Fit a hierarchical Bayesian EC50 model for a single protein across replicates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
containing CETSA data. |
required |
protein_id
|
str
|
ID of the protein to fit. |
required |
draws
|
int
|
Number of samples to draw from the posterior (per chain). |
1000
|
tune
|
int
|
Number of tuning steps (burn-in). |
1000
|
chains
|
int
|
Number of independent MCMC chains to run. |
4
|
cores
|
int
|
Number of CPU cores to use for parallel sampling. Set to 1 if running this function inside a joblib loop. Set to equal 'chains' if running on a single protein. |
1
|
progressbar
|
bool
|
Whether to show the PyMC progress bar. |
True
|
Returns:
| Type | Description |
|---|---|
dict containing PyMC model, posterior samples (trace), and summary.
|
|
options: show_root_heading: true show_source: false members_order: source
Hit calling and sensitivity scoring¶
cetsax.hits¶
This module provides direct hit filtering and hit summarization from fitted parameters.
hits.py
Hit calling and summary for CETSA ITDR binding models.
call_hits(fit_df, r2_min=0.8, delta_min=0.1)
¶
Filter fitted curves to keep only high-confidence hits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fit_df
|
DataFrame
|
Data Frame containing fitted curve parameters with at least the following columns: - 'R2': Coefficient of determination of the fit. - 'delta_max': Maximum change in response. - 'EC50': Half-maximal effective concentration. |
required |
r2_min
|
float
|
Minimum R² value to consider a fit as a hit. |
0.8
|
delta_min
|
float
|
Minimum delta_max value to consider a fit as a hit. |
0.1
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Data Frame containing only the hits that meet the specified criteria. |
summarize_hits(hits_df, min_reps=2)
¶
Aggregate hit information at the protein level across replicates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hits_df
|
DataFrame
|
Data Frame containing hit information with at least the following columns: - ID_COL: Identifier for the protein or target. - COND_COL: Condition or replicate identifier. - 'EC50': Half-maximal effective concentration. - 'Emax': Maximum effect. - 'Hill': Hill coefficient. |
required |
min_reps
|
int
|
Minimum number of replicates required to include a protein in the summary. |
2
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Summary Data Frame with aggregated hit information per protein. |
options: show_root_heading: true show_source: false members_order: source
cetsax.sensitivity¶
This module computes the NADPH Sensitivity Score and related protein-level summary metrics.
sensitivity.py
Proteome-wide NADPH sensitivity modelling. This module provides functions to compute a unified NADPH Sensitivity Score (NSS) per protein by integrating dose-response fit parameters, summarize sensitivity at the pathway/module level, and quantify sensitivity heterogeneity across the proteome.
compute_sensitivity_scores(fits_df, id_col='id', cond_col='condition', agg='median', weights=None)
¶
Compute a unified NADPH Sensitivity Score (NSS) per protein.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fits_df
|
DataFrame
|
Output from fit_all_proteins, containing: [id, condition, EC50, delta_max, Hill, R2] |
required |
agg
|
('median', 'mean')
|
How to aggregate across replicates. |
"median"
|
weights
|
dict or None
|
Optional component weights: {"EC50": w1, "delta_max": w2, "Hill": w3, "R2": w4} If None, defaults to: EC50: 0.45, delta_max: 0.3, Hill: 0.15, R2: 0.10 |
None
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
id | EC50 | delta_max | Hill | R2 | NSS | NSS_rank |
summarize_sensitivity_by_pathway(sens_df, annot_df, id_col='id', path_col='pathway')
¶
Summarize NADPH sensitivity per pathway or complex.
annot_df must map: id -> pathway/module
Output: pathway | N | NSS_mean | NSS_top25 | EC50_median | delta_max_median
compute_sensitivity_heterogeneity(sens_df, bins=50)
¶
Quantify how spread-out sensitivity is across the proteome.
Returns: dict with: - histogram - Gini coefficient (inequality) - top10% NSS threshold
options: show_root_heading: true show_source: false members_order: source
cetsax.viz_hits¶
This module combines hit-calling logic with diagnostic visualization utilities and export helpers.
viz_hits.py
Hit-calling and visualization functions for CETSA ITDR data. This module provides functions to classify hits based on EC50, delta_max, and R2 thresholds, generate ranked hit tables, and create diagnostic plots to visualize hit characteristics and replicate consistency.
classify_hit(row, ec50_strong=0.01, delta_strong=0.1, r2_strong=0.7)
¶
Classify a single (EC50, delta_max, R2) triplet into strong/weak.
Thresholds are configurable but default to: strong: EC50 < 0.01, delta_max > 0.10, R2 > 0.70 else: weak
build_hits_table(fits_df, ec50_strong=0.01, delta_strong=0.1, r2_strong=0.7, id_col='id', cond_col='condition')
¶
Add hit_class per row and aggregate to a per-protein ranked table.
Returns a DataFrame with columns: id, n_reps, class_counts, EC50_median, delta_max_median, R2_median, dominant_class
plot_ec50_vs_delta(df, ec50_cut=0.01, delta_cut=0.1, ax=None)
¶
EC50 vs delta_max scatter with quadrant lines.
plot_ec50_replicates(df, id_col='id', cond_col='condition', cond_r1='NADPH.r1', cond_r2='NADPH.r2', ax=None)
¶
EC50 replicate consistency plot: EC50_r1 vs EC50_r2 (log-log).
Returns (fig, ax) or None if required conditions are missing.
plot_r2_vs_delta(df, ax=None)
¶
R2 vs delta_max scatter.
plot_ec50_vs_r2(df, ax=None)
¶
EC50 vs R2 scatter (log-scale EC50).
run_hit_calling_and_plots(fits_df, out_dir, id_col='id', cond_col='condition', ec50_strong=0.01, delta_strong=0.1, r2_strong=0.7)
¶
Full hit-calling + plotting pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fits_df
|
DataFrame
|
Output of fit_all_proteins with columns: id, condition, EC50, delta_max, R2, ... |
required |
out_dir
|
Path or str
|
Directory where plots and ranked hits table will be saved. |
required |
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
options: show_root_heading: true show_source: false members_order: source
Systems-level analysis¶
cetsax.enrichment¶
This module implements pathway-level summarization and both binary and continuous enrichment testing.
enrichment.py
Pathway-level effect size integration and enrichment for NADPH CETSA data. Build pathway/module-level summaries of NADPH responsiveness and perform enrichment tests using:
- Binary hit sets (e.g. strong / medium hits)
- Continuous scores (e.g. NSS, delta_max, EC50)
summarize_pathway_effects(metric_df, annot_df, id_col='id', path_col='pathway', metrics=('NSS', 'EC50', 'delta_max', 'R2'))
¶
Summarize NADPH responsiveness per pathway/module.
metric_df : DataFrame Per-protein metric table, e.g. output from sensitivity.compute_sensitivity_scores or a custom per-protein summary with columns: id, EC50, delta_max, Hill, R2, NSS, ...
annot_df : DataFrame Annotation table mapping proteins to pathways/modules. Must contain columns: id_col, path_col
metrics : iterable of str Column names in metric_df to summarize per pathway.
Returns:
| Type | Description |
|---|---|
DataFrame
|
path_col | N_proteins | |
enrich_overrepresentation(hits_df, annot_df, id_col='id', path_col='pathway', hit_col='hit_class', strong_labels=('strong',), min_genes=3)
¶
Perform over-representation analysis for pathways using a binary hit set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hits_df
|
DataFrame
|
Per-protein hit classification table, must contain: id_col, hit_col. |
required |
annot_df
|
DataFrame
|
id-to-pathway mapping. |
required |
id_col
|
str
|
Column name for protein IDs. |
'id'
|
path_col
|
str
|
Column name for pathway/module names. |
'pathway'
|
hit_col
|
str
|
Column name for hit classification. |
'hit_class'
|
strong_labels
|
iterable of str
|
Labels in hit_col to consider as "hits". |
('strong',)
|
min_genes
|
int
|
Minimum number of genes in a pathway to consider it for enrichment. |
3
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
path_col | n_path | n_hits | n_bg | odds_ratio | pval | qval |
enrich_continuous_mannwhitney(sens_df, annot_df, score_col='NSS', id_col='id', path_col='pathway', min_genes=3)
¶
Continuous enrichment per pathway using Mann–Whitney U tests.
For each pathway, compares the distribution of score_col (e.g. NSS,
delta_max, or -log10(EC50)) between proteins in the pathway vs all
other proteins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Per-protein sensitivity scores, must contain: id_col, score_col. |
required |
annot_df
|
DataFrame
|
id-to-pathway mapping. |
required |
score_col
|
str
|
Column name for continuous score to test. |
'NSS'
|
id_col
|
str
|
Column name for protein IDs. |
'id'
|
path_col
|
str
|
Column name for pathway/module names. |
'pathway'
|
min_genes
|
int
|
Minimum number of genes in a pathway to consider it for enrichment. |
3
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
pathway | n_path | score_mean | score_median | U_stat | pval | qval |
options: show_root_heading: true show_source: false members_order: source
cetsax.network¶
This module derives co-stabilization networks from CETSA response profiles and detects response modules.
network.py
Network construction and module detection based on co-stabilization. This module provides functions to compute co-stabilization (correlation) matrices from CETSA dose-response data, build protein-protein interaction networks based on correlation cutoffs, and detect co-stabilization modules using community detection algorithms.
compute_costab_matrix(df)
¶
Compute co-stabilization (correlation) matrix across proteins. Each protein is represented by its dose-response vector.
Returns: DataFrame (n_proteins x n_proteins) correlation matrix.
make_network_from_matrix(corr_matrix, cutoff=0.7)
¶
Convert correlation matrix into a network graph. Edges exist for correlations >= cutoff.
detect_modules(G)
¶
Apply a community detection algorithm (Louvain or greedy modularity) to identify co-stabilization modules.
Returns: dict: {protein_id: module_id}
options: show_root_heading: true show_source: false members_order: source
cetsax.redox¶
This module constructs interpretable redox axes and pathway-level redox summaries from sensitivity, hit, and optional network information.
redox.py
Redox-axis reconstruction from NADPH CETSA data. Defines functions to build redox axes per protein based on sensitivity scores, hit classifications, and network centrality, as well as to summarize redox roles at the pathway level.
build_redox_axes(fits_df, sens_df, hits_df, net_df=None, id_col='id', hit_col='dominant_class', degree_col='degree', betweenness_col='betweenness')
¶
Build redox axes for each protein.
Inputs
fits_df : DataFrame Per-protein or per-(id,condition) CETSA parameters; must contain: id_col, EC50, delta_max, R2, Hill If per-condition, should be pre-aggregated to protein-level (e.g. median across replicates) before calling this function.
sens_df : DataFrame Output of sensitivity.compute_sensitivity_scores with columns: id_col, NSS, EC50, delta_max, Hill, R2, ...
hits_df : DataFrame Per-protein hit classification; must contain: id_col, hit_col (e.g. "strong", "medium", "weak")
net_df : DataFrame, optional Network centrality stats; if provided, must contain: id_col, degree_col, betweenness_col If None, network centrality will be ignored. id_col : str Column name for protein identifier. hit_col : str Column name for hit classification. degree_col : str Column name for network degree centrality. betweenness_col : str Column name for network betweenness centrality.
Returns:
| Type | Description |
|---|---|
DataFrame
|
id | EC50 | delta_max | R2 | NSS | axis_direct | axis_indirect | axis_network | redox_role |
summarize_redox_by_pathway(redox_df, annot_df, id_col='id', path_col='pathway')
¶
Summarize redox axes and roles per pathway/module.
redox_df : DataFrame Output of build_redox_axes, with columns: id_col, axis_direct, axis_indirect, axis_network, redox_role, ...
annot_df : DataFrame id-to-pathway mapping. Must contain: id_col, path_col id_col : str Column name for protein identifier. path_col : str Column name for pathway/module annotation.
Returns:
| Type | Description |
|---|---|
DataFrame
|
pathway | N | direct_mean | indirect_mean | network_mean | frac_direct | frac_indirect | frac_network | frac_peripheral |
options: show_root_heading: true show_source: false members_order: source
cetsax.latent¶
This module builds feature matrices and fits low-dimensional latent representations such as PCA and factor analysis.
latent.py
Latent-factor modelling of CETSA NADPH responsiveness. This module provides functions to build a feature matrix from CETSA sensitivity scores and redox axes, and to perform dimensionality reduction using PCA and Factor Analysis (FA). It also includes utilities to merge latent representations back to protein metadata tables.
build_feature_matrix(sens_df, redox_df=None, id_col='id', base_features=None, include_redox_axes=True)
¶
Build a standardized feature matrix per protein.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Output of sensitivity.compute_sensitivity_scores, expected columns: id, EC50, delta_max, Hill, R2, NSS, EC50_scaled, ... |
required |
redox_df
|
DataFrame
|
Output of redox.build_redox_axes, expected columns: id, axis_direct, axis_indirect, axis_network, redox_role, ... |
None
|
base_features
|
list of str or None
|
Which numeric columns from sens_df to include as features. If None, defaults to: ["EC50", "delta_max", "Hill", "R2", "NSS", "EC50_scaled", "delta_max_scaled", "Hill_scaled", "R2_scaled"] |
None
|
include_redox_axes
|
bool
|
If True and redox_df is provided, will also include: ["axis_direct", "axis_indirect", "axis_network"] as features. |
True
|
id_col
|
str
|
Column name for protein identifier in sens_df and redox_df. |
'id'
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Feature matrix indexed by id: index: id columns: selected features (standardized) |
fit_pca(feat_df, n_components=3)
¶
Fit PCA on the standardized feature matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Feature matrix as returned by build_feature_matrix, indexed by id. |
required |
n_components
|
int
|
Number of principal components to compute. |
3
|
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
fit_factor_analysis(feat_df, n_components=3)
¶
Fit Factor Analysis (FA) on the standardized feature matrix.
FA often yields more interpretable latent factors than PCA when features are noisy and partially redundant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Feature matrix as returned by build_feature_matrix, indexed by id. |
required |
n_components
|
int
|
Number of latent factors. |
3
|
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
attach_latent_to_metadata(meta_df, latent_df, id_col='id')
¶
Merge latent coordinates back to a per-protein metadata table.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
meta_df
|
DataFrame
|
Any per-protein table with column id_col (e.g. sens_df, redox_df). |
required |
latent_df
|
DataFrame
|
Latent representation indexed by id (e.g. PCA scores or FA scores). |
required |
id_col
|
str
|
Column name for protein identifier in meta_df and latent_df. |
'id'
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
meta_df with latent columns appended. |
options: show_root_heading: true show_source: false members_order: source
cetsax.mixture¶
This module supports Gaussian mixture modeling for response-state discovery and soft cluster assignment.
mixture.py
Mixture modelling for CETSA NADPH response populations. This module provides functions to build a standardized feature matrix from CETSA sensitivity scores and redox axes, fit Gaussian Mixture Models (GMMs ) with BIC-based model selection, assign cluster labels and posterior responsibilities, and optionally label clusters by sensitivity levels.
build_mixture_features(sens_df, redox_df=None, id_col='id', feature_cols=None, include_redox_axes=True, log_transform_ec50=True)
¶
Build feature matrix for mixture modelling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Typically output of sensitivity.compute_sensitivity_scores, with at least: id, EC50, delta_max, R2, NSS, ... |
required |
redox_df
|
DataFrame
|
Output of redox.build_redox_axes, with columns: id, axis_direct, axis_indirect, axis_network, ... |
None
|
feature_cols
|
list of str or None
|
Which columns from sens_df to use. If None, defaults to: ["EC50", "delta_max", "NSS", "R2"] |
None
|
include_redox_axes
|
bool
|
If True and redox_df provided, add: ["axis_direct", "axis_indirect", "axis_network"] |
True
|
log_transform_ec50
|
bool
|
If True and "EC50" in features, replace with -log10(EC50) (so higher means stronger / more sensitive). |
True
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
id-indexed standardized feature matrix for mixture modelling. |
fit_gmm_bic_grid(feat_df, n_components_grid=None, covariance_type='full', random_state=0)
¶
Fit Gaussian Mixture Models for a grid of component numbers and select the best model via lowest BIC.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Standardized feature matrix indexed by id. |
required |
n_components_grid
|
list of int or None
|
List of component numbers to evaluate. If None, defaults to [1, 2, 3, 4, 5]. |
None
|
covariance_type
|
('full', 'tied', 'diag', 'spherical')
|
Covariance structure for GMM. |
"full"
|
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
assign_mixture_clusters(feat_df, gmm, id_col='id')
¶
Assign mixture clusters and posterior responsibilities for each protein.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feat_df
|
DataFrame
|
Standardized feature matrix indexed by id. |
required |
gmm
|
GaussianMixture
|
Fitted GMM. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
id | cluster | resp_k... (one column per mixture component) |
label_clusters_by_sensitivity(sens_df, cluster_df, id_col='id', score_col='NSS')
¶
Assign human-readable labels ("high", "medium", "low") to mixture clusters by ranking them by the mean of a chosen score (e.g. NSS, -log10 EC50).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sens_df
|
DataFrame
|
Per-protein scores, must contain id_col and score_col. |
required |
cluster_df
|
DataFrame
|
Output of assign_mixture_clusters, must contain id_col and 'cluster'. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
cluster | mean_score | label |
options: show_root_heading: true show_source: false members_order: source
cetsax.ml¶
This module contains classical machine-learning utilities for curve feature extraction, clustering, and outlier detection.
ml.py
Machine learning utilities for CETSA dose-response curve analysis. This module provides functions to extract features from dose-response curves using PCA, classify curves using KMeans clustering, and detect outlier curves based on z-score heuristics.
extract_curve_features(df, n_components=3)
¶
Reduce dose-response curves to principal components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Data Frame containing dose-response data with 'id' column and dose columns. |
required |
n_components
|
int
|
Number of PCA components to extract. |
3
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Data Frame with PCA features per protein ID. |
classify_curves_kmeans(features, k=4)
¶
Apply KMeans to curve embeddings (e.g. PCA features).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
DataFrame
|
Feature matrix per protein (e.g. output of extract_curve_features). |
required |
k
|
int
|
Number of clusters. |
4
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Input features with an additional 'cluster' column. |
detect_outliers(features)
¶
Simple z-score heuristic for outlier curves.
TODO : Maybe Replace with robust Mahalanobis distance or isolation forest.¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
features
|
DataFrame
|
Feature matrix per protein (e.g. output of extract_curve_features). |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Data Frame with boolean 'outlier' column. |
options: show_root_heading: true show_source: false members_order: source
Visualization¶
cetsax.plotting¶
This module provides direct plotting helpers for raw curves and global fit diagnostics.
plotting.py
Plotting functions for ITDR data and model fits.
plot_protein_curve(df, fit_df, protein_id, condition=None, ax=None)
¶
Plot raw ITDR data and fitted curve for a given protein (and condition).
plot_goodness_of_fit(df, fit_df, ax=None)
¶
Global goodness-of-fit plot: observed vs predicted values for all proteins, colored by condition (r1, r2).
Returns: fig, ax : matplotlib Figure and Axes objects
options: show_root_heading: true show_source: false members_order: source
cetsax.viz¶
This module provides higher-level visualization for pathway effects, redox axes, latent structure, and mixture-model outputs.
viz.py
Plotting functions for pathway effects, redox axes, latent factors, and mixture clusters.
plot_pathway_effects_bar(path_df, metric='NSS_mean', top_n=20, ax=None)
¶
Horizontal barplot of pathway-level effects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path_df
|
DataFrame
|
Output of summarize_pathway_effects, must contain: 'pathway' (or custom) and the chosen metric. |
required |
metric
|
str
|
Column name to plot, e.g. 'NSS_mean', 'delta_max_median'. |
'NSS_mean'
|
top_n
|
int
|
Number of top pathways to show. |
20
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_pathway_enrichment_volcano(enr_df, ax=None, label_top_n=10)
¶
Plot volcano plot of pathway over-representation analysis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
enr_df
|
DataFrame
|
Output of pathway_enrichment_analysis, must contain: 'odds_ratio', 'qval', 'pathway' (or custom). |
required |
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
label_top_n
|
int
|
Number of top pathways (by smallest qval) to label. |
10
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_redox_axes_scatter(redox_df, x_axis='axis_direct', y_axis='axis_indirect', color_by='redox_role', ax=None)
¶
Scatter of redox axes (e.g. direct vs indirect), colored by role.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
redox_df
|
DataFrame
|
Output of build_redox_axes, must contain x_axis, y_axis, color_by. |
required |
x_axis
|
str
|
Columns to use as x and y coordinates. |
'axis_direct'
|
y_axis
|
str
|
Columns to use as x and y coordinates. |
'axis_direct'
|
color_by
|
str
|
Categorical column used for coloring (e.g. 'redox_role'). |
'redox_role'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_redox_role_composition(path_redox_df, top_n=20, ax=None)
¶
Stacked barplot of redox role composition per pathway.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path_redox_df
|
DataFrame
|
Output of summarize_pathway_redox_roles, must contain: 'pathway' (or custom), 'frac_direct_core', 'frac_indirect_responder', 'frac_network_mediator', 'frac_peripheral'. |
required |
top_n
|
int
|
Number of top pathways to show. |
20
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
plot_pca_scores(scores_df, meta_df=None, id_col='id', color_by=None, pc_x='PC1', pc_y='PC2', ax=None)
¶
Scatter of PCA scores (PC1 vs PC2), optionally colored by a metadata column.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scores_df
|
DataFrame
|
PCA scores indexed by id, columns including pc_x, pc_y. |
required |
meta_df
|
DataFrame
|
Metadata dataframe with id_col and color_by column. |
None
|
id_col
|
str
|
Column name for protein identifier. |
'id'
|
color_by
|
str or None
|
Column in meta_df to color points by. If None, no coloring. |
None
|
pc_x
|
str
|
Columns in scores_df to use as x and y coordinates. |
'PC1'
|
pc_y
|
str
|
Columns in scores_df to use as x and y coordinates. |
'PC1'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_factor_scores(scores_df, meta_df=None, id_col='id', color_by=None, f_x='F1', f_y='F2', ax=None)
¶
Scatter of factor analysis scores (F1 vs F2), optionally colored by a metadata column.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scores_df
|
DataFrame
|
Factor analysis scores indexed by id, columns including f_x, f_y. |
required |
meta_df
|
DataFrame
|
Metadata dataframe with id_col and color_by column. |
None
|
id_col
|
str
|
Column name for protein identifier. |
'id'
|
color_by
|
str or None
|
Column in meta_df to color points by. If None, no coloring. |
None
|
f_x
|
str
|
Columns in scores_df to use as x and y coordinates. |
'F1'
|
f_y
|
str
|
Columns in scores_df to use as x and y coordinates. |
'F1'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_mixture_clusters_in_pca(pca_scores, cluster_df, id_col='id', pc_x='PC1', pc_y='PC2', ax=None)
¶
Plot mixture clusters in PCA space (PC1 vs PC2).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pca_scores
|
DataFrame
|
PCA scores indexed by id, columns including pc_x, pc_y. |
required |
cluster_df
|
DataFrame
|
Output of assign_mixture_clusters, must contain id_col, 'cluster'. |
required |
id_col
|
str
|
Column name for protein identifier. |
'id'
|
pc_x
|
str
|
Columns in pca_scores to use as x and y coordinates. |
'PC1'
|
pc_y
|
str
|
Columns in pca_scores to use as x and y coordinates. |
'PC1'
|
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
plot_cluster_size_bar(cluster_df, ax=None)
¶
Barplot of mixture cluster sizes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cluster_df
|
DataFrame
|
Output of assign_mixture_clusters, must contain 'cluster'. |
required |
ax
|
Axes or None
|
If None, creates a new figure/axis. |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
options: show_root_heading: true show_source: false members_order: source
cetsax.viz_predict¶
This module focuses on downstream evaluation of predictive models, including confusion matrices, ROC curves, saliency-style plots, and biologically oriented diagnostic summaries.
viz_predict.py
Visualization and Analysis of Model Predictions for Protein-Ligand Binding
visualize_predictions(pred_file, truth_file, out_dir)
¶
Generate a series of plots to visualize model predictions against ground truth labels and experimental data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_file
|
str
|
Path to CSV file containing model predictions. |
required |
truth_file
|
str
|
Path to CSV file containing ground truth labels and experimental data. |
required |
- Global Residue Importance (Aggregated IG)
analyze_fitting_data(fits_file, pred_file, out_dir)
¶
Analyze and visualize the quality of curve fitting data in relation to model predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fits_file
|
str
|
Path to CSV file containing curve fitting parameters. |
required |
pred_file
|
str
|
Path to CSV file containing model predictions. |
required |
- Curve Reconstruction for Top Predicted Targets
generate_bio_insight(pred_file, truth_file, annot_file, out_dir)
¶
Generate biological insight plots based on model predictions, experimental data, and pathway annotations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_file
|
str
|
Path to CSV file containing model predictions. |
required |
truth_file
|
str
|
Path to CSV file containing ground truth labels and experimental data. |
required |
annot_file
|
str
|
Path to CSV file containing protein annotations (e.g., pathways). |
required |
- EC50 Validation Across Predicted Classes
plot_training_loop(history_file, out_dir)
¶
Plot training and validation loss/accuracy curves from model training history.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
history_file
|
str
|
Path to CSV file containing training history. |
required |
- Accuracy Curves (if applicable)
options: show_root_heading: true show_source: false members_order: source
Deep learning¶
cetsax.deeplearn¶
This subpackage contains sequence-based learning modules built around protein language models and custom training utilities.
CETSAX Deep Learning Module This module contains functions and classes for training deep learning models to predict NADPH responsiveness of proteins based on their amino acid sequences. It includes utilities to build supervised training tables from sequence data and NADPH response measurements.
train_seq_model(csv_path, cfg, patience=True)
¶
Returns: model, metrics, paths(dict of caches)
options: show_root_heading: true show_source: false members_order: source
cetsax.deeplearn.esm_seq_nadph¶
This module contains the ESM-based training, caching, inference, and explainability workflow.
esm_seq_nadph.py
Sequence-based modelling of NADPH responsiveness using protein language models (ESM2).
What this script supports (fully working): - Build supervised table from EC50 fits + FASTA (adds seq, label_cls, label_reg). - Cache: 1) tokens (per protein) 2) pooled embeddings (per protein) -> fastest training 3) residue representations (per protein) + masks -> explainability without rerunning ESM - Train with three modes: - train_mode="pooled" : uses cached pooled embeddings (fast) - train_mode="reps" : uses cached per-residue reps (fast-ish, enables explainability) - train_mode="tokens" : runs ESM forward during training (slowest) - Focal Loss + class weights + WeightedRandomSampler - Early stopping + ReduceLROnPlateau scheduler - Saliency + Integrated Gradients from cached reps (no ESM required), and from tokens (uses ESM)
Hardware note (your setup: 1x CUDA 15GB): - Caching pooled reps: bs auto uses 2 on CUDA, with OOM fallback to 1 - Caching residue reps: bs auto uses 1 on CUDA (safest)
BSD 3-Clause License Copyright (c) 2025, Abhinav Mishra Email: mishraabhinav36@gmail.com
NADPHSeqDataset
¶
Bases: Dataset
Tokenize from CSV on-the-fly (slowest); kept for completeness.
NADPHSeqDatasetTokenCached
¶
Bases: Dataset
Load tokenized sequences from token cache.
NADPHPooledDataset
¶
Bases: Dataset
Load cached pooled embeddings (N, D).
NADPHRepsDataset
¶
Bases: Dataset
Load cached residue reps (list of (L,D)) + masks (list of (L,)).
FocalLoss
¶
Bases: Module
Focal loss for multi-class classification.
NADPHSeqModel
¶
Bases: Module
Single model that supports: - tokens path (runs ESM) - reps path (cached per-residue reps) - pooled path (cached pooled embeddings)
read_fasta_to_dict(fasta_path)
¶
Read FASTA into dict: {id: sequence}. Supports:
P12345 sp|P12345|...
compute_residue_saliency_from_reps(model, reps, mask, target_class=None)
¶
reps: (B,L,D) float32 requires_grad mask: (B,L) bool returns: (B,L) saliency scores
compute_residue_integrated_gradients_from_reps(model, reps, mask, target_class=None, steps=50)
¶
Integrated gradients in representation space. reps: (B,L,D) mask: (B,L) returns: (B,L) IG scores
build_token_cache(csv_path, cfg)
¶
Tokenize sequences once and save {tokens: List[Tensor(L)], label_cls, label_reg?}.
build_pooled_cache(token_cache_pt, cfg)
¶
Compute pooled (D,) embedding per protein using ESM reps + AttentionPooling. Saves: {"pooled": (N,D), "label_cls":..., "label_reg"?:...}
build_reps_cache(token_cache_pt, cfg)
¶
Cache per-residue reps so saliency/IG can run without ESM forward. Saves:
train_seq_model(csv_path, cfg, patience=True)
¶
Returns: model, metrics, paths(dict of caches)
options: show_root_heading: true show_source: false members_order: source
cetsax.deeplearn.my_seq_nadph¶
This module contains the Transformers/Hugging Face based sequence-modeling workflow with support for pooled, residue-level, and token-level training modes.
seq_nadph.py
Sequence-based modelling of NADPH responsiveness using Hugging Face Transformers (ESM2).
Enhancements over original: - Backend: Migrated from 'esm' to 'transformers' (AutoModel, AutoTokenizer). - Optimization: Added Automatic Mixed Precision (AMP) for faster training. - Memory: Added Gradient Accumulation to simulate large batches on small GPUs. - Flexibility: Added 'freeze_backbone' option to allow fine-tuning (inspired by HF examples). - Tokenization: robust handling via AutoTokenizer.
Modes: - train_mode="pooled": Cached embeddings (fastest). - train_mode="reps": Cached residue reps (fast-ish, supports Saliency/IG). - train_mode="tokens": End-to-end ESM forward pass (slowest, supports Fine-tuning).
BSD 3-Clause License Copyright (c) 2025, Abhinav Mishra
options: show_root_heading: true show_source: false members_order: source
Notes on expected data structures¶
Most analytical functions in cetsax assume one of a few standard tabular forms.
-
Raw CETSA tables usually include:
idcondition- dose columns defined in
cetsax.config.DOSE_COLS
-
Fit tables usually include:
EC50log10_EC50HillR2delta_max
-
Protein-level summary tables usually include:
id- sensitivity metrics such as
NSS - optional annotations such as pathway or redox role