import warnings
from pathlib import Path
from typing import Any, Sequence
import uuid
import numpy as np
import pandas as pd
import anndata as ad
from pandas.api.types import is_string_dtype, is_categorical_dtype
from scipy import sparse
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import leaves_list, linkage
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.patches import Patch
from matplotlib.ticker import MaxNLocator
import seaborn as sns
from proteopy.utils.anndata import check_proteodata, is_proteodata
from proteopy.utils.matplotlib import _resolve_color_scheme
from proteopy.utils.functools import partial_with_docsig
from proteopy.utils.string import sanitize_string
from proteopy.pp.stats import calculate_cv
def _validate_completeness_args( # noqa: C901
adata,
axis,
layer,
order,
group_by_resolution,
group_by_partition,
min_count,
min_fraction,
fraction_thresh,
bin_width,
):
"""Validate inputs and derive working variables for completeness."""
check_proteodata(adata)
if axis not in (0, 1):
raise ValueError(
"`axis` must be either 0 (var) or 1 (obs)."
)
if (
group_by_resolution is not None
and group_by_partition is not None
):
raise ValueError(
"`group_by_resolution` and `group_by_partition` "
"are mutually exclusive. Provide one or neither."
)
if min_count is not None and min_fraction is not None:
raise ValueError(
"`min_count` and `min_fraction` are mutually exclusive. "
"Provide one or neither."
)
if fraction_thresh is not None and (
fraction_thresh < 0 or fraction_thresh > 1
):
raise ValueError(
"`fraction_thresh` must be between 0 and 1."
)
if bin_width is not None and bin_width <= 0:
raise ValueError(
"`bin_width` must be a positive number."
)
if (
group_by_resolution is None
and (min_count is not None or min_fraction is not None)
):
warnings.warn(
"`min_count` and `min_fraction` are only used when "
"`group_by_resolution` is provided. They will be "
"ignored."
)
min_count = None
min_fraction = None
if layer is None:
matrix = adata.X
else:
if layer not in adata.layers:
raise KeyError(
f"Layer '{layer}' not found in adata.layers."
)
matrix = adata.layers[layer]
if matrix is None:
raise ValueError(
"Selected matrix is empty; cannot compute "
"completeness."
)
n_obs, n_vars = matrix.shape
if axis == 0:
axis_labels = ("var", "obs")
n_items = n_vars
axis_length = n_obs
grouping_frame = adata.obs
else:
axis_labels = ("obs", "var")
n_items = n_obs
axis_length = n_vars
grouping_frame = adata.var
if axis_length == 0:
raise ValueError(
"Cannot compute completeness on empty axis."
)
if n_items == 0:
raise ValueError(
"No items to compute completeness for."
)
if order is not None and group_by_partition is None:
warnings.warn(
"`order` is only used when "
"`group_by_partition` is provided. "
"It will be ignored."
)
return [
matrix, axis_labels, n_items, axis_length,
grouping_frame, min_count, min_fraction,
]
def _summary_stats(values):
"""Return a single-row DataFrame of summary statistics."""
s = pd.Series(values) if not isinstance(
values, pd.Series,
) else values
return pd.DataFrame({
"count": [s.count()],
"mean": [s.mean()],
"median": [s.median()],
"std": [s.std()],
"min": [s.min()],
"max": [s.max()],
})
def _count_nonmissing(mat, ax, zero_to_na):
"""Count non-missing values along the given axis."""
if sparse.issparse(mat):
mat_coo = mat.tocoo()
data = mat_coo.data
rows = mat_coo.row
cols = mat_coo.col
if zero_to_na:
valid = (~np.isnan(data)) & (data != 0)
if ax == 0:
return np.bincount(
cols[valid],
minlength=mat.shape[1],
)
else:
return np.bincount(
rows[valid],
minlength=mat.shape[0],
)
else:
nan_mask = np.isnan(data)
if ax == 0:
nan_c = np.bincount(
cols[nan_mask],
minlength=mat.shape[1],
)
return mat.shape[0] - nan_c
else:
nan_c = np.bincount(
rows[nan_mask],
minlength=mat.shape[0],
)
return mat.shape[1] - nan_c
else:
values = np.asarray(mat)
valid_mask = ~np.isnan(values)
if zero_to_na:
valid_mask &= values != 0
return valid_mask.sum(axis=ax)
def _resolve_partition_order(order, available):
"""Resolve and validate group order for partition plots."""
if order is not None:
if isinstance(order, str):
order = [order]
else:
order = list(order)
missing = [
g for g in order if g not in available
]
if missing:
raise ValueError(
"Unknown group(s) in `order`: "
f"{', '.join(map(str, missing))}.",
)
return order
return sorted(available, key=str)
def _group_completeness_counts(
matrix, axis, g_mask, zero_to_na,
):
"""Count non-missing values per item within a group mask."""
if axis == 0:
sub_matrix = matrix[g_mask, :]
else:
sub_matrix = matrix[:, g_mask]
counts = np.asarray(
_count_nonmissing(sub_matrix, axis, zero_to_na),
dtype=float,
)
return counts, int(g_mask.sum())
def _plot_completeness_partition(
matrix,
axis,
axis_labels,
zero_to_na,
grouping_frame,
group_by_partition,
order,
fraction_thresh,
print_stats,
xlabel_rotation,
figsize,
ax,
):
"""Plot boxplots of completeness partitioned by a grouping column."""
if group_by_partition not in grouping_frame.columns:
raise KeyError(
f"Column '{group_by_partition}' not found "
f"in {'.obs' if axis == 0 else '.var'}",
)
group_series = grouping_frame[group_by_partition]
available = list(group_series.dropna().unique())
unique_groups = _resolve_partition_order(
order, available,
)
if len(unique_groups) == 0:
raise ValueError(
"No groups found for the given "
"`group_by_partition` column.",
)
# -- compute completeness per item within each group
records = []
for g in unique_groups:
g_mask = (group_series == g).values
counts_g, g_size = _group_completeness_counts(
matrix, axis, g_mask, zero_to_na,
)
fracs = counts_g / g_size
for f in fracs:
records.append(
{"Group": str(g), "Completeness": f}
)
long_df = pd.DataFrame(records)
if print_stats:
print("Global:")
print(_summary_stats(
long_df["Completeness"],
).to_string(
index=False, float_format="%.4f",
))
per_group = (
long_df.groupby("Group")["Completeness"]
.agg(["count", "mean", "median",
"std", "min", "max"])
.reindex(
[str(g) for g in unique_groups],
)
)
print(f"\nPer {group_by_partition}:")
print(per_group.to_string(
float_format="%.4f",
))
print()
if ax is None:
fig, _ax = plt.subplots(figsize=figsize)
else:
_ax = ax
fig = _ax.get_figure()
sns.boxplot(
data=long_df,
x="Group",
y="Completeness",
order=[str(g) for g in unique_groups],
ax=_ax,
)
_ax.set_title(
f"Completeness per {axis_labels[0]} "
f"by '{group_by_partition}'",
)
_ax.set_xlabel(group_by_partition)
_ax.set_ylabel(
f"Fraction of non-missing {axis_labels[1]} "
f"values per {axis_labels[0]}",
)
if fraction_thresh is not None:
_ax.axhline(
fraction_thresh,
color="red",
linestyle="--",
label=f"fraction_thresh={fraction_thresh}",
)
_ax.legend()
plt.setp(
_ax.get_xticklabels(),
rotation=xlabel_rotation,
)
return fig, _ax
def _plot_completeness_ungrouped(
matrix,
axis,
axis_labels,
axis_length,
zero_to_na,
fraction_thresh,
print_stats,
bin_edges,
xlabel_rotation,
figsize,
ax,
):
"""Plot a histogram of ungrouped completeness fractions."""
counts = np.asarray(
_count_nonmissing(matrix, axis, zero_to_na),
dtype=float,
)
fractions = counts / axis_length
if print_stats:
print("Global:")
print(_summary_stats(fractions).to_string(
index=False, float_format="%.4f",
))
print()
if ax is None:
fig, _ax = plt.subplots(figsize=figsize)
else:
_ax = ax
fig = _ax.get_figure()
sns.histplot(fractions, bins=bin_edges, ax=_ax)
_ax.set_title(
f"Completeness per {axis_labels[0]}",
)
_ax.set_xlabel(
f"Fraction of non-missing {axis_labels[1]} values "
f"per {axis_labels[0]}",
)
if fraction_thresh is not None:
_ax.axvline(
fraction_thresh,
color="red",
linestyle="--",
label=f"fraction_thresh={fraction_thresh}",
)
_ax.legend()
plt.setp(
_ax.get_xticklabels(), rotation=xlabel_rotation,
)
return fig, _ax
def _plot_completeness_resolution(
matrix,
axis,
axis_labels,
n_items,
zero_to_na,
grouping_frame,
group_by_resolution,
min_count,
min_fraction,
fraction_thresh,
print_stats,
bin_edges,
xlabel_rotation,
figsize,
ax,
):
"""Plot a histogram of detection fractions across groups."""
if group_by_resolution not in grouping_frame.columns:
raise KeyError(
f"Column '{group_by_resolution}' not found in "
f"{'.obs' if axis == 0 else '.var'}",
)
group_series = grouping_frame[group_by_resolution]
unique_groups = list(
group_series.dropna().unique()
)
n_groups = len(unique_groups)
if n_groups == 0:
raise ValueError(
"No groups found for the given "
"`group_by_resolution` column.",
)
# Default threshold: min_count=1
use_fraction = min_fraction is not None
if not use_fraction and min_count is None:
min_count = 1
# For each group, determine which items are "detected"
detected_count = np.zeros(n_items, dtype=int)
for g in unique_groups:
g_mask = (group_series == g).values
counts_g, group_size = _group_completeness_counts(
matrix, axis, g_mask, zero_to_na,
)
if use_fraction:
detected = (
counts_g / group_size >= min_fraction
)
else:
detected = counts_g >= min_count
detected_count += detected.astype(int)
detection_fractions = detected_count / n_groups
if print_stats:
print("Global:")
print(_summary_stats(
detection_fractions,
).to_string(
index=False, float_format="%.4f",
))
print()
if ax is None:
fig, _ax = plt.subplots(figsize=figsize)
else:
_ax = ax
fig = _ax.get_figure()
sns.histplot(
detection_fractions, bins=bin_edges, ax=_ax,
)
if use_fraction:
threshold_label = (
f"min_fraction={min_fraction}"
)
else:
threshold_label = f"min_count={min_count}"
_ax.set_title(
f"'{group_by_resolution}' completeness "
f"per {axis_labels[0]}",
)
_ax.set_xlabel(
f"Fraction of '{group_by_resolution}' groups "
f"where {axis_labels[0]} is detected "
f"({threshold_label})",
)
if fraction_thresh is not None:
_ax.axvline(
fraction_thresh,
color="red",
linestyle="--",
label=f"fraction_thresh={fraction_thresh}",
)
_ax.legend()
plt.setp(
_ax.get_xticklabels(), rotation=xlabel_rotation,
)
return fig, _ax
def completeness(
adata: ad.AnnData,
axis: int,
layer: str | None = None,
zero_to_na: bool = False,
order: Sequence[Any] | None = None,
group_by_partition: str | None = None,
group_by_resolution: str | None = None,
min_count: int | None = None,
min_fraction: float | None = None,
fraction_thresh: float | None = None,
print_stats: bool = False,
bin_width: float = 0.01,
xlabel_rotation: float = 0.0,
figsize: tuple[float, float] = (6.0, 5.0),
show: bool = True,
ax: Axes | None = None,
save: str | Path | None = None,
) -> Axes:
"""
Plot a histogram of completeness across observations or variables.
When ``group_by_resolution`` is provided, shows the distribution of
the fraction of groups in which each item is "detected" (has at
least ``min_count`` or ``min_fraction`` non-missing values within
the group).
Parameters
----------
adata : AnnData
:class:`~anndata.AnnData` object in proteodata format.
axis
``0`` plots completeness per variable, ``1`` per observation.
layer
Name of the layer to use instead of ``.X``.
zero_to_na
Treat zero entries as missing values when True.
order
Explicit ordering and subsetting of groups when
``group_by_partition`` is provided. Groups not listed
are excluded.
group_by_partition
Column in ``.obs`` (axis 0) or ``.var`` (axis 1) used to
partition the opposite axis. For each partition group,
completeness fractions are computed per item and displayed
as side-by-side boxplots. Mutually exclusive with
``group_by_resolution``.
group_by_resolution
Column in ``.obs`` (axis 0) or ``.var`` (axis 1) used to define
detection groups. When provided, the plot shows the fraction of
groups in which each item is detected.
min_count : int or None, optional
Minimum number of non-missing values within a group for an item
to be considered detected. Mutually exclusive with
``min_fraction``. Only used when ``group_by_resolution`` is
provided.
min_fraction : float or None, optional
Minimum fraction of non-missing values within a group for an
item to be considered detected. Mutually exclusive with
``min_count``. Only used when ``group_by_resolution`` is
provided.
fraction_thresh : float or None, optional
Completeness fraction threshold in ``[0, 1]``. Drawn as a
vertical dashed line on histograms or a horizontal dashed
line on boxplots (``group_by_partition``).
print_stats : bool, optional
Print completeness distribution statistics before plotting.
When ``group_by_partition`` is provided, per-group statistics
are printed below the global summary.
bin_width : float, optional
Width of each histogram bin on the fraction axis. Bins span
from 0.0 to 1.0 + ``bin_width``. Defaults to 0.01.
xlabel_rotation
Rotation angle in degrees applied to x-axis tick labels.
figsize
Tuple ``(width, height)`` controlling figure size in inches.
show
Display the plot with ``plt.show()`` when True.
ax : Axes or None, optional
Matplotlib Axes object to plot onto. If ``None``, a new
figure and axes are created.
save : str or Path or None, optional
File path to save the figure. If ``None``, do not save.
Returns
-------
Axes
The Matplotlib Axes object used for plotting.
"""
validated = _validate_completeness_args(
adata, axis, layer, order,
group_by_resolution, group_by_partition,
min_count, min_fraction, fraction_thresh,
bin_width,
)
matrix = validated[0]
axis_labels = validated[1]
n_items = validated[2]
axis_length = validated[3]
grouping_frame = validated[4]
min_count = validated[5]
min_fraction = validated[6]
bin_edges = np.arange(
0.0, 1.0 + bin_width * 2, bin_width,
)
if group_by_partition is not None:
fig, _ax = _plot_completeness_partition(
matrix=matrix,
axis=axis,
axis_labels=axis_labels,
zero_to_na=zero_to_na,
grouping_frame=grouping_frame,
group_by_partition=group_by_partition,
order=order,
fraction_thresh=fraction_thresh,
print_stats=print_stats,
xlabel_rotation=xlabel_rotation,
figsize=figsize,
ax=ax,
)
elif group_by_resolution is None:
fig, _ax = _plot_completeness_ungrouped(
matrix=matrix,
axis=axis,
axis_labels=axis_labels,
axis_length=axis_length,
zero_to_na=zero_to_na,
fraction_thresh=fraction_thresh,
print_stats=print_stats,
bin_edges=bin_edges,
xlabel_rotation=xlabel_rotation,
figsize=figsize,
ax=ax,
)
else:
fig, _ax = _plot_completeness_resolution(
matrix=matrix,
axis=axis,
axis_labels=axis_labels,
n_items=n_items,
zero_to_na=zero_to_na,
grouping_frame=grouping_frame,
group_by_resolution=group_by_resolution,
min_count=min_count,
min_fraction=min_fraction,
fraction_thresh=fraction_thresh,
print_stats=print_stats,
bin_edges=bin_edges,
xlabel_rotation=xlabel_rotation,
figsize=figsize,
ax=ax,
)
if save is not None:
fig.savefig(save, dpi=300, bbox_inches="tight")
if show:
plt.show()
return _ax
[docs]
def completeness_per_var(
adata: ad.AnnData,
layer: str | None = None,
zero_to_na: bool = False,
order: Sequence[Any] | None = None,
group_by_partition: str | None = None,
group_by_resolution: str | None = None,
min_count: int | None = None,
min_fraction: float | None = None,
fraction_thresh: float | None = None,
print_stats: bool = False,
bin_width: float = 0.01,
xlabel_rotation: float = 0.0,
figsize: tuple[float, float] = (6.0, 5.0),
show: bool = True,
ax: Axes | None = None,
save: str | Path | None = None,
) -> Axes:
"""
Plot a histogram of completeness per variable.
For each variable (column), completeness is the fraction of
observations (rows) with non-missing values. When
``group_by_resolution`` is provided, shows the fraction of
observation-groups in which each variable is detected. When
``group_by_partition`` is provided, shows boxplots of per-variable
completeness within each partition group.
Parameters
----------
adata : AnnData
:class:`~anndata.AnnData` object in proteodata format.
layer
Name of the layer to use instead of ``.X``.
zero_to_na
Treat zero entries as missing values when True.
order
Explicit ordering and subsetting of groups when
``group_by_partition`` is provided. Groups not listed
are excluded.
group_by_partition
Column in ``.obs`` used to partition observations. For each
group, completeness fractions are computed per variable and
displayed as side-by-side boxplots. Mutually exclusive with
``group_by_resolution``.
group_by_resolution
Column in ``.obs`` used to define detection groups. When
provided, the plot shows the fraction of groups in which each
variable is detected.
min_count : int or None, optional
Minimum number of non-missing observations within a group for
a variable to be considered detected. Mutually exclusive with
``min_fraction``. Only used when ``group_by_resolution`` is
provided.
min_fraction : float or None, optional
Minimum fraction of non-missing observations within a group
for a variable to be considered detected. Mutually exclusive
with ``min_count``. Only used when ``group_by_resolution`` is
provided.
fraction_thresh : float or None, optional
Completeness fraction threshold in ``[0, 1]``. Drawn as a
vertical dashed line on histograms or a horizontal dashed
line on boxplots (``group_by_partition``).
print_stats : bool, optional
Print completeness distribution statistics before plotting.
When ``group_by_partition`` is provided, per-group statistics
are printed below the global summary.
bin_width : float, optional
Width of each histogram bin on the fraction axis. Bins span
from 0.0 to 1.0 + ``bin_width``. Defaults to 0.01.
xlabel_rotation
Rotation angle in degrees applied to x-axis tick labels.
figsize
Tuple ``(width, height)`` controlling figure size in inches.
show
Display the plot with ``plt.show()`` when True.
ax : Axes or None, optional
Matplotlib Axes object to plot onto. If ``None``, a new
figure and axes are created.
save : str or Path or None, optional
File path to save the figure. If ``None``, do not save.
Returns
-------
Axes
The Matplotlib Axes object used for plotting.
Examples
--------
>>> import proteopy as pr
>>> adata = pr.datasets.example_peptide_data()
>>> pr.pl.completeness_per_var(adata, fraction_thresh=0.7)
>>> pr.pl.completeness_per_var(
... adata,
... group_by_resolution="condition",
... min_count=1,
... )
>>> pr.pl.completeness_per_var(
... adata,
... group_by_partition="condition",
... order=["control", "treatment"],
... )
"""
return completeness(
adata,
axis=0,
layer=layer,
zero_to_na=zero_to_na,
order=order,
group_by_partition=group_by_partition,
group_by_resolution=group_by_resolution,
min_count=min_count,
min_fraction=min_fraction,
fraction_thresh=fraction_thresh,
print_stats=print_stats,
bin_width=bin_width,
xlabel_rotation=xlabel_rotation,
figsize=figsize,
show=show,
ax=ax,
save=save,
)
[docs]
def completeness_per_sample(
adata: ad.AnnData,
layer: str | None = None,
zero_to_na: bool = False,
order: Sequence[Any] | None = None,
group_by_partition: str | None = None,
group_by_resolution: str | None = None,
min_count: int | None = None,
min_fraction: float | None = None,
fraction_thresh: float | None = None,
print_stats: bool = False,
bin_width: float = 0.01,
xlabel_rotation: float = 0.0,
figsize: tuple[float, float] = (6.0, 5.0),
show: bool = True,
ax: Axes | None = None,
save: str | Path | None = None,
) -> Axes:
"""
Plot a histogram of completeness per sample (observation).
For each sample (row), completeness is the fraction of variables
(columns) with non-missing values. When ``group_by_resolution``
is provided, shows the fraction of variable-groups in which each
sample is detected. When ``group_by_partition`` is provided, shows
boxplots of per-sample completeness within each partition group.
Parameters
----------
adata : AnnData
:class:`~anndata.AnnData` object in proteodata format.
layer
Name of the layer to use instead of ``.X``.
zero_to_na
Treat zero entries as missing values when True.
order
Explicit ordering and subsetting of groups when
``group_by_partition`` is provided. Groups not listed
are excluded.
group_by_partition
Column in ``.var`` used to partition variables. For each
group, completeness fractions are computed per sample and
displayed as side-by-side boxplots. Mutually exclusive with
``group_by_resolution``.
group_by_resolution
Column in ``.var`` used to define detection groups. When
provided, the plot shows the fraction of groups in which each
sample is detected.
min_count : int or None, optional
Minimum number of non-missing variables within a group for a
sample to be considered detected. Mutually exclusive with
``min_fraction``. Only used when ``group_by_resolution`` is
provided.
min_fraction : float or None, optional
Minimum fraction of non-missing variables within a group for
a sample to be considered detected. Mutually exclusive with
``min_count``. Only used when ``group_by_resolution`` is
provided.
fraction_thresh : float or None, optional
Completeness fraction threshold in ``[0, 1]``. Drawn as a
vertical dashed line on histograms or a horizontal dashed
line on boxplots (``group_by_partition``).
print_stats : bool, optional
Print completeness distribution statistics before plotting.
When ``group_by_partition`` is provided, per-group statistics
are printed below the global summary.
bin_width : float, optional
Width of each histogram bin on the fraction axis. Bins span
from 0.0 to 1.0 + ``bin_width``. Defaults to 0.01.
xlabel_rotation
Rotation angle in degrees applied to x-axis tick labels.
figsize
Tuple ``(width, height)`` controlling figure size in inches.
show
Display the plot with ``plt.show()`` when True.
ax : Axes or None, optional
Matplotlib Axes object to plot onto. If ``None``, a new
figure and axes are created.
save : str or Path or None, optional
File path to save the figure. If ``None``, do not save.
Returns
-------
Axes
The Matplotlib Axes object used for plotting.
Examples
--------
>>> import proteopy as pr
>>> adata = pr.datasets.example_peptide_data()
>>> pr.pl.completeness_per_sample(adata, fraction_thresh=0.5)
With peptide-level proteodata, grouping by ``protein_id`` yields
the fraction of proteins detected per sample.
>>> pr.pl.completeness_per_sample(
... adata,
... group_by_resolution="protein_id",
... min_count=1,
... )
"""
return completeness(
adata,
axis=1,
layer=layer,
zero_to_na=zero_to_na,
order=order,
group_by_partition=group_by_partition,
group_by_resolution=group_by_resolution,
min_count=min_count,
min_fraction=min_fraction,
fraction_thresh=fraction_thresh,
print_stats=print_stats,
bin_width=bin_width,
xlabel_rotation=xlabel_rotation,
figsize=figsize,
show=show,
ax=ax,
save=save,
)
def _contains_value(seq, value) -> bool:
"""Check if *value* is in *seq*, treating NaN as equal."""
for item in seq:
if pd.isna(item) and pd.isna(value):
return True
if item == value:
return True
return False
def _append_unique(seq, value) -> None:
"""Append *value* to *seq* only if not already present."""
if not _contains_value(seq, value):
seq.append(value)
def _n_var_summary_stats(series):
"""Return a one-row DataFrame of count summary stats."""
return pd.DataFrame({
"mean_count": [series.mean()],
"std_count": [series.std()],
"median_count": [series.median()],
"min_count": [series.min()],
"max_count": [series.max()],
})
def _add_pct_cols(df, total):
"""Add percentage columns to *df* in place."""
for col in [
"mean", "std", "median", "min", "max",
]:
df[f"{col}_pct"] = (
df[f"{col}_count"] / total * 100
)
def _print_stats_df(df):
"""Print a DataFrame with one-decimal formatting."""
print(df.to_string(
index=False, float_format="%.1f",
))
_AGG_STATS = {
"mean_count": "mean",
"std_count": "std",
"median_count": "median",
"min_count": "min",
"max_count": "max",
}
def _validate_n_var_per_sample_args( # noqa: C901
adata,
level,
group_by,
order_by,
order,
layer,
):
"""Validate inputs for :func:`n_var_per_sample`."""
_, data_level = is_proteodata(adata, raise_error=True)
# -- Validate level
valid_levels = {"peptide", "protein", None}
if level not in valid_levels:
raise ValueError(
f"Invalid level '{level}'. Must be "
"'peptide', 'protein', or None."
)
if level == "peptide" and data_level == "protein":
raise ValueError(
"Cannot count peptides from "
"protein-level data."
)
# -- Mutual exclusivity
if group_by is not None and order_by is not None:
raise ValueError(
"`group_by` and `order_by` cannot be "
"used together."
)
# -- Validate layer
if layer is None:
matrix = adata.X
else:
if layer not in adata.layers:
raise KeyError(
f"Layer '{layer}' not found in "
"adata.layers."
)
matrix = adata.layers[layer]
if matrix is None:
raise ValueError(
"Selected layer is empty; cannot "
"compute variable counts."
)
# -- Validate group_by column
if group_by is not None:
if group_by not in adata.obs.columns:
raise KeyError(
f"Column '{group_by}' not found "
"in adata.obs."
)
# -- Validate order_by column
if order_by is not None:
if order_by not in adata.obs.columns:
raise KeyError(
f"Column '{order_by}' not found "
"in adata.obs."
)
# -- Validate order elements
if order is not None:
if group_by is not None:
valid = set(
adata.obs[group_by].dropna().unique()
)
source = f"adata.obs['{group_by}']"
elif order_by is not None:
valid = set(
adata.obs[order_by].dropna().unique()
)
source = f"adata.obs['{order_by}']"
else:
valid = set(adata.obs_names)
source = "adata.obs_names"
invalid = [
o for o in order if o not in valid
]
if invalid:
invalid_str = ", ".join(
map(str, invalid)
)
raise ValueError(
f"Unknown value(s) in `order`: "
f"{invalid_str}. Valid values "
f"come from {source}."
)
return data_level, level, matrix
def _valid_mask(matrix, zero_to_na):
"""Return a dense boolean mask of valid (non-missing) entries."""
if sparse.issparse(matrix):
arr = matrix.toarray()
else:
arr = np.asarray(matrix)
mask = ~np.isnan(arr)
if zero_to_na:
mask &= arr != 0
return mask
def _n_var_count_per_sample(
matrix, zero_to_na, level, data_level, adata,
):
"""Count non-missing vars per sample.
When *level* is ``"protein"`` on peptide-level data, counts
unique proteins with at least one non-missing peptide.
"""
valid = _valid_mask(matrix, zero_to_na)
# -- Count at native level
if level is None or level == data_level:
return valid.sum(axis=1)
# -- Protein count from peptide data
if level == "protein" and data_level == "peptide":
protein_codes, _ = pd.factorize(
adata.var["protein_id"].to_numpy(),
sort=False,
)
n_proteins = protein_codes.max() + 1
# OR-reduce peptide columns into protein columns
prot_detected = np.zeros(
(valid.shape[0], n_proteins), dtype=bool,
)
np.maximum.at(
prot_detected,
(slice(None), protein_codes),
valid,
)
return prot_detected.sum(axis=1)
raise ValueError(
f"Requested level '{level}' is "
f"incompatible with "
f"'{data_level}' data."
)
def _n_var_derive_totals(
counts_array, level, data_level,
percentage, ylabel, title, adata,
):
"""Derive totals, percentage, ylabel, and title."""
if level == "protein" and data_level == "peptide":
total_vars = adata.var["protein_id"].nunique()
else:
total_vars = adata.n_vars
if percentage:
if total_vars == 0:
raise ValueError(
"Cannot compute percentage: "
"no variables found."
)
counts_array = (
counts_array / total_vars
) * 100
# -- Resolve y-axis label
if ylabel is None:
ylabel = "%" if percentage else "#"
# -- Resolve title
if title is None:
if level == "protein" or (
level is None
and data_level == "protein"
):
entity = "proteins"
elif level == "peptide" or (
level is None
and data_level == "peptide"
):
entity = "peptides"
else:
entity = "variables"
title = f"Number of detected {entity}"
return total_vars, counts_array, ylabel, title
def _n_var_print_group_stats(
counts, stats_df, group_by, total_vars,
):
"""Print global and per-group statistics."""
global_df = _n_var_summary_stats(counts["count"])
_add_pct_cols(global_df, total_vars)
print("Global:")
_print_stats_df(global_df)
print_df = stats_df.copy()
_add_pct_cols(print_df, total_vars)
print(f"\nPer {group_by}:")
_print_stats_df(print_df)
def _n_var_resolve_bar_colors(
color_scheme, group_order, stats_df, group_by,
):
"""Resolve bar colors from a color scheme."""
if color_scheme is None:
return None
colors = _resolve_color_scheme(
color_scheme, group_order,
)
if colors is None:
return None
return [
colors[group_order.index(grp)]
for grp in stats_df[group_by]
]
def _n_var_group_by_path(
counts, adata, group_by, order,
color_scheme, total_vars, ylabel, title,
print_stats, figsize, xlabel_rotation,
save, show, ax=None,
):
"""Plot mean +/- std bar chart grouped by an obs column."""
group_df = adata.obs[[group_by]].copy()
group_df = group_df.rename_axis(
"obs",
).reset_index()
counts = pd.merge(
counts, group_df, on="obs", how="left",
)
counts = counts.dropna(subset=[group_by])
if counts.empty:
raise ValueError(
"No observations remain after "
"aligning `group_by` labels.",
)
group_values = counts[group_by]
if isinstance(
group_values.dtype, pd.CategoricalDtype,
):
group_values = (
group_values.cat
.remove_unused_categories()
)
counts[group_by] = group_values
available_groups: list[Any] = []
for value in group_values:
_append_unique(available_groups, value)
group_order = _n_var_resolve_group_order(
order, available_groups, group_values,
)
# Append any groups not yet in order
for value in available_groups:
_append_unique(group_order, value)
# -- Compute per-group statistics
stats_df = (
counts.groupby(group_by, observed=True)[
"count"
]
.agg(**_AGG_STATS)
.reindex(group_order)
)
stats_df = stats_df.dropna(
subset=["mean_count"],
)
stats_df["std_count"] = (
stats_df["std_count"].fillna(0.0)
)
stats_df = stats_df.reset_index()
if print_stats:
_n_var_print_group_stats(
counts, stats_df, group_by, total_vars,
)
# -- Plot grouped bar chart
bar_colors = _n_var_resolve_bar_colors(
color_scheme, group_order,
stats_df, group_by,
)
if ax is not None:
_ax = ax
fig = _ax.get_figure()
else:
fig, _ax = plt.subplots(figsize=figsize)
bar_labels = stats_df[group_by].astype(str)
_ax.bar(
bar_labels,
stats_df["mean_count"],
yerr=stats_df["std_count"],
color=bar_colors,
capsize=4.0,
edgecolor="black",
)
plt.setp(
_ax.get_xticklabels(),
rotation=xlabel_rotation,
ha="right",
)
_ax.set_xlabel(group_by)
_ax.set_ylabel(ylabel)
fig.suptitle(title, y=0.95)
plt.tight_layout()
if save is not None:
fig.savefig(
save, dpi=300,
bbox_inches="tight",
)
if show:
plt.show()
return _ax
def _n_var_resolve_group_order(
order, available_groups, group_values,
):
"""Resolve group ordering from order arg or categories."""
if order:
# Deduplicate while preserving order
group_order: list[Any] = []
for grp in order:
if not _contains_value(
group_order, grp,
):
group_order.append(grp)
return group_order
if isinstance(
group_values.dtype,
pd.CategoricalDtype,
):
return list(
group_values.cat.categories,
)
return available_groups.copy()
def _n_var_resolve_obs_ordering(
counts, obs_df, group_key, order,
available_groups, ascending,
):
"""Resolve observation ordering for the per-obs bar path."""
has_grouping = group_key != "_group"
if has_grouping:
group_order = _n_var_resolve_group_order(
order, available_groups, obs_df[group_key],
)
for grp in available_groups:
_append_unique(group_order, grp)
cat_index_map: dict[str, list[str]] = {}
for grp in group_order:
obs_list = obs_df.loc[
obs_df[group_key] == grp, "obs"
].tolist()
if obs_list:
cat_index_map[str(grp)] = obs_list
x_ordered = [
obs
for obs_list in cat_index_map.values()
for obs in obs_list
]
else:
if order:
# Deduplicate, then append remaining obs
x_ordered: list[Any] = []
for obs_name in order:
_append_unique(
x_ordered, obs_name,
)
for obs_name in counts["obs"]:
_append_unique(
x_ordered, obs_name,
)
else:
if ascending is not None:
sorted_counts = counts.sort_values(
"count",
ascending=ascending,
kind="mergesort",
)
x_ordered = sorted_counts[
"obs"
].tolist()
else:
x_ordered = counts[
"obs"
].tolist()
cat_index_map = {"all": x_ordered}
return x_ordered, cat_index_map
def _n_var_plot_per_obs(
counts, x_ordered, cat_index_map,
group_key, order_by, total_vars,
color_scheme, ylabel, title,
print_stats, figsize, xlabel_rotation,
order_by_label_rotation, save, show,
ax=None,
):
"""Plot per-observation bars with group labels."""
has_grouping = group_key != "_group"
# -- Print statistics
if print_stats:
if has_grouping:
global_df = _n_var_summary_stats(
counts["count"],
)
_add_pct_cols(global_df, total_vars)
print("Global:")
_print_stats_df(global_df)
print_df = (
counts.groupby(
order_by, observed=True,
)["count"]
.agg(**_AGG_STATS)
.reset_index()
)
_add_pct_cols(print_df, total_vars)
print(f"\nPer {order_by}:")
_print_stats_df(print_df)
else:
print_df = _n_var_summary_stats(
counts["count"],
)
_add_pct_cols(print_df, total_vars)
_print_stats_df(print_df)
# -- Resolve colors
counts[group_key] = (
counts[group_key].astype(str)
)
unique_groups = list(cat_index_map.keys())
colors = _resolve_color_scheme(
color_scheme, unique_groups,
)
plot_kwargs = {}
if colors is not None:
color_map = {
str(grp): colors[i]
for i, grp in enumerate(unique_groups)
}
plot_kwargs["color"] = (
counts[group_key].map(color_map).to_list()
)
# -- Plot per-observation bars
if ax is not None:
_ax = ax
fig = _ax.get_figure()
else:
fig, _ax = plt.subplots(figsize=figsize)
counts.plot(
kind="bar",
x="obs",
y="count",
ax=_ax,
legend=False,
**plot_kwargs,
)
plt.setp(
_ax.get_xticklabels(),
rotation=xlabel_rotation,
ha="right",
)
_ax.set_xlabel("")
_ax.set_ylabel(ylabel)
# -- Add group labels above bars
obs_idx_map = {
obs: i for i, obs in enumerate(x_ordered)
}
ymax = counts['count'].max()
for cat, obs_list in cat_index_map.items():
if not obs_list:
continue
start_idx = obs_idx_map[obs_list[0]]
end_idx = obs_idx_map[obs_list[-1]]
mid_idx = (start_idx + end_idx) / 2
_ax.text(
x=mid_idx,
y=ymax * 1.05,
s=cat,
ha='center',
va='bottom',
fontsize=8,
fontweight='bold',
rotation=order_by_label_rotation,
)
fig.suptitle(title, y=0.95)
plt.tight_layout()
if save is not None:
fig.savefig(
save, dpi=300, bbox_inches='tight',
)
if show:
plt.show()
return _ax
def n_var_per_sample(
adata: ad.AnnData,
*,
layer: str | None = None,
zero_to_na: bool = False,
level: str | None = None,
percentage: bool = False,
ascending: bool | None = None,
order_by: str | None = None,
order: Sequence[str] | None = None,
group_by: str | None = None,
print_stats: bool = False,
figsize: tuple[float, float] = (6.0, 4.0),
color_scheme: str | dict | Sequence | Colormap | callable | None = None,
title: str | None = None,
ylabel: str | None = None,
xlabel_rotation: float = 90,
order_by_label_rotation: float = 0,
show: bool = True,
ax: Axes | None = None,
save: str | Path | None = None,
) -> Axes:
"""
Plot the number of detected variables (peptides or protein) per sample.
Parameters
----------
adata : AnnData
:class:`~anndata.AnnData` object in proteodata format.
layer : str or None, optional
Key in ``adata.layers``; when set, uses that layer
instead of ``.X``.
zero_to_na : bool, optional
If ``True``, zeros in the matrix are treated as
missing values.
level : str or None, optional
``"peptide"`` counts detected peptides.
``"protein"`` counts detected proteins.
``None`` follows the intrinsic level of the data (.vars).
percentage : bool, optional
Display y-axis values as a percentage of total
variables instead of raw counts.
ascending : bool or None, optional
Sort observations by detected counts. ``True`` places
lower counts to the left; ``False`` places higher counts
to the left; ``None`` preserves the existing
observation order.
order_by : str or None, optional
Column in ``adata.obs`` used for grouping and
colouring bars.
order : Sequence[str] or None, optional
Controls ordering and subsetting on the x-axis.
Without ``group_by`` or ``order_by`` it lists
observation names. With ``order_by`` it specifies
the group order. With ``group_by`` it specifies
the group order for the bar chart.
group_by : str or None, optional
Column in ``adata.obs`` used to summarise
observations into groups. When provided, a mean
+/- std bar chart is shown. Mutually exclusive
with ``order_by``.
print_stats : bool, optional
Print summary statistics as a DataFrame.
figsize : tuple of float, optional
Figure size ``(width, height)`` in inches passed
to :func:`matplotlib.pyplot.subplots`.
color_scheme
Colour mapping for groups. Accepts a named
Matplotlib colormap, a single colour, a
list/tuple of colours, a dict mapping labels to
colours, a :class:`~matplotlib.colors.Colormap`,
or a callable.
title : str or None, optional
Plot title.
ylabel : str or None, optional
Label for the y-axis.
xlabel_rotation : float, optional
Rotation in degrees applied to x-axis tick labels.
order_by_label_rotation : float, optional
Rotation in degrees applied to group labels drawn
above the plot.
show : bool, optional
Call :func:`matplotlib.pyplot.show` when ``True``.
ax : Axes or None, optional
Matplotlib Axes to plot onto. If ``None``, a new
figure and axes are created.
save : str or Path or None, optional
File path to save the figure.
Returns
-------
Axes
The Matplotlib Axes object used for plotting.
Examples
--------
>>> import proteopy as pr
>>> adata = pr.datasets.karayel_2020()
>>> pr.pl.n_var_per_sample(adata)
Show mean +/- std per group:
>>> pr.pl.n_var_per_sample(
... adata,
... group_by="cell_type",
... )
Order bars by a grouping column:
>>> pr.pl.n_var_per_sample(
... adata,
... order_by="cell_type",
... order=["LBaso", "Ortho"],
... )
"""
data_level, level, matrix = (
_validate_n_var_per_sample_args(
adata, level, group_by, order_by,
order, layer,
)
)
# -- Count non-missing vars per sample
counts_array = _n_var_count_per_sample(
matrix, zero_to_na, level, data_level, adata,
)
# -- Derive totals, percentage, ylabel, and title
total_vars, counts_array, ylabel, title = (
_n_var_derive_totals(
counts_array, level, data_level,
percentage, ylabel, title, adata,
)
)
# -- Build counts DataFrame
counts_series = pd.Series(
counts_array,
index=adata.obs_names,
name="count",
)
counts = counts_series.rename_axis(
"obs",
).reset_index()
# -- Warn when ascending has no effect
if ascending is not None:
if group_by is not None:
warnings.warn(
"`ascending` is ignored when "
"`group_by` is set.",
UserWarning,
stacklevel=2,
)
elif order is not None:
warnings.warn(
"`ascending` is ignored when "
"`order` is set explicitly.",
UserWarning,
stacklevel=2,
)
# -- group_by path: mean +/- std bar plot per group
if group_by is not None:
return _n_var_group_by_path(
counts, adata, group_by, order,
color_scheme, total_vars, ylabel,
title, print_stats, figsize,
xlabel_rotation, save, show, ax,
)
# -- Per-observation bar plot (with optional order_by)
has_grouping = order_by is not None
group_key = (
order_by if has_grouping else "_group"
)
# Attach grouping column to counts
if has_grouping:
if group_key != "obs":
obs = adata.obs[[group_key]].copy()
obs = obs.rename_axis(
"obs",
).reset_index()
counts = pd.merge(
counts, obs, on="obs", how="left",
)
else:
counts[group_key] = counts["obs"]
else:
counts[group_key] = "all"
obs_df = adata.obs.copy()
obs_df = obs_df.rename_axis(
"obs",
).reset_index()
if group_key not in obs_df.columns:
obs_df[group_key] = "all"
if has_grouping and isinstance(
obs_df[group_key].dtype,
pd.CategoricalDtype,
):
obs_df[group_key] = (
obs_df[group_key].astype("category")
)
available_groups: list[Any] = []
for value in obs_df[group_key]:
_append_unique(available_groups, value)
# -- Resolve observation ordering
x_ordered, cat_index_map = (
_n_var_resolve_obs_ordering(
counts, obs_df, group_key, order,
available_groups, ascending,
)
)
counts["obs"] = pd.Categorical(
counts["obs"],
categories=x_ordered,
ordered=True,
)
counts = counts.sort_values("obs")
# -- Plot per-observation bars
return _n_var_plot_per_obs(
counts, x_ordered, cat_index_map,
group_key, order_by, total_vars,
color_scheme, ylabel, title,
print_stats, figsize, xlabel_rotation,
order_by_label_rotation, save, show, ax,
)
n_peptides_per_sample = partial_with_docsig(
n_var_per_sample,
level="peptide",
docstr_header="""\
Plot the number of detected peptides per sample.
For each sample (observation), counts the number of
peptides with non-missing values. Requires peptide-level
proteodata.""",
docstr_examples="""\
>>> import proteopy as pr
>>> adata = pr.datasets.williams_2018()
>>> pr.pl.n_peptides_per_sample(adata)
Show mean +/- std per group:
>>> pr.pl.n_peptides_per_sample(
... adata,
... group_by="tissue",
... )""",
)
n_proteins_per_sample = partial_with_docsig(
n_var_per_sample,
level="protein",
docstr_header="""\
Plot the number of detected proteins per sample.
For each sample (observation), counts the number of
proteins with non-missing values.
""",
docstr_examples="""\
>>> import proteopy as pr
Protein-level data:
>>> adata = pr.datasets.karayel_2020()
>>> pr.pl.n_proteins_per_sample(adata)
Peptide-level data (aggregated to proteins):
>>> adata = pr.datasets.williams_2018()
>>> pr.pl.n_proteins_per_sample(adata)
Show mean +/- std per group:
>>> pr.pl.n_proteins_per_sample(
... adata,
... group_by="tissue",
... )""",
)
[docs]
def n_samples_per_category(
adata: ad.AnnData,
category_key: str | Sequence[str],
categories: Sequence[Any] | None = None,
ignore_na: bool = False,
ascending: bool = False,
order: Sequence[Any] | None = None,
xlabel_rotation: float = 45.0,
color_scheme: Any | None = None,
figsize: tuple[float, float] = (6.0, 4.0),
show: bool = True,
save: str | Path | None = None,
ax: bool = False,
) -> Axes | None:
"""
Plot sample (obs) counts per category (optionally stratified).
Parameters
----------
adata : anndata.AnnData
Annotated data matrix with categorical obs annotations.
category_key : str | Sequence[str]
One or two column names in ``adata.obs`` used to stratify observations.
categories : Sequence[Any] | None
Labels from the first category column to display on the x-axis. Rows
whose first-column value is not listed are dropped.
ignore_na : bool
Drop observations with missing labels when ``True``; otherwise, missing
values are shown as ``"missing"``.
ascending : bool
Sort categories by total counts when no explicit order is supplied.
``True`` places lower counts on the left.
order : Sequence[Any] | None
Explicit order for the x-axis labels (values of the first category
column). Any levels not listed are appended afterwards in their intrinsic
order. When provided, ``ascending`` is ignored.
xlabel_rotation : float
Rotation angle (degrees) applied to the x-axis tick labels.
color_scheme : Any | None
Mapping, sequence, colormap name, or callable used to colour categories.
figsize : tuple[float, float]
Figure size (width, height) in inches used for
:func:`matplotlib.pyplot.subplots`.
show : bool
Call :func:`matplotlib.pyplot.show` when ``True``.
save : str | Path | None
Save the figure to the provided path (``str`` or :class:`~pathlib.Path``).
ax : bool
Return the :class:`~matplotlib.axes.Axes` instead of displaying the plot.
"""
check_proteodata(adata)
if isinstance(category_key, str):
category_cols = [category_key]
else:
category_cols = list(category_key)
if not category_cols:
raise ValueError("category_key must contain at least one column name.")
missing_label = "missing"
unknown_cols = [col for col in category_cols if col not in adata.obs]
if unknown_cols:
raise KeyError(
"Column(s) missing in adata.obs: "
f"{', '.join(map(str, unknown_cols))}."
)
obs = adata.obs.loc[:, category_cols].copy()
for col in category_cols:
if not (is_string_dtype(obs[col]) or is_categorical_dtype(obs[col])):
obs[col] = obs[col].astype("string")
if ignore_na:
continue
if is_categorical_dtype(obs[col]):
if missing_label not in obs[col].cat.categories:
obs[col] = obs[col].cat.add_categories([missing_label])
obs[col] = obs[col].fillna(missing_label)
else:
obs[col] = obs[col].fillna(missing_label)
first_cat_col = category_cols[0]
if ignore_na:
obs = obs.dropna(subset=category_cols, how="any")
first_cat_col = category_cols[0]
selected_categories: list[Any] | None = None
if categories is not None:
if isinstance(categories, (str, bytes)):
selected_categories = [categories]
else:
selected_categories = list(categories)
if not selected_categories:
raise ValueError("categories must contain at least one label.")
mask = obs[first_cat_col].isin(selected_categories)
if not mask.any():
raise ValueError("No observations match the requested categories.")
obs = obs.loc[mask].copy()
if obs.empty:
raise ValueError("No observations available after NA handling.")
for col in category_cols:
if is_categorical_dtype(obs[col]):
obs[col] = obs[col].cat.remove_unused_categories()
def _ordered_categories(series: pd.Series) -> list[Any]:
if is_categorical_dtype(series):
ordered = list(series.cat.categories)
else:
ordered = list(pd.unique(series))
if not ignore_na and missing_label in ordered:
ordered = [
value for value in ordered if value != missing_label
] + [missing_label]
return ordered
first_level_order = _ordered_categories(obs[first_cat_col])
if selected_categories is not None:
first_level_order = [
category for category in selected_categories if category in first_level_order
]
if order is not None:
if isinstance(order, str):
specified = [order]
else:
specified = list(order)
unknown_specified = [cat for cat in specified if cat not in first_level_order]
if unknown_specified:
raise ValueError(
"Order values not present in the first category column: "
f"{', '.join(map(str, unknown_specified))}."
)
remaining = [cat for cat in first_level_order if cat not in specified]
first_level_order = specified + remaining
use_count_sort = order is None and selected_categories is None
fig, _ax = plt.subplots(figsize=figsize)
if len(category_cols) == 1:
freq = obs[first_cat_col].value_counts(dropna=False)
if use_count_sort:
freq = freq.sort_values(ascending=ascending)
else:
freq = freq.reindex(first_level_order, fill_value=0)
plot_kwargs: dict[str, Any] = {}
if color_scheme is not None:
colors = _resolve_color_scheme(color_scheme, freq.index)
if colors is not None:
plot_kwargs["color"] = colors
freq.plot(kind="bar", ax=_ax, **plot_kwargs)
elif len(category_cols) == 2:
second_cat_col = category_cols[1]
second_level_order = _ordered_categories(obs[second_cat_col])
df = (
obs.groupby(category_cols, observed=False)
.size()
.unstack(fill_value=0)
)
df = df.reindex(first_level_order, fill_value=0)
df = df.reindex(columns=second_level_order, fill_value=0)
if use_count_sort:
df = df.loc[df.sum(axis=1).sort_values(ascending=ascending).index]
colors = _resolve_color_scheme(color_scheme, df.columns)
plot_kwargs: dict[str, Any] = {}
if colors is not None:
plot_kwargs["color"] = colors
df.plot(kind="bar", stacked=True, ax=_ax, **plot_kwargs)
if df.shape[1] > 1:
_ax.legend(loc="center right", bbox_to_anchor=(1.4, 0.5))
else:
raise NotImplementedError(
"Plotting more than two category columns is not implemented."
)
_ax.yaxis.set_major_locator(MaxNLocator(integer=True))
_ax.set_xlabel(first_cat_col)
_ax.set_ylabel('#')
ha = (
'right' if xlabel_rotation > 0
else 'left' if xlabel_rotation < 0
else 'center'
)
plt.setp(_ax.get_xticklabels(), rotation=xlabel_rotation, ha=ha)
fig.tight_layout()
save_path: Path | None = Path(save) if save is not None else None
if save_path is not None:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
if show:
plt.show()
if ax:
return _ax
if not show and save_path is None and not ax:
warnings.warn(
"Function does not do anything. Enable `show`, provide a `save` path, "
"or set `ax=True`."
)
plt.close(fig)
[docs]
def n_cat1_per_cat2_hist(
adata: ad.AnnData,
first_category: str,
second_category: str,
axis: int,
bin_width: float | None = None,
bin_range: tuple[float, float] | None = None,
print_stats: bool = False,
figsize: tuple[float, float] = (6.0, 4.0),
show: bool = True,
save: str | Path | None = None,
ax: Axes | None = None,
) -> Axes:
"""
Plot the distribution of the number of first-category entries per second
category.
Parameters
----------
adata : AnnData
Annotated data matrix.
first_category : str
Column providing the secondary category from the same axis as
``second_category``. Pass ``"index"`` to use ``adata.obs_names`` (``axis
== 0``) or ``adata.var_names`` (``axis == 1``).
second_category : str
Column name identifying the primary category. Resolved from
``adata.obs`` when ``axis == 0`` and ``adata.var`` when ``axis == 1``.
Passing ``"index"`` is not supported.
axis : int
``0`` to work on ``adata.obs``, ``1`` to work on ``adata.var``.
bin_width : float | None
Optional histogram bin width. Must be positive when provided.
bin_range : tuple[float, float] | None
Optional tuple ``(lower, upper)`` limiting the histogram bins. ``lower``
must be strictly smaller than ``upper``.
print_stats : bool
Print distribution statistics (mean, median, mode, variance, min, max).
figsize : tuple[float, float]
Size (width, height) in inches passed to
:func:`matplotlib.pyplot.subplots`.
show : bool
Call :func:`matplotlib.pyplot.show` when ``True``.
save : str | Path | None
Save the figure to the provided path when given.
ax : Axes | None
Matplotlib Axes to plot onto. If ``None``, a new figure and axes
are created.
"""
check_proteodata(adata)
# Ensures that the 'index' has unique values if used
if axis not in (0, 1):
raise ValueError("axis must be either 0 (.obs) or 1 (.var).")
frame = adata.obs if axis == 0 else adata.var
frame_label = ".obs" if axis == 0 else ".var"
if second_category == "index":
raise ValueError(
"`second_category='index'` is not supported; pass 'index' via "
"`first_category` instead."
)
if second_category not in frame:
raise KeyError(
f"Column '{second_category}' not found in adata{frame_label}."
)
if first_category != "index" and first_category not in frame:
raise KeyError(
f"Column '{first_category}' not found in adata{frame_label}."
)
if bin_width is not None:
if bin_width <= 0:
raise ValueError("bin_width must be a positive number.")
if bin_range is not None:
if (
not isinstance(bin_range, tuple)
or len(bin_range) != 2
or not all(np.isfinite(bin_range))
):
raise TypeError(
"bin_range must be a tuple of two finite numbers (lower, upper)."
)
lower, upper = bin_range
if lower >= upper:
raise ValueError("bin_range lower bound must be less than upper bound.")
temp_col = "__proteopy_axis_index__" if first_category == "index" else first_category
data = frame[[second_category]].copy()
if first_category == "index":
index_values = adata.obs_names if axis == 0 else adata.var_names
data[temp_col] = index_values
else:
data[temp_col] = frame[first_category]
data = data.drop_duplicates(subset=[second_category, temp_col])
counts = data.groupby(second_category, observed=False).size()
if counts.empty:
raise ValueError(
"No entries available to compute counts for the requested categories."
)
if bin_width is None:
edges = np.histogram_bin_edges(counts.values, bins="auto")
auto_width = edges[1] - edges[0]
bin_width = max(auto_width, 1.0)
if print_stats:
stats_df = pd.DataFrame(
{
"mean": [counts.mean()],
"median": [counts.median()],
"mode": [counts.mode().iloc[0]],
"variance": [counts.var()],
"min": [counts.min()],
"max": [counts.max()],
}
)
print(stats_df.to_string(index=False))
if ax is None:
fig, _ax = plt.subplots(figsize=figsize)
else:
_ax = ax
fig = _ax.get_figure()
if first_category == "index":
entry_label = "observations" if axis == 0 else "variables"
else:
entry_label = first_category
sns.histplot(
counts,
binwidth=bin_width,
binrange=bin_range,
ax=_ax,
)
_ax.set_xlabel(f"Number of {entry_label} per {second_category}")
_ax.set_ylabel(f"# {second_category}")
fig.tight_layout()
if save is not None:
fig.savefig(save, dpi=300, bbox_inches="tight")
if show:
plt.show()
return _ax
docstr_header = (
"Plot the distribution of the number of first-category entries per second category."
)
n_peptides_per_protein = partial_with_docsig(
n_cat1_per_cat2_hist,
first_category="peptide_id",
second_category="protein_id",
axis=1,
docstr_header=docstr_header,
)
n_proteoforms_per_protein = partial_with_docsig(
n_cat1_per_cat2_hist,
first_category="proteoform_id",
second_category="protein_id",
axis=1,
docstr_header=docstr_header,
)
[docs]
def cv_by_group(
adata: ad.AnnData,
group_by: str,
layer: str | None = None,
zero_to_na: bool = False,
min_samples: int = None,
force: bool = False,
order: list | None = None,
color_scheme=None,
alpha: float = 0.8,
hline: float | None = None,
show_points: bool = False,
point_alpha: float = 0.7,
point_size: float = 1,
xlabel_rotation: int | float = 0,
figsize: tuple[float, float] = (6, 4),
show: bool = True,
ax: bool = False,
save: str | None = None,
print_stats: bool = False,
):
"""
Compute per-group coefficients of variation and plot their distributions.
Parameters
----------
adata : AnnData
AnnData object that contains proteomics quantifications.
group_by : str
Column in ``adata.obs`` used to define observation groups for CV
calculation.
layer : str | None, optional
AnnData layer to read intensities from. Defaults to ``adata.X``.
zero_to_na : bool, optional
Replace zero values with NaN before computing CVs. Default is ``False``.
min_samples : int | None, optional
Minimum number of observations per variable required to compute a CV.
Variables with fewer non-NaN entries receive NaN. Default is ``3``.
Ignored when using precomputed CV data from ``adata.varm``.
force : bool, optional
Force recomputation of CV values even if precomputed data exists in
``adata.varm``. When ``True``, uses a temporary slot that is deleted
after extracting the data. Default is ``False``.
order : list | None, optional
Explicit order of group labels (without the ``cv_`` prefix) along the
x-axis. When ``None`` the observed group order is used.
color_scheme : sequence, dict | None, optional
Color assignments for groups. When None, uses the Matplotlib default
color cycle.
alpha : float, optional
Transparency for the violin bodies. Default is ``0.8``.
hline : float | None, optional
If set, draw a horizontal dashed line at this CV value.
show_points : bool, optional
Overlay individual variable CVs as a strip plot. Default is ``False``.
point_alpha : float, optional
Opacity for individual points when ``show_points`` is ``True``.
point_size : float, optional
Size of the individual CV points. Default is ``1``.
xlabel_rotation : float, optional
Rotation angle (degrees) for the x-axis group labels.
figsize : tuple of float, optional
Matplotlib figure size in inches. Default is ``(6, 4)``.
show : bool, optional
Call ``plt.show()`` when ``True``. Default is ``True``.
ax : bool, optional
Return the Matplotlib Axes if ``True``.
save : str | None, optional
Path to save the figure. When ``None`` the figure is not saved.
print_stats : bool, optional
Print CV summary statistics.
"""
check_proteodata(adata)
if group_by not in adata.obs.columns:
raise KeyError(f"Column '{group_by}' not found in adata.obs.")
if adata.n_obs == 0:
raise ValueError(
"AnnData object contains no observations; cannot compute CVs."
)
groups = adata.obs[group_by]
if groups.dropna().empty:
raise ValueError(
f"Column '{group_by}' does not contain any non-missing group labels."
)
if isinstance(groups.dtype, pd.CategoricalDtype):
observed_groups = groups.cat.remove_unused_categories().cat.categories
unique_groups = [str(cat) for cat in observed_groups]
else:
unique_groups = pd.Index(groups.astype(str)).unique().tolist()
if not unique_groups:
raise ValueError(
f"Column '{group_by}' does not contain any finite groups."
)
# Use existing CV data if available; otherwise compute temporarily
layer_suffix = sanitize_string(layer) if layer is not None else "X"
varm_key = f"cv_by_{sanitize_string(group_by)}_{layer_suffix}"
key_existed = varm_key in adata.varm
temp_key_name = None
# Determine whether to use precomputed data or compute new
use_precomputed = key_existed and not force
if use_precomputed:
# Check if min_samples was explicitly provided
if min_samples:
raise ValueError(
f"Cannot use `min_samples={min_samples}` with precomputed CV "
f"data in adata.varm['{varm_key}']. Either:\n"
f" - Use `force=True` to recompute CV values with the new "
f"`min_samples` setting, or\n"
f" - Remove the precomputed data with "
f"`del adata.varm['{varm_key}']` before calling this function."
)
print(f"Using existing CV data from adata.varm['{varm_key}'].")
key_to_use = varm_key
else:
# Random key prevents overwriting existing varm slots
temp_key_name = f"_temp_cv_{uuid.uuid4().hex[:8]}"
default_min_samples = 3
min_samples = min_samples or default_min_samples
calculate_cv(
adata,
group_by=group_by,
layer=layer,
zero_to_na=zero_to_na,
min_samples=min_samples,
key_added=temp_key_name,
inplace=True,
)
key_to_use = temp_key_name
if key_to_use not in adata.varm:
raise RuntimeError(
f"Failed to compute CV data: adata.varm['{key_to_use}'] not found."
)
check_proteodata(adata)
cv_df = adata.varm[key_to_use].copy()
# Clean up temporary data immediately after extraction
if temp_key_name is not None:
del adata.varm[temp_key_name]
df_melted = cv_df.melt(var_name="Group", value_name="CV", ignore_index=False)
df_melted = df_melted.reset_index(drop=True)
if order is None:
order = unique_groups
else:
missing = [grp for grp in order if grp not in df_melted["Group"].unique()]
if missing:
raise ValueError(
"Requested ordering includes groups with no CV data: "
f"{', '.join(missing)}."
)
resolved_colors = _resolve_color_scheme(color_scheme, order)
if resolved_colors is None:
palette = None
else:
palette = dict(zip(order, resolved_colors))
if print_stats:
cv_values = df_melted["CV"].dropna()
global_summary = pd.DataFrame({
"Count": [cv_values.count()],
"Min": [round(cv_values.min(), 4)],
"Max": [round(cv_values.max(), 4)],
"Median": [round(cv_values.median(), 4)],
"Mean": [round(cv_values.mean(), 4)],
"Std": [round(cv_values.std(), 4)],
})
print("Global CV Summary:")
print(global_summary.to_string(index=False))
print()
per_group = (
df_melted.groupby("Group")["CV"]
.agg(
Count="count",
Min="min",
Max="max",
Median="median",
Mean="mean",
Std="std",
)
.round(4)
.reindex(order)
)
print("Per-Group CV Summary:")
print(per_group.to_string())
print()
if hline is not None:
below_count = (cv_values < hline).sum()
total_count = cv_values.count()
pct = (
round(below_count / total_count * 100, 4)
if total_count > 0
else 0.0
)
global_thresh = pd.DataFrame({
"Count below": [int(below_count)],
"Percentage below": [pct],
})
print(
f"Global Threshold Summary "
f"(hline={hline}):"
)
print(global_thresh.to_string(index=False))
print()
def _thresh_stats(group_cv):
n_below = (group_cv < hline).sum()
n_total = group_cv.count()
pct_below = (
round(n_below / n_total * 100, 4)
if n_total > 0
else 0.0
)
return pd.Series({
"Count below": int(n_below),
"Percentage below": pct_below,
})
per_group_thresh = (
df_melted.groupby("Group")["CV"]
.apply(_thresh_stats)
.unstack()
.reindex(order)
)
print(
f"Per-Group Threshold Summary "
f"(hline={hline}):"
)
print(per_group_thresh.to_string())
print()
fig, ax_plot = plt.subplots(figsize=figsize, dpi=150)
sns.violinplot(
data=df_melted,
x="Group",
y="CV",
hue="Group",
order=order,
palette=palette,
cut=0,
inner="box",
alpha=alpha,
legend=False,
ax=ax_plot,
)
# Optionally overlay points
if show_points:
sns.stripplot(
data=df_melted,
x="Group",
y="CV",
order=order,
color="black",
alpha=point_alpha,
size=point_size,
jitter=0.2,
dodge=False,
ax=ax_plot,
)
# Optional horizontal dashed line
if hline is not None:
ax_plot.axhline(
y=hline,
color="black",
linestyle="--",
linewidth=1,
alpha=0.8,
)
# add annotation for clarity
ax_plot.text(
x=-0.4,
y=hline,
s=f"{hline:.2f}",
color="black",
va="bottom",
ha="left",
fontsize=8,
)
ax_plot.set_xlabel("")
ax_plot.set_ylabel("Coefficient of Variation (CV)")
for label in ax_plot.get_xticklabels():
label.set_rotation(xlabel_rotation)
ax_plot.set_title("Distribution of CV across groups")
sns.despine()
plt.tight_layout()
check_proteodata(adata)
if save:
fig.savefig(save, dpi=300, bbox_inches="tight")
print(f"Figure saved to: {save}")
if show:
plt.show()
if ax:
return ax_plot
[docs]
def sample_correlation_matrix(
adata: ad.AnnData,
method: str = "pearson",
zero_to_na: bool = False,
layer: str | None = None,
fill_na: float | None = None,
margin_color: str | None = None,
color_scheme=None,
cmap: str = "coolwarm",
linkage_method: str = "average",
xticklabels: bool = False,
yticklabels: bool = False,
figsize: tuple[float, float] = (9.0, 7.0),
show: bool = True,
ax: bool = False,
print_stats: bool = False,
save: str | Path | None = None,
) -> Axes | None:
"""
Plot a clustered correlation heatmap across samples (obs).
Parameters
----------
adata : AnnData
:class:`~anndata.AnnData` with proteomics annotations.
method : str
Correlation estimator passed to :meth:`pandas.DataFrame.corr`.
zero_to_na : bool
Replace zeros with missing values before computing correlations.
layer : str | None
Optional ``adata.layers`` key to draw quantification values from.
When ``None`` the primary matrix ``adata.X`` is used.
fill_na : float | None
Constant used to replace remaining ``NaN`` values prior to
correlation. When ``None`` (default), a :class:`ValueError` is raised
if missing values are detected (suggesting ``fill_na=0``).
margin_color : str | None
Optional column in ``adata.obs`` used to color dendrogram labels.
color_scheme : Any
Color palette specification understood by
:func:`proteopy.utils.matplotlib._resolve_color_scheme`.
cmap : str
Continuous colormap for the heatmap body.
linkage_method : str
Linkage criterion handed to :func:`scipy.cluster.hierarchy.linkage`.
xticklabels, yticklabels : bool
Whether to show x- and y-axis tick labels.
figsize : tuple[float, float]
Matplotlib figure size in inches.
show : bool
Display the figure with :func:`matplotlib.pyplot.show`.
ax : bool
Return the heatmap :class:`matplotlib.axes.Axes` when ``True``.
print_stats : bool
Print correlation summary statistics before drawing the plot.
Includes overall off-diagonal statistics, per-sample mean
correlation, and per-group correlations when ``margin_color``
is provided.
save : str | Path | None
File path for saving the Seaborn cluster map. When ``None`` nothing is
written.
Returns
-------
Axes or None
Heatmap axes when ``ax`` is ``True``; otherwise ``None``.
Raises
------
ValueError
If the selected matrix still contains missing values after optional
zero replacement and ``fill_na`` is ``None``.
"""
check_proteodata(adata)
# ---- values from adata.X or a specified layer (obs × var)
expected_shape = (adata.n_obs, adata.n_vars)
if layer is None:
matrix = adata.X
else:
if layer not in adata.layers:
raise KeyError(f"Layer '{layer}' not found in adata.layers.")
matrix = adata.layers[layer]
if matrix is None:
raise ValueError("Selected matrix is empty; cannot compute correlations.")
if matrix.shape != expected_shape:
raise ValueError(
"Selected matrix shape "
f"{matrix.shape} does not match adata dimensions {expected_shape}."
)
if isinstance(matrix, pd.DataFrame):
vals = matrix.reindex(index=adata.obs_names, columns=adata.var_names).copy()
else:
if sparse.issparse(matrix):
# correlation requires dense values; convert temporarily
dense_matrix = matrix.toarray()
else:
dense_matrix = np.asarray(matrix)
vals = pd.DataFrame(
dense_matrix,
index=adata.obs_names,
columns=adata.var_names,
)
if zero_to_na:
vals = vals.replace(0, np.nan)
if fill_na is not None:
vals = vals.fillna(fill_na)
if vals.isna().to_numpy().any():
raise ValueError(
"Input matrix contains missing values; provide `fill_na` (e.g., "
"`fill_na=0`) to replace them before computing correlations."
)
# ---- obs×obs correlation (pairwise complete)
corr_df = vals.T.corr(method=method) # (obs × obs)
corr_df.index = adata.obs_names
corr_df.columns = adata.obs_names
# ---- compute off-diagonal mean for color center
A = corr_df.values.astype(float, copy=False)
n = A.shape[0]
if n > 1:
offdiag = A[~np.eye(n, dtype=bool)]
center_val = np.nanmean(offdiag)
else:
center_val = float(np.nanmean(A)) # degenerate case
# ---- optional row/col colors from obs[margin_color]
row_colors = None
legend_handles = None
if margin_color is not None:
if margin_color not in adata.obs.columns:
raise KeyError(f"Column '{margin_color}' not found in adata.obs.")
groups = adata.obs.loc[corr_df.index, margin_color]
cats = pd.Categorical(groups.dropna()).categories
resolved_colors = _resolve_color_scheme(color_scheme, cats)
if resolved_colors is None:
resolved_colors = (
sns.color_palette(n_colors=len(cats)) if len(cats) > 0 else []
)
palette = {str(cat): color for cat, color in zip(cats, resolved_colors)}
groups_str = groups.astype("string")
row_color_series = groups_str.map(palette)
missing_mask = row_color_series.isna() & groups.notna()
if missing_mask.any():
missing_cats = sorted(groups[missing_mask].astype(str).unique())
raise ValueError(
"No color provided for categories: "
f"{', '.join(missing_cats)} in '{margin_color}'."
)
legend_handles = [
Patch(facecolor=palette[str(cat)], edgecolor="none", label=str(cat))
for cat in cats
]
if groups.isna().any():
na_color = mpl.colors.to_rgba("lightgray")
row_color_series = row_color_series.astype(object)
row_color_series[groups.isna()] = na_color
legend_handles.append(
Patch(facecolor=na_color, edgecolor="none", label="NA")
)
row_colors = (
row_color_series.to_numpy() if row_color_series is not None else None
)
# ---- hierarchical clustering on (1 - r)
dist = 1 - corr_df.values
np.fill_diagonal(dist, 0.0)
dist = np.clip(dist, 0, 2) # numerical guard
Z = linkage(squareform(dist), method=linkage_method)
# ---- optional statistics printout
if print_stats and n > 1:
# 1) Overall off-diagonal summary
summary = pd.DataFrame({
"min": [np.nanmin(offdiag)],
"max": [np.nanmax(offdiag)],
"mean": [np.nanmean(offdiag)],
"median": [np.nanmedian(offdiag)],
"std": [np.nanstd(offdiag)],
})
print(
f"Sample correlation summary "
f"(off-diagonal, {method}):"
)
print(summary.to_string(index=False))
print()
# 2) Per-sample mean correlation (dendrogram order)
mask = ~np.eye(n, dtype=bool)
per_sample_mean = np.nanmean(
np.where(mask, A, np.nan), axis=1
)
heatmap_order = leaves_list(Z)
per_sample_df = pd.DataFrame({
"sample_id": corr_df.index[heatmap_order],
"mean_corr": per_sample_mean[heatmap_order],
})
print("Per-sample mean correlation:")
print(per_sample_df.to_string(index=False))
print()
# 3) Per-group correlation (if margin_color provided)
if margin_color is not None:
if margin_color not in adata.obs.columns:
raise KeyError(
f"Column '{margin_color}' not found "
f"in adata.obs."
)
groups_ps = adata.obs.loc[
corr_df.index, margin_color
]
unique_groups = groups_ps.dropna().unique()
group_rows = []
for grp in sorted(unique_groups):
grp_idx = groups_ps[
groups_ps == grp
].index
other_idx = groups_ps[
(groups_ps != grp) & groups_ps.notna()
].index
within = corr_df.loc[grp_idx, grp_idx]
within_vals = within.values[
~np.eye(len(grp_idx), dtype=bool)
]
mean_within = (
np.nanmean(within_vals)
if len(within_vals) > 0
else np.nan
)
if len(other_idx) > 0:
between_vals = corr_df.loc[
grp_idx, other_idx
].values.ravel()
mean_between = np.nanmean(
between_vals
)
else:
mean_between = np.nan
group_rows.append({
"group": grp,
"mean_within": mean_within,
"mean_between": mean_between,
})
group_df = pd.DataFrame(group_rows)
print("Per-group mean correlation:")
print(group_df.to_string(index=False))
print()
# ---- clustermap (center at off-diagonal mean)
g = sns.clustermap(
corr_df,
row_linkage=Z,
col_linkage=Z,
row_colors=row_colors,
col_colors=row_colors if row_colors is not None else None,
cmap=cmap,
center=center_val,
figsize=figsize,
xticklabels=xticklabels,
yticklabels=yticklabels,
cbar_kws={"label": f"{method.capitalize()}"},
)
# ---- add legend for margin_color colors
if legend_handles is not None:
g.ax_heatmap.legend(
handles=legend_handles,
title=margin_color,
bbox_to_anchor=(1.05, 1),
loc='upper left',
borderaxespad=0.,
frameon=False,
)
g.ax_heatmap.set_xlabel("Samples")
g.ax_heatmap.set_ylabel("Samples")
plt.tight_layout()
if show:
plt.show()
if save:
g.savefig(save, dpi=300, bbox_inches="tight")
if ax:
return g.ax_heatmap
[docs]
def hclustv_profiles_heatmap(
adata: ad.AnnData,
selected_vars: list[str] | None = None,
group_by: str | None = None,
summary_method: str = "median",
linkage_method: str = "average",
distance_metric: str = "euclidean",
layer: str | None = None,
zero_to_na: bool = False,
fill_na: float | int | None = None,
skip_na: bool = True,
cmap: str = "coolwarm",
margin_color: bool = False,
order_by: str | None = None,
order: str | list | None = None,
color_scheme: str | dict | Sequence | Colormap | None = None,
row_cluster: bool = True,
col_cluster: bool = True,
cbar_pos: tuple[float, float, float, float] | None = (
0.02, 0.8, 0.05, 0.18
),
tree_kws: dict | None = None,
xticklabels: bool = True,
yticklabels: bool = False,
figsize: tuple[float, float] = (10.0, 8.0),
title: str | None = None,
show: bool = True,
ax: bool = False,
save: str | Path | None = None,
) -> Axes | None:
"""
Plot a clustered heatmap of variable profiles across samples or groups.
Computes z-scores for each variable across samples (or group summaries),
then applies hierarchical clustering to visualize variable expression
patterns.
Parameters
----------
adata : AnnData
:class:`~anndata.AnnData` with proteomics annotations.
selected_vars : list[str] | None
Explicit list of variables to include. When ``None``, all variables
are used.
group_by : str | None
Column in ``adata.obs`` used to group observations. When provided,
computes a summary statistic for each group rather than showing
individual samples.
summary_method : str
Method for computing group summaries when ``group_by`` is specified.
One of ``"median"`` or ``"mean"`` (alias ``"average"``).
linkage_method : str
Linkage criterion passed to :func:`scipy.cluster.hierarchy.linkage`.
distance_metric : str
Distance metric for clustering. One of ``"euclidean"``, ``"manhattan"``,
or ``"cosine"``.
layer : str | None
Optional ``adata.layers`` key to draw quantification values from.
When ``None`` the primary matrix ``adata.X`` is used.
zero_to_na : bool
Replace zeros with ``NaN`` before computing profiles.
fill_na : float | int | None
Replace ``NaN`` values with the specified constant.
skip_na : bool
Skip ``NaN`` values when computing group summaries and z-scores.
cmap : str
Colormap for the heatmap body.
margin_color : bool
Add a color bar between the column dendrogram and the heatmap.
When ``True``, colors by sample (if ``group_by`` is ``None``) or by
group (if ``group_by`` is set).
order_by : str | None
Column in ``adata.obs`` used to order samples (columns). When set,
automatically disables column clustering and orders columns by the
values of this column. Also displays a margin color bar colored by
this column. Cannot be used with ``group_by``.
order : str | list | None
The order by which to present samples, groups, or categories. If
``order_by`` is ``None`` and ``order`` is ``None``, the existing order
is used. If ``order_by`` is ``None`` and ``order`` is not ``None``,
``order`` specifies the column order (samples or groups). If
``order_by`` is not ``None`` and ``order`` is ``None``, the unique
values in ``order_by`` are used (categorical order if categorical,
sorted order otherwise). If ``order_by`` is not ``None`` and
``order`` is not ``None``, ``order`` defines the order of the unique
``order_by`` values. Values not in ``order`` are excluded.
color_scheme : str | dict | Sequence | Colormap | None
Palette specification for the margin color bar, forwarded to
:func:`proteopy.utils.matplotlib._resolve_color_scheme`. Ignored
when neither ``margin_color`` nor ``order_by`` is set.
cbar_pos : tuple of (left, bottom, width, height), optional
Position of the colorbar axes in the figure. Setting to
``None`` will disable the colorbar.
tree_kws : dict, optional
Keyword arguments passed to
:class:`matplotlib.collections.LineCollection` for the
dendrogram lines (e.g. ``colors``, ``linewidths``).
row_cluster : bool
Perform hierarchical clustering on variables (rows).
col_cluster : bool
Perform hierarchical clustering on samples/groups (columns).
xticklabels : bool
Show x-axis tick labels (sample/group names).
yticklabels : bool
Show y-axis tick labels (variable names).
figsize : tuple[float, float]
Matplotlib figure size in inches.
title : str | None
Title for the plot.
show : bool
Display the figure with :func:`matplotlib.pyplot.show`.
ax : bool
Return the heatmap :class:`matplotlib.axes.Axes` when ``True``.
save : str | Path | None
File path for saving the figure.
Returns
-------
Axes or None
Heatmap axes when ``ax`` is ``True``; otherwise ``None``.
"""
check_proteodata(adata)
# Validate summary_method
summary_method = summary_method.lower()
if summary_method == "average":
summary_method = "mean"
if summary_method not in ("median", "mean"):
raise ValueError(
f"summary_method must be 'median' or 'mean', got '{summary_method}'."
)
# Validate distance_metric
distance_metric = distance_metric.lower()
if distance_metric not in ("euclidean", "manhattan", "cosine"):
raise ValueError(
f"distance_metric must be 'euclidean', 'manhattan', or 'cosine', "
f"got '{distance_metric}'."
)
# Map metric names to scipy pdist names
metric_map = {
"euclidean": "euclidean",
"manhattan": "cityblock",
"cosine": "cosine",
}
scipy_metric = metric_map[distance_metric]
# Validate order_by
if order_by is not None:
if group_by is not None:
raise ValueError(
"order_by cannot be used with group_by. When using group_by, "
"columns represent groups, not individual samples."
)
if order_by not in adata.obs.columns:
raise KeyError(f"Column '{order_by}' not found in adata.obs.")
# order_by and col_cluster are mutually exclusive; disable clustering
if col_cluster:
print((
"`order_by` parameter is incompatible with `col_cluster=True`. "
"`col_cluster` has been overridden."
))
col_cluster = False
# Validate order parameter
if order is not None:
if col_cluster:
print((
"`order` parameter is incompatible with `col_cluster=True`. "
"`col_cluster` has been overridden."
))
col_cluster = False
order = list(order)
if order_by is None and group_by is None:
# order specifies sample names
available_samples = set(adata.obs_names)
invalid_samples = [s for s in order if s not in available_samples]
if invalid_samples:
raise KeyError(
f"Samples not found in adata.obs_names: {invalid_samples}"
)
elif group_by is not None:
# order specifies group names; validate against group_by column
available_groups = set(adata.obs[group_by].dropna().unique())
invalid_groups = [g for g in order if g not in available_groups]
if invalid_groups:
raise KeyError(
f"Groups not found in adata.obs['{group_by}']: {invalid_groups}"
)
# Validation for order_by case is done after we have the data
# Extract matrix
if layer is None:
matrix = adata.X
else:
if layer not in adata.layers:
raise KeyError(f"Layer '{layer}' not found in adata.layers.")
matrix = adata.layers[layer]
if matrix is None:
raise ValueError("Selected matrix is empty.")
# Densify if sparse
if sparse.issparse(matrix):
matrix = matrix.toarray()
else:
matrix = np.asarray(matrix)
# Create DataFrame (obs x var)
df = pd.DataFrame(
matrix,
index=adata.obs_names,
columns=adata.var_names,
)
# Filter variables if specified
if selected_vars is not None:
missing_vars = [v for v in selected_vars if v not in df.columns]
if missing_vars:
raise KeyError(
f"Variables not found in adata.var_names: {missing_vars}"
)
df = df[selected_vars]
if zero_to_na:
df = df.replace(0, np.nan)
if fill_na is not None:
df = df.fillna(fill_na)
# Group by if specified
if group_by is not None:
if group_by not in adata.obs.columns:
raise KeyError(f"Column '{group_by}' not found in adata.obs.")
groups = adata.obs[group_by]
df["__group__"] = groups.values
# Compute group summaries
# include_groups=False excludes __group__ from the lambda input
if summary_method == "median":
summary_df = df.groupby("__group__", observed=True).apply(
lambda x: x.median(skipna=skip_na),
include_groups=False,
)
else:
summary_df = df.groupby("__group__", observed=True).apply(
lambda x: x.mean(skipna=skip_na),
include_groups=False,
)
# Transpose to get var x group
profile_df = summary_df.T
else:
# Transpose to get var x obs
profile_df = df.T
# Drop variables with all NaN
profile_df = profile_df.dropna(how="all")
if profile_df.empty:
raise ValueError("No variables remain after removing all-NaN rows.")
# Compute z-scores per variable (row)
row_means = profile_df.mean(axis=1, skipna=skip_na)
row_stds = profile_df.std(axis=1, skipna=skip_na, ddof=0)
row_stds = row_stds.replace(0, np.nan) # avoid division by zero
z_df = profile_df.sub(row_means, axis=0).div(row_stds, axis=0)
# Fill NaN with 0 for clustering
z_df_filled = z_df.fillna(0)
# Order columns based on order_by and/or order
if order_by is not None:
# Get order based on obs column values
order_col_values = adata.obs.loc[z_df_filled.columns, order_by]
if order is not None:
# Validate that order values exist in the order_by column
available_values = set(order_col_values.unique())
invalid_values = [v for v in order if v not in available_values]
if invalid_values:
raise KeyError(
f"Values not found in adata.obs['{order_by}']: {invalid_values}"
)
# Filter to samples whose order_by value is in order, then sort
mask = order_col_values.isin(order)
filtered_cols = z_df_filled.columns[mask]
order_col_values = order_col_values.loc[filtered_cols]
# Create categorical with specified order for sorting
order_col_values = pd.Categorical(
order_col_values,
categories=order,
ordered=True,
)
sorted_idx = (
pd.Series(order_col_values, index=filtered_cols)
.sort_values().index
)
else:
# Use categorical order if categorical, sorted order otherwise
if isinstance(order_col_values.dtype, pd.CategoricalDtype):
cat_order = list(order_col_values.cat.categories)
order_col_values = pd.Categorical(
order_col_values,
categories=cat_order,
ordered=True,
)
sorted_idx = pd.Series(
order_col_values,
index=z_df_filled.columns,
).sort_values().index
else:
sorted_idx = order_col_values.sort_values().index
z_df_filled = z_df_filled[sorted_idx]
elif order is not None:
# order specifies sample or group names directly
# Filter to only columns in order, maintaining order
valid_cols = [c for c in order if c in z_df_filled.columns]
z_df_filled = z_df_filled[valid_cols]
# Build column colors for margin annotation
col_colors = None
col_names = z_df_filled.columns
if order_by is not None:
# Color by the order_by column
categories = adata.obs.loc[col_names, order_by].values
elif margin_color:
# Color by sample or group
categories = col_names
else:
categories = None
if categories is not None:
# Create color palette
unique_cats = pd.Series(categories).unique()
resolved_colors = _resolve_color_scheme(color_scheme, unique_cats)
if resolved_colors is None:
resolved_colors = (
sns.color_palette("husl", n_colors=len(unique_cats))
if len(unique_cats) > 0 else []
)
color_map = dict(zip(unique_cats, resolved_colors))
col_colors = pd.Series(
[color_map[c] for c in categories],
index=col_names,
)
# Create clustermap
clustermap_kws = dict(
method=linkage_method,
metric=scipy_metric,
row_cluster=row_cluster,
col_cluster=col_cluster,
cmap=cmap,
center=0,
figsize=figsize,
xticklabels=xticklabels,
yticklabels=yticklabels,
col_colors=col_colors,
tree_kws=tree_kws,
)
if cbar_pos is not None:
clustermap_kws["cbar_pos"] = cbar_pos
clustermap_kws["cbar_kws"] = {"label": "Z-score"}
else:
clustermap_kws["cbar_pos"] = None
g = sns.clustermap(z_df_filled, **clustermap_kws)
g.ax_heatmap.set_xlabel("")
# Remove y-axis ticks from the margin color bar if present
if g.ax_col_colors is not None:
g.ax_col_colors.set_yticks([])
if title is not None:
g.figure.suptitle(title, y=1.02)
plt.tight_layout()
if save:
g.savefig(save, dpi=300, bbox_inches="tight")
if show:
plt.show()
if ax:
return g.ax_heatmap
return None