import numpy as np
import anndata as ad
from scipy import sparse
from proteopy.utils.anndata import check_proteodata
from proteopy.utils.array import _is_log_transformed_array
def _validate_impute_downshift_input( # noqa: C901
adata,
downshift,
width,
zero_to_na,
inplace,
force,
random_state,
group_by,
verbose,
Y=None,
):
"""Validate and type-check arguments for ``impute_downshift``."""
if not isinstance(adata, ad.AnnData):
raise TypeError(
f"`adata` must be an AnnData object, "
f"got {type(adata).__name__}."
)
if isinstance(downshift, bool) or not isinstance(downshift, (int, float)):
raise TypeError(
f"`downshift` must be a numeric value, "
f"got {type(downshift).__name__}."
)
if isinstance(width, bool) or not isinstance(width, (int, float)):
raise TypeError(
f"`width` must be a numeric value, got {type(width).__name__}."
)
if width <= 0:
raise ValueError("`width` must be positive.")
if not isinstance(zero_to_na, bool):
raise TypeError(
f"`zero_to_na` must be a bool, "
f"got {type(zero_to_na).__name__}."
)
if not isinstance(inplace, bool):
raise TypeError(
f"`inplace` must be a bool, got {type(inplace).__name__}."
)
if not isinstance(force, bool):
raise TypeError(f"`force` must be a bool, got {type(force).__name__}.")
if random_state is not None and not isinstance(random_state, int):
raise TypeError(
f"`random_state` must be an int or None, "
f"got {type(random_state).__name__}."
)
if group_by is not None and not isinstance(group_by, str):
raise TypeError(
f"`group_by` must be a string or None, "
f"got {type(group_by).__name__}."
)
if not isinstance(verbose, bool):
raise TypeError(
f"`verbose` must be a bool, got {type(verbose).__name__}."
)
if group_by is not None:
if group_by not in adata.obs.columns:
raise KeyError(f"`group_by`='{group_by}' not found in adata.obs")
# -- Log-transform check on cleaned matrix so that
# zeros (now NaN) don't bias the heuristic
if not force and Y is not None:
is_log, _ = _is_log_transformed_array(Y)
if not is_log:
raise ValueError(
"Imputation expects log-transformed "
"intensities. Set force=True to "
"proceed nevertheless."
)
def _impute_rows(
Y_imp,
miss_mask,
row_indices,
median,
sd,
downshift,
width,
rng,
):
"""Impute NaNs in the given rows from a downshifted normal.
Draws random samples from
``N(median - downshift*sd, (width*sd)^2)`` and writes them into
``Y_imp`` wherever ``miss_mask`` is ``True``.
Parameters
----------
Y_imp : np.ndarray
Output matrix (obs x vars); imputed values are written
here in-place.
miss_mask : np.ndarray
Boolean mask of shape (obs x vars); ``True`` marks positions
to be imputed.
row_indices : array-like of int
Row indices to process.
median : float
Median of the reference distribution (center before shifting).
sd : float
Standard deviation of the reference distribution.
downshift : float
Number of standard deviations to shift the center leftward.
width : float
Scaling factor applied to ``sd`` to set the sampler width.
rng : np.random.Generator
NumPy random generator used for reproducible sampling.
"""
mu = median - downshift * sd
scale = width * sd
for i in row_indices:
miss = miss_mask[i, :]
if not miss.any():
continue
Y_imp[i, miss] = rng.normal(
loc=mu,
scale=scale,
size=int(miss.sum()),
)
def _impute_by_group(
Y,
Y_imp,
miss_mask,
groups,
g_median,
g_sd,
downshift,
width,
rng,
verbose=False,
):
"""Impute per group, falling back to global stats.
Falls back to ``(g_median, g_sd)`` for any group with fewer than
three finite values or with zero standard deviation (constant
values). When ``verbose`` is ``True``, prints up to the first
five group labels that trigger each fallback type.
"""
max_report = 5
few_values = []
constant = []
for label in groups.unique():
row_idx = np.where(groups == label)[0]
grp_vals = Y[row_idx, :][np.isfinite(Y[row_idx, :])]
if grp_vals.size >= 3:
grp_median = float(np.median(grp_vals))
grp_sd = float(np.std(grp_vals))
if not np.isfinite(grp_sd) or grp_sd <= 0:
grp_median, grp_sd = g_median, g_sd
if len(constant) < max_report:
constant.append(label)
else:
grp_median, grp_sd = g_median, g_sd
if len(few_values) < max_report:
few_values.append(label)
_impute_rows(
Y_imp,
miss_mask,
row_idx,
grp_median,
grp_sd,
downshift,
width,
rng,
)
if verbose:
if few_values:
labels = ", ".join(f"'{x}'" for x in few_values)
print(
f"Groups with fewer than 3 finite values "
f"(first {len(few_values)}): {labels}; "
f"falling back to global stats."
)
if constant:
labels = ", ".join(f"'{x}'" for x in constant)
print(
f"Groups with zero standard deviation "
f"(constant values; first {len(constant)}): "
f"{labels}; falling back to global stats."
)
def _store_downshift_imputation_metadata(
target,
miss_mask,
n_missing,
width,
downshift,
group_by,
random_state,
):
"""Write imputation mask and run metadata to ``target``."""
target.layers["imputation_mask_X"] = miss_mask.astype(bool)
target.uns.setdefault("imputation", {})
target.uns["imputation"].update(
dict(
method="downshift_normal",
width=float(width),
downshift=float(downshift),
group_by=group_by,
random_state=(None if random_state is None else int(random_state)),
n_imputed=int(n_missing),
pct_imputed=float(n_missing / miss_mask.size * 100.0),
)
)
[docs]
def impute_downshift(
adata,
zero_to_na: bool = False,
downshift: float = 1.8,
width: float = 0.3,
group_by: str | None = None,
inplace: bool = True,
force: bool = False,
random_state: int | None = 42,
verbose: bool = False,
):
"""Impute missing values via a downshifted Gaussian.
Replaces ``NaN`` (and optionally zero) entries by sampling from a
Gaussian centered at ``median - downshift * sd`` with standard
deviation ``width * sd``, simulating expression signals below the
detection limit as popularised by the Perseus platform [1]_.
The median and standard deviation are estimated from the observed
values of the global distribution or distributions defined
by the ``group_by`` parameter:
- ``group_by=None`` — global distribution (all finite values
in ``.X``). Recommended when sample-level distributions are similar.
- ``group_by=<obs column>`` — per-group distribution pooled across
all samples sharing the same label in that column.
When ``group_by`` is set and a group contains fewer than three
finite values, or its finite values are all constant (zero
standard deviation), the global distribution (all finite values
in ``.X``) is used as a fallback for that group.
The function records an imputation mask in
``.layers["imputation_mask_X"]`` (``True`` where values were
imputed) and stores run metadata in ``.uns["imputation"]``.
It is recommended to work on the log-transformed intensities space.
Parameters
----------
adata : ad.AnnData
Proteodata-formatted :class:`~anndata.AnnData`.
zero_to_na : bool, optional
If ``True``, replace zeros in ``.X`` with ``NaN`` before
imputation so they are treated as missing values.
downshift : float, optional
Number of standard deviations to shift the distribution
center leftward from the observed median.
width : float, optional
Scaling factor applied to the observed standard deviation to
set the width of the sampling distribution.
group_by : str | None, optional
Column in ``adata.obs`` defining groups over which the
reference distribution is pooled. When ``None``, the global
distribution across all samples is used.
inplace : bool, optional
If ``True``, modify ``adata`` in place and return ``None``.
If ``False``, return an imputed copy without altering ``adata``.
force : bool, optional
If ``False``, raise a ``ValueError`` when the data are
detected as non-log-transformed. Set to ``True`` to bypass
this check and impute regardless.
random_state : int | None, optional
Seed for the NumPy random generator. Pass ``None`` for a
non-deterministic run.
verbose : bool, optional
If ``True``, print summary statistics (measured / imputed
counts) and, when ``group_by`` is set, up to the first five
groups that trigger each per-group fallback to global stats.
Returns
-------
ad.AnnData | None
Imputed ``AnnData`` when ``inplace=False``; ``None`` otherwise.
The returned or modified object contains:
- ``.X`` — imputed intensity matrix (sparse if input was
sparse).
- ``.layers["imputation_mask_X"]`` — boolean mask; ``True``
marks positions that were imputed.
- ``.uns["imputation"]`` — dict with keys ``method``,
``downshift``, ``width``, ``group_by``, ``random_state``,
``n_imputed``, and ``pct_imputed``.
Raises
------
TypeError
If any argument has an unexpected type.
ValueError
If ``width`` is not positive, fewer than three finite values
exist globally, the global finite values are constant (zero
standard deviation), or the data appear non-log-transformed
and ``force=False``.
KeyError
If ``group_by`` is not a column in ``adata.obs``.
References
----------
.. [1] Tyanova S, Temu T, Sinitcyn P, Carlson A, Hein MY,
Geiger T, Mann M, and Cox J. "The Perseus computational
platform for comprehensive analysis of (prote)omics
data." *Nature Methods*, 2016, 13(9):731-740.
Examples
--------
>>> import numpy as np
>>> import proteopy as pr
>>> adata = pr.datasets.karayel_2020()
>>> adata.layers["raw"] = adata.X
>>> adata.X[adata.X == 0] = np.nan
>>> adata.X = np.log2(adata.X)
Simple imputation as popularized by Tyanova et. al 2016
(downshift=1.8, width=0.3)
>>> pr.pp.impute_downshift(adata)
Impute by drawing from sample-level Gaussian distributions
instead of global:
>>> pr.pp.impute_downshift(adata, group_by="sample_id")
"""
check_proteodata(adata)
Xsrc = adata.X
was_sparse = sparse.issparse(Xsrc)
X = Xsrc.toarray() if was_sparse else np.asarray(Xsrc)
X = X.astype(float, copy=True)
# -- Build working matrix (NaN = missing)
Y = X.copy()
if zero_to_na:
Y[Y == 0] = np.nan
Y[~np.isfinite(Y)] = np.nan
_validate_impute_downshift_input(
adata,
downshift,
width,
zero_to_na,
inplace,
force,
random_state,
group_by,
verbose,
Y=Y,
)
miss_mask = ~np.isfinite(Y)
n_missing = int(miss_mask.sum())
rng = np.random.default_rng(random_state)
# -- Global fallback stats
y_finite = Y[np.isfinite(Y)]
if y_finite.size < 3:
raise ValueError(
"Not enough finite values to estimate imputation parameters."
)
g_median = float(np.median(y_finite))
g_sd = float(np.std(y_finite))
if not np.isfinite(g_sd) or g_sd <= 0:
raise ValueError(
"Global standard deviation is zero or "
"non-finite; cannot estimate imputation "
"parameters. The data may lack variation."
)
# -- Imputation
Y_imp = Y.copy()
if group_by is None:
_impute_rows(
Y_imp,
miss_mask,
range(Y.shape[0]),
g_median,
g_sd,
downshift,
width,
rng,
)
else:
_impute_by_group(
Y,
Y_imp,
miss_mask,
adata.obs[group_by],
g_median,
g_sd,
downshift,
width,
rng,
verbose=verbose,
)
Z_out = sparse.csr_matrix(Y_imp) if was_sparse else Y_imp
if verbose:
total = miss_mask.size
measured_n = total - n_missing
print(
f"Measured: {measured_n:,} values "
f"({100 * measured_n / total:.1f}%)"
)
print(
f"Imputed: {n_missing:,} values "
f"({100 * n_missing / total:.1f}%)"
)
if not inplace:
adata_out = adata.copy()
adata_out.X = Z_out
_store_downshift_imputation_metadata(
adata_out,
miss_mask,
n_missing,
width,
downshift,
group_by,
random_state,
)
check_proteodata(adata_out)
return adata_out
else:
adata.X = Z_out
_store_downshift_imputation_metadata(
adata,
miss_mask,
n_missing,
width,
downshift,
group_by,
random_state,
)
check_proteodata(adata)
return None