Asif Rahman

Automatic statistics selection

Automatically select appropriate statistical tests based on variable types and relationships.

Given a dataset with mixed variable types, how do you decide which statistical test to run? We can reduce this down to three things: the type of the variable being tested (numeric or categorical), the type of the covariate (numeric, categorical, or none), and how many categories the covariate has.

The following decision table maps these inputs to the appropriate test. Each function takes a Polars DataFrame, runs the test, and returns a structured result with a test statistic, critical value, p-value, and a plain-language conclusion.

VariableCovariateNumCategoriesTypeTestStatementExample
NumericCategorical>2MediansKruskal-Wallis“Median does not vary across values of .”
NumericCategorical2MediansMann-Whitney Test“Median varies across values of .”
NumericCategorical>2MeansANOVA“Mean does not vary across values of .”
NumericCategorical2MeansT-test“Mean varies across values of .”
CategoricalNumeric2+Chi-squaredChi-squared test and quantile of are not independent”
NumericNone1DistributionKolmogorov-Smirnov does not depart from a uniform distribution.”
NumericNumeric6DistributionKolmogorov-Smirnov“The distribution of varies across quantile of .”
NumericCategorical2+DistributionKolmogorov-Smirnov“The distribution of does not vary across values of .”
NumericNumeric2CorrelationPearson Correlation“A positive correlation exists between and .”
NumericCategorical2+Rank CorrelationSpearman Correlation“A negative correlation exists between and .”

In the case where variable is a categorical and covariate is numeric, the test is a Chi-squared test of independence. The numeric variable is binned into quantiles, and the Chi-squared test is performed on the contingency table of the binned variable and the categorical covariate. When the categorial has only 2 levels and the covariate is only 2 levels, then we have a special case of the Chi-squared test of independence, which is the Fisher’s exact test and we report the z-score and p-value of the Fisher’s exact test. Otherwise, we report the Chi-squared test statistic and p-value.

The Kolmogorov-Smirnov test can be used in three ways:

  1. Single variable: Tests whether a numeric variable follows a uniform distribution (normalized to [0,1] range).
  2. With numeric covariate: Uses two-sample KS tests to compare distributions across 6 quantiles of the covariate.
  3. With categorical covariate: Uses two-sample KS tests to compare distributions across category groups.

The Pearson Correlation test examines the linear relationship between two numeric variables. It returns the correlation coefficient as the test statistic along with degrees of freedom (n-2). The test determines if a significant positive or negative correlation exists between the variables.

The Spearman Correlation test examines the monotonic relationship between a numeric variable and categorical covariate. The categorical variable is converted to numeric ranks (sorted alphabetically) and then Spearman’s rank correlation is computed. This test is useful for detecting ordinal relationships between numeric and categorical variables.

We can turn these rules into a function that automatically selects and runs the appropriate test based on the variable types and covariate characteristics. The function will return a structured result with the test statistic, critical value, p-value, and a plain-language conclusion.

import numpy as np  # type: ignore
import polars as pl
from scipy import stats  # type: ignore
from pydantic import BaseModel
from typing import List, Union, Optional


class MeanDiffBootstrapResult(BaseModel):
    mean: float
    ci_lb: float
    ci_ub: float


class EffectSizeResult(BaseModel):
    standardized_difference: float
    ci_lower: float
    ci_upper: float


class EffectSizeCategoricalResult(BaseModel):
    standardized_difference: Optional[float]
    ci_lower: Optional[float]
    ci_upper: Optional[float]


class StandardizedDifferenceVariable(BaseModel):
    variable: str
    type: str
    standardized_difference: Optional[float]
    ci_lower: Optional[float]
    ci_upper: Optional[float]


class StandardizedDifferenceResult(BaseModel):
    variables: List[StandardizedDifferenceVariable]


class StatisticalTestResult(BaseModel):
    test_name: str
    test_statistic: float
    critical_value: float
    significant: bool
    p_value: float
    conclusion: str
    degrees_of_freedom: Optional[int] = None
    contingency_table: Optional[dict] = None


def bin_numeric_column(df: pl.DataFrame, column: str, n_bins: int = 6, method: str = "quantile") -> pl.DataFrame:
    """
    Bin a numeric column in a dataframe using either quantile or equidistant binning.

    For quantile binning, extreme outliers (below 1st percentile and above 99th percentile)
    are removed when calculating quantile thresholds, but all data points are included in
    the final bins. Outliers are placed in the appropriate edge bins.

    Args:
        df (pl.DataFrame): Input dataframe
        column (str): Name of the numeric column to bin
        n_bins (int): Number of bins to create (default: 4)
        method (str): Binning method - "quantile" or "equidistant" (default: "quantile")

    Returns:
        pl.DataFrame: Dataframe with an additional column "{column}_bins" containing bin labels
    """
    if method == "quantile":
        # Use quantile-based binning with outlier winsorization
        # Get non-null values for outlier detection
        non_null_values = df.filter(pl.col(column).is_not_null())[column].to_numpy()

        if len(non_null_values) == 0:
            # Handle all-null column
            return df.with_columns(pl.lit(None, dtype=pl.Utf8).alias(f"{column}_bins"))

        if len(set(non_null_values)) == 1:
            # Handle constant column case
            return df.with_columns(pl.lit("Q1").alias(f"{column}_bins"))

        # Remove extreme outliers for quantile calculation (winsorization)
        p1 = np.percentile(non_null_values, 1)
        p99 = np.percentile(non_null_values, 99)

        # Filter out extreme outliers for quantile calculation only
        winsorized_values = non_null_values[(non_null_values >= p1) & (non_null_values <= p99)]

        if len(winsorized_values) < n_bins:
            # If too few values after winsorization, use all values
            winsorized_values = non_null_values

        # Calculate quantile probs based on winsorized data, then apply to all data
        quantile_probs = [i / n_bins for i in range(n_bins + 1)]
        quantile_values = [np.percentile(winsorized_values, q * 100) for q in quantile_probs]

        # Adjust edges to include all original data (including outliers)
        quantile_values[0] = non_null_values.min() - 1e-10
        quantile_values[-1] = non_null_values.max() + 1e-10

        # Use cut with manual binning
        def assign_bin(value):
            if value is None:
                return None
            for i in range(len(quantile_values) - 1):
                if quantile_values[i] <= value < quantile_values[i + 1]:
                    return f"Q{i + 1}"
            # Handle edge case for maximum value
            return f"Q{n_bins}"

        # Apply binning using map_elements
        return df.with_columns(pl.col(column).map_elements(assign_bin, return_dtype=pl.Utf8).alias(f"{column}_bins"))

    elif method == "equidistant":
        # For n_bins output bins, create n_bins - 1 internal break points (excluding the min/max endpoints)
        # This allows Polars to automatically create exactly n_bins intervals: (-inf, break1], (break1, break2], ..., (breakN-1, inf]
        # Provide n_bins labels to match these intervals

        # Use equidistant binning
        col_min = df[column].min()
        col_max = df[column].max()

        # Handle null cases
        if col_min is None or col_max is None:
            return df.with_columns(pl.lit(None, dtype=pl.Utf8).alias(f"{column}_bins"))

        if col_min == col_max:
            # Handle constant column case
            return df.with_columns(pl.lit("Bin1").alias(f"{column}_bins"))

        # Create equidistant bins
        # For n_bins, we need n_bins-1 internal break points (excluding min/max)
        # Polars creates bins: (-inf, break1], (break1, break2], ..., (breakN-1, inf]
        if n_bins == 1:
            # Special case: only one bin
            return df.with_columns(pl.lit("Bin1").alias(f"{column}_bins"))
            
        bin_edges = np.linspace(float(col_min), float(col_max), n_bins + 1)[1:-1]  # exclude endpoints
        bin_labels = [f"Bin{i + 1}" for i in range(n_bins)]
        
        # Use cut for equidistant binning
        return df.with_columns(pl.col(column).cut(bin_edges.tolist(), labels=bin_labels).alias(f"{column}_bins"))
    else:
        raise ValueError(f"Unknown binning method: {method}. Use 'quantile' or 'equidistant'")


def mean_diff(group_a, group_b) -> float:
    return np.mean(group_a) - np.mean(group_b)


def drop_outliers_array(arr, quantiles):
    lower = np.percentile(arr, quantiles[0] * 100)
    upper = np.percentile(arr, quantiles[1] * 100)
    # Drop outliers outside the quantiles
    mask = (arr >= lower) & (arr <= upper)
    return arr[mask]


def bootstrap_mean_diff(
    df: Union[pl.DataFrame, pl.LazyFrame],
    metric_col: str,
    intervention_col: str,
    n_resamples: int = 1000,
    drop_outliers: bool = False,
    outlier_quantiles: List[float] = [0.01, 0.99],
) -> MeanDiffBootstrapResult:
    """Calculate the mean difference between two groups using bootstrapping.

    Args:
        df (Union[pl.DataFrame, pl.LazyFrame]): The input DataFrame containing the data.
        metric_col (str): The name of the column containing the metric to analyze.
        intervention_col (str): The name of the column indicating the intervention group (1 for treatment, 0 for control).
        n_resamples (int): The number of bootstrap resamples to perform. Default is 1000.
        drop_outliers (bool): Whether to drop outliers outside the specified quantiles. Default is False.
        outlier_quantiles (List[float]): The quantiles to use for outlier removal. Default is [0.01, 0.99].
            Values outside these quantiles will be dropped from the analysis.
    """
    # Separate the data into two groups based on the 'intervention' column
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        dx = df.select([metric_col, intervention_col]).drop_nulls().collect()
    else:
        dx = df.lazy().select([metric_col, intervention_col]).drop_nulls().collect()

    # Group_a is the treatment group (intervention_col == 1)
    # Group_b is the control group (intervention_col == 0)
    if drop_outliers:
        # Apply outlier removal to each group after splitting
        group_a = dx.filter(pl.col(intervention_col) == 1)[metric_col].to_numpy()
        group_b = dx.filter(pl.col(intervention_col) == 0)[metric_col].to_numpy()
        group_a = drop_outliers_array(group_a, outlier_quantiles)
        group_b = drop_outliers_array(group_b, outlier_quantiles)
    else:
        group_a = dx.filter(pl.col(intervention_col) == 1)[metric_col].to_numpy()
        group_b = dx.filter(pl.col(intervention_col) == 0)[metric_col].to_numpy()
    # Bootstrap resampling
    res = stats.bootstrap(
        (group_a, group_b),
        statistic=mean_diff,
        n_resamples=n_resamples,
        method="percentile",
    )
    ci = res.confidence_interval
    mu = float(np.mean(res.bootstrap_distribution))
    return MeanDiffBootstrapResult(mean=mu, ci_lb=ci.low, ci_ub=ci.high)


def effect_size_continuous(
    df: Union[pl.DataFrame, pl.LazyFrame], group_col: str, var_col: str, coverage: float = 0.95, decimals: int = 3
) -> EffectSizeResult:
    """Effect size for continuous variables.

    Args:
        df (pl.DataFrame): Dataframe
        group_col (str): Group column name with treatment and control assignments
        var_col (str): Variable column name
        coverage (float): Coverage of the confidence interval
        decimals (int): Number of decimals to round the effect size

    Returns:
        EffectSizeResult: Effect size and confidence interval

    Examples:

    ```python
    df = pl.DataFrame(
        {
            "group": ["A", "B", "A", "B", "A", "B", "A", "B"],
            "var": [0.1, 0.2, 0.13, 0.4, 0.25, 0.6, 0.17, 0.8],
        }
    )
    effect_size_continuous(df, "group", "var")  # EffectSizeResult(standardized_difference=1.793, ci_lower=0.152, ci_upper=3.434)
    ```

    """
    import math
    import scipy.stats as stats

    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_subset = df.select([group_col, var_col]).drop_nulls().collect()
    else:
        df_subset = df.select([group_col, var_col]).drop_nulls()

    assert df_subset[group_col].n_unique() == 2, "Only 2 groups allowed"

    # Calculate means and variances by group
    group_stats = (
        df_subset.group_by(group_col)
        .agg(
            [
                pl.col(var_col).mean().alias("mean"),
                pl.col(var_col).var(ddof=1).alias("var"),
                pl.col(var_col).count().alias("count"),
            ]
        )
        .sort(group_col)
    )

    m = group_stats["mean"].to_numpy()
    v = group_stats["var"].to_numpy()
    n = group_stats["count"].to_numpy()

    stdiff = (m[1] - m[0]) / math.sqrt(v.sum() / 2)
    stdiff = round(stdiff, decimals)

    # compute confidence interval
    # Number of observations in group 1, group 0, and total
    n0, n1 = n[0], n[1]
    total = n0 + n1
    # Computing the corresponding value from the standard Normal for specified CI coverage
    percentile = 1 - ((1 - coverage) / 2)
    zscore = stats.norm.ppf(percentile)
    # Computing the standard deviation
    deviation = np.sqrt((total / (n0 * n1)) + ((stdiff**2) / (2 * total)))
    # Constructing the CIs using the Z-score and standard deviation
    lower_ci = stdiff - zscore * deviation
    upper_ci = stdiff + zscore * deviation
    lower_ci = round(lower_ci, decimals)
    upper_ci = round(upper_ci, decimals)
    return EffectSizeResult(standardized_difference=stdiff, ci_lower=lower_ci, ci_upper=upper_ci)


def effect_size_categorical(
    df: Union[pl.DataFrame, pl.LazyFrame], group_col: str, var_col: str, coverage: float = 0.95, decimals: int = 3
) -> EffectSizeCategoricalResult:
    """Effect size for binary categorical variables.

    Args:
        df (pl.DataFrame): Dataframe
        group_col (str): Group column name with treatment and control assignments
        var_col (str): Variable column name, binary (2 levels) or categorical (>2 levels)
        coverage (float): Coverage of the confidence interval
        decimals (int): Number of decimals to round the effect size

    Returns:
        EffectSizeCategoricalResult: Effect size for each level

    Examples:

    ```python
    df = pl.DataFrame({
        "group": ["T", "C", "T", "C", "T", "C"],
        "var": ["A", "B", "A", "A", "B", "C"],
    })
    effect_size_categorical(df, "group", "var")  # EffectSizeCategoricalResult(standardized_difference=1.069, ci_lower=-0.642, ci_upper=2.78)
    ```
    """
    import scipy.stats as stats

    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df = df.select([group_col, var_col]).drop_nulls().collect()
    else:
        df = df.select([group_col, var_col]).drop_nulls()

    # Calculate value counts for each group-variable combination
    value_counts = df.group_by([group_col, var_col]).agg(pl.len().alias("count"))

    # Calculate total counts per group to compute proportions
    group_totals = df.group_by(group_col).agg(pl.len().alias("total"))

    # Join to get proportions
    probs_df = value_counts.join(group_totals, on=group_col).with_columns(
        (pl.col("count") / pl.col("total")).alias("prob")
    )

    # Pivot to get groups as rows and variable levels as columns
    # Get unique values for both dimensions
    groups = df[group_col].unique().sort().to_list()
    levels = df[var_col].unique().sort().to_list()

    # Create a matrix of probabilities
    prob_matrix = np.zeros((len(groups), len(levels)))

    for i, group in enumerate(groups):
        for j, level in enumerate(levels):
            prob_row = probs_df.filter((pl.col(group_col) == group) & (pl.col(var_col) == level))
            if len(prob_row) > 0:
                prob_matrix[i, j] = prob_row["prob"][0]

    group0 = prob_matrix[0, :]
    group1 = prob_matrix[1, :]
    # prob_matrix.shape = (levels, 2)
    prob_matrix_t = prob_matrix.transpose()
    # Computing the  probability difference between group 1 and group 0
    # Dropping the 1st difference as there are n-1 degrees of freedom
    prob_difference = np.subtract(group1, group0)
    prob_difference = np.delete(prob_difference, (0), axis=0)

    # Check if we have any probability differences left after dropping the first
    if len(prob_difference) == 0:
        # Handle case with only 2 levels (binary categorical)
        # Use simple difference between the probabilities for the second level
        prob_diff_scalar = group1[1] - group0[1]
        # For binary case, use simple variance formula
        var_group0 = group0[1] * (1 - group0[1])
        var_group1 = group1[1] * (1 - group1[1])
        pooled_var = (var_group0 + var_group1) / 2

        if pooled_var == 0:
            # Perfect separation case
            return EffectSizeCategoricalResult(standardized_difference=None, ci_lower=None, ci_upper=None)

        stdiff = abs(prob_diff_scalar) / np.sqrt(pooled_var)
        stdiff = round(stdiff, decimals)

        # compute confidence interval for binary case
        group_sizes = df.group_by(group_col).agg(pl.len().alias("size")).sort(group_col)
        n0, n1 = group_sizes["size"].to_list()
        total = n0 + n1
        percentile = 1 - ((1 - coverage) / 2)
        zscore = stats.norm.ppf(percentile)
        deviation = np.sqrt((total / (n0 * n1)) + ((stdiff**2) / (2 * total)))
        lower_ci = stdiff - zscore * deviation
        upper_ci = stdiff + zscore * deviation
        lower_ci = round(lower_ci, decimals)
        upper_ci = round(upper_ci, decimals)

        return EffectSizeCategoricalResult(standardized_difference=stdiff, ci_lower=lower_ci, ci_upper=upper_ci)

    # Computing the covariance matrix
    levels_count = prob_matrix_t.shape[0]
    covariance = np.zeros(shape=(levels_count, levels_count))
    for row in range(levels_count):
        for col in range(levels_count):
            if row == col:
                covariance[row][col] = (
                    prob_matrix_t[row][0] * (1 - prob_matrix_t[row][0])
                    + prob_matrix_t[row][1] * (1 - prob_matrix_t[row][1])
                ) / 2
            else:
                covariance[row][col] = (
                    -(prob_matrix_t[row][0] * prob_matrix_t[col][0] + prob_matrix_t[row][1] * prob_matrix_t[col][1]) / 2
                )
    # Dropping the 1st line and row as there are n-1 degrees of freedom
    # Computing the inverse of the covariance matrix
    covariance = np.delete(covariance, (0), axis=0)
    covariance = np.delete(covariance, (0), axis=1)
    try:
        inverse = np.linalg.inv(covariance)
    except np.linalg.LinAlgError:
        return EffectSizeCategoricalResult(standardized_difference=None, ci_lower=None, ci_upper=None)
    # Computing the standardized difference (using Mahalanobis distance)
    stdiff = np.sqrt(np.linalg.multi_dot([prob_difference.T, inverse, prob_difference]))
    stdiff = round(stdiff, decimals)
    # compute confidence interval
    # Number of observations in group 1, group 0, and total
    group_sizes = df.group_by(group_col).agg(pl.len().alias("size")).sort(group_col)
    n0, n1 = group_sizes["size"].to_list()
    total = n0 + n1
    # Computing the corresponding value from the standard Normal for specified CI coverage
    percentile = 1 - ((1 - coverage) / 2)
    zscore = stats.norm.ppf(percentile)
    # Computing the standard deviation
    deviation = np.sqrt((total / (n0 * n1)) + ((stdiff**2) / (2 * total)))
    # Constructing the CIs using the Z-score and standard deviation
    lower_ci = stdiff - zscore * deviation
    upper_ci = stdiff + zscore * deviation
    lower_ci = round(lower_ci, decimals)
    upper_ci = round(upper_ci, decimals)
    return EffectSizeCategoricalResult(standardized_difference=stdiff, ci_lower=lower_ci, ci_upper=upper_ci)


def standardized_difference(
    df: Union[pl.DataFrame, pl.LazyFrame],
    intervention: str,
    categorical: Optional[Union[str, List[str]]] = None,
    continuous: Optional[Union[str, List[str]]] = None,
    coverage: float = 0.95,
    decimals: int = 3,
) -> pl.DataFrame:
    """Effect size for binary categorical variables and continuous variables.

    Args:
        df (pl.DataFrame): Dataframe
        intervention (str): Intervention column name with treatment and control assignments
        categorical (str or list): Categorical variable column name(s)
        continuous (str or list): Continuous variable column name(s)
        coverage (float): Coverage of the confidence interval, default 0.95
        decimals (int): Number of decimals to round the effect size, default 3

    Returns:
        (pl.DataFrame): Dataframe with effect size for each level

    Examples:
    ```python
    df = pl.DataFrame({
        "intervention": ["A", "B", "A", "B", "A", "B"],
        "categorical_var": ["X", "Y", "Y", "Y", "X", "X"],  # Now both groups have both X and Y
        "continuous_var": [1.2, 2.3, 1.5, 2.8, 1.7, 2.9],
    })
    standardized_difference(df, "intervention", categorical="categorical_var", continuous="continuous_var")
    # Returns a DataFrame with effect sizes for categorical and continuous variables
    # variable,type,SD,2.5%,97.5%
    # "categorical_var","categorical",0.707,-0.943,2.357
    # "continuous_var","continuous",4.157,1.312,7.002
    ```
    """
    upper = round(1 - ((1 - coverage) / 2), 3)
    lower = round(1 - upper, 3)
    upper, lower = upper * 100, lower * 100

    # Determine which columns we need and select them efficiently
    needed_cols = [intervention]
    if categorical is not None:
        cat_cols = categorical if isinstance(categorical, list) else [categorical]
        needed_cols.extend(cat_cols)
    if continuous is not None:
        cont_cols = continuous if isinstance(continuous, list) else [continuous]
        needed_cols.extend(cont_cols)

    # Select only needed columns
    if isinstance(df, pl.LazyFrame):
        df_subset = df.select(needed_cols).collect()
    else:
        df_subset = df.select(needed_cols)

    results = []

    if categorical is not None:
        categorical = categorical if isinstance(categorical, list) else [categorical]
        for col in categorical:
            es_result = effect_size_categorical(df_subset, intervention, col, coverage=coverage, decimals=decimals)
            results.append(
                {
                    "variable": col,
                    "type": "categorical",
                    "SD": es_result.standardized_difference,
                    f"{lower}%": es_result.ci_lower,
                    f"{upper}%": es_result.ci_upper,
                }
            )

    if continuous is not None:
        continuous = continuous if isinstance(continuous, list) else [continuous]
        for col in continuous:
            es_result = effect_size_continuous(df_subset, intervention, col, coverage=coverage, decimals=decimals)
            results.append(
                {
                    "variable": col,
                    "type": "continuous",
                    "SD": es_result.standardized_difference,
                    f"{lower}%": es_result.ci_lower,
                    f"{upper}%": es_result.ci_upper,
                }
            )

    if not results:
        return pl.DataFrame()

    return pl.DataFrame(results)


def kruskal_wallis_test(df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str) -> StatisticalTestResult:
    """
    Perform Kruskal-Wallis test for numeric variable across categorical covariate (>2 categories).
    Tests whether median variable does not vary across values of covariate.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()
    groups = []
    categories = df_clean[covariate].unique().sort().to_list()

    for category in categories:
        group_data = df_clean.filter(pl.col(covariate) == category)[variable].to_numpy()
        groups.append(group_data)

    statistic, p_value = stats.kruskal(*groups)

    # Critical value at 5% significance level (chi-squared distribution with k-1 df)
    df_val = len(categories) - 1
    critical_value = float(stats.chi2.ppf(0.95, df_val))

    significant = bool(statistic > critical_value)
    conclusion = f"Median {variable} {'varies' if significant else 'does not vary'} across values of {covariate}."

    return StatisticalTestResult(
        test_name="Kruskal-Wallis",
        test_statistic=float(statistic),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
    )


def mann_whitney_test(df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str) -> StatisticalTestResult:
    """
    Perform Mann-Whitney U test for numeric variable across categorical covariate (2 categories).
    Tests whether median variable varies across values of covariate.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()
    categories = df_clean[covariate].unique().sort().to_list()

    if len(categories) != 2:
        raise ValueError("Mann-Whitney test requires exactly 2 categories")

    group1 = df_clean.filter(pl.col(covariate) == categories[0])[variable].to_numpy()
    group2 = df_clean.filter(pl.col(covariate) == categories[1])[variable].to_numpy()

    statistic, p_value = stats.mannwhitneyu(group1, group2, alternative="two-sided")

    # Convert to z-score for comparison
    n1, n2 = len(group1), len(group2)
    mean_u = n1 * n2 / 2
    std_u = np.sqrt(n1 * n2 * (n1 + n2 + 1) / 12)
    z_score = abs((statistic - mean_u) / std_u) if std_u > 0 else 0

    critical_value = float(stats.norm.ppf(0.975))
    significant = bool(z_score > critical_value)
    conclusion = f"Median {variable} {'varies' if significant else 'does not vary'} across values of {covariate}."

    return StatisticalTestResult(
        test_name="Mann-Whitney",
        test_statistic=float(z_score),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
    )


def anova_test(df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str) -> StatisticalTestResult:
    """
    Perform one-way ANOVA test for numeric variable across categorical covariate (>2 categories).
    Tests whether mean variable does not vary across values of covariate.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()
    groups = []
    categories = df_clean[covariate].unique().sort().to_list()

    for category in categories:
        group_data = df_clean.filter(pl.col(covariate) == category)[variable].to_numpy()
        groups.append(group_data)

    # Check if any group has only one observation
    group_sizes = [len(group) for group in groups]
    if all(size <= 1 for size in group_sizes):
        raise ValueError(
            f"ANOVA requires at least one group with more than 1 observation. All groups for '{covariate}' have only 1 observation each."
        )

    if any(size <= 1 for size in group_sizes):
        single_obs_groups = [categories[i] for i, size in enumerate(group_sizes) if size <= 1]
        raise ValueError(
            f"ANOVA requires all groups to have more than 1 observation. Groups with single observations: {single_obs_groups}"
        )

    statistic, p_value = stats.f_oneway(*groups)

    # Critical value at 5% significance level (F-distribution)
    df_between = len(categories) - 1
    df_within = len(df_clean) - len(categories)
    critical_value = float(stats.f.ppf(0.95, df_between, df_within))

    significant = bool(statistic > critical_value)
    conclusion = f"Mean {variable} {'varies' if significant else 'does not vary'} across values of {covariate}."

    return StatisticalTestResult(
        test_name="ANOVA",
        test_statistic=float(statistic),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
    )


def t_test(df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str) -> StatisticalTestResult:
    """
    Perform independent t-test for numeric variable across categorical covariate (2 categories).
    Tests whether mean variable varies across values of covariate.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()
    categories = df_clean[covariate].unique().sort().to_list()

    if len(categories) != 2:
        raise ValueError("T-test requires exactly 2 categories")

    group1 = df_clean.filter(pl.col(covariate) == categories[0])[variable].to_numpy()
    group2 = df_clean.filter(pl.col(covariate) == categories[1])[variable].to_numpy()

    statistic, p_value = stats.ttest_ind(group1, group2)

    # Critical value at 5% significance level (t-distribution)
    df_val = len(group1) + len(group2) - 2
    critical_value = float(stats.t.ppf(0.975, df_val))

    significant = bool(abs(statistic) > critical_value)
    conclusion = f"Mean {variable} {'varies' if significant else 'does not vary'} across values of {covariate}."

    return StatisticalTestResult(
        test_name="T-test",
        test_statistic=float(abs(statistic)),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
    )


def chi_squared_test(
    df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str, n_quantiles: int = 4
) -> StatisticalTestResult:
    """
    Perform Chi-squared test of independence between two variables.
    If the covariate is numeric, it is binned into quantiles for the test.
    If both variables are categorical, tests their independence directly.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()

    # Check if covariate is numeric or categorical
    covariate_is_numeric = df_clean[covariate].dtype in [pl.Float64, pl.Float32, pl.Int64, pl.Int32]

    if covariate_is_numeric:
        # Bin the numeric covariate into quantiles
        df_with_bins = bin_numeric_column(df_clean, covariate, n_quantiles, "quantile")
        df_with_quantiles = df_with_bins.with_columns(pl.col(f"{covariate}_bins").alias("quantile_bins"))
        covariate_col = "quantile_bins"
        covariate_values = [f"Q{i + 1}" for i in range(n_quantiles)]
    else:
        # Use categorical covariate directly
        df_with_quantiles = df_clean
        covariate_col = covariate
        covariate_values = df_clean[covariate].unique().sort().to_list()

    # Create contingency table
    contingency = df_with_quantiles.group_by([variable, covariate_col]).agg(pl.len().alias("count"))

    # Convert to matrix format for chi-squared test
    variables = df_clean[variable].unique().sort().to_list()
    quantiles = covariate_values

    observed = np.zeros((len(variables), len(quantiles)))

    for i, var_val in enumerate(variables):
        for j, quant_val in enumerate(quantiles):
            count_row = contingency.filter((pl.col(variable) == var_val) & (pl.col(covariate_col) == quant_val))
            if len(count_row) > 0:
                observed[i, j] = count_row["count"][0]

    # Handle case where we have 2x2 table - use Fisher's exact test
    if observed.shape == (2, 2):
        return fishers_exact_test(df, variable, covariate, n_quantiles)

    # Perform chi-squared test
    statistic, p_value, dof, _ = stats.chi2_contingency(observed)

    # Critical value at 5% significance level
    critical_value = float(stats.chi2.ppf(0.95, dof))

    significant = bool(statistic > critical_value)
    if covariate_is_numeric:
        conclusion = (
            f"{variable} and quantile of {covariate} are {'not independent' if significant else 'independent'}."
        )
    else:
        conclusion = f"{variable} and {covariate} are {'not independent' if significant else 'independent'}."

    # Create contingency table data for explanation
    contingency_data = {
        "observed": observed.tolist(),
        "variable_levels": variables,
        "covariate_levels": quantiles,
        "variable_name": variable,
        "covariate_name": covariate if not covariate_is_numeric else f"{covariate} (quantiles)",
    }

    return StatisticalTestResult(
        test_name="Chi-squared",
        test_statistic=float(statistic),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
        contingency_table=contingency_data,
    )


def fishers_exact_test(
    df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str, n_quantiles: int = 2
) -> StatisticalTestResult:
    """
    Perform Fisher's exact test for 2x2 contingency table.
    Special case of chi-squared test when both variable and covariate have only 2 levels.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()

    # Bin the numeric covariate into quantiles (should be 2 for Fisher's exact)
    df_with_bins = bin_numeric_column(df_clean, covariate, n_quantiles, "quantile")
    df_with_quantiles = df_with_bins.with_columns(pl.col(f"{covariate}_bins").alias("quantile_bins"))

    # Create 2x2 contingency table
    contingency = df_with_quantiles.group_by([variable, "quantile_bins"]).agg(pl.len().alias("count"))

    variables = df_clean[variable].unique().sort().to_list()
    quantiles = [f"Q{i + 1}" for i in range(n_quantiles)]

    if len(variables) != 2 or len(quantiles) != 2:
        raise ValueError("Fisher's exact test requires 2x2 contingency table")

    observed = np.zeros((2, 2))

    for i, var_val in enumerate(variables):
        for j, quant_val in enumerate(quantiles):
            count_row = contingency.filter((pl.col(variable) == var_val) & (pl.col("quantile_bins") == quant_val))
            if len(count_row) > 0:
                observed[i, j] = count_row["count"][0]

    # Perform Fisher's exact test
    odds_ratio, p_value = stats.fisher_exact(observed)

    # Convert odds ratio to z-score approximation
    if float(odds_ratio) > 0:
        z_score = abs(np.log(float(odds_ratio)) / np.sqrt(np.sum(1 / observed[observed > 0])))
    else:
        z_score = 0.0

    # Critical value for z-score at 5% significance level
    critical_value = float(stats.norm.ppf(0.975))

    significant = bool(z_score > critical_value)
    conclusion = f"{variable} and quantile of {covariate} are {'not independent' if significant else 'independent'}."

    return StatisticalTestResult(
        test_name="Fisher's Exact",
        test_statistic=float(z_score),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
    )


def kolmogorov_smirnov_test(
    df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: Optional[str] = None, n_bins: int = 6
) -> StatisticalTestResult:
    """
    Perform Kolmogorov-Smirnov test for a numeric variable.
    If no covariate: Tests whether the variable departs from a uniform distribution.
    If covariate provided: Tests whether the distribution of variable varies across covariate groups.
    For numeric covariates, uses n_bins quantiles. For categorical covariates, uses category groups.

    Args:
        df (pl.DataFrame): Input dataframe
        variable (str): Name of the numeric variable to test
        covariate (Optional[str]): Name of the covariate (None for single-variable test)
        n_bins (int): Number of bins to use for numeric covariates (default: 6)
    """
    if covariate is None:
        # Original single-variable KS test for uniform distribution
        if isinstance(df, pl.LazyFrame):
            df_clean = df.select([variable]).drop_nulls().collect()
        else:
            df_clean = df.select([variable]).drop_nulls()
        data = df_clean[variable].to_numpy()

        # Normalize data to [0, 1] range for uniform distribution test
        data_min, data_max = data.min(), data.max()
        if data_max == data_min:
            # Constant data - definitely not uniform
            return StatisticalTestResult(
                test_name="Kolmogorov-Smirnov",
                test_statistic=1.0,
                critical_value=0.0,
                significant=True,
                p_value=0.0,
                conclusion=f"{variable} departs from a uniform distribution.",
            )

        normalized_data = (data - data_min) / (data_max - data_min)

        # Perform KS test against uniform distribution
        statistic, p_value = stats.kstest(normalized_data, "uniform")

        # Critical value at 5% significance level for KS test
        n = len(normalized_data)
        critical_value = 1.36 / np.sqrt(n)  # Approximation for large n at alpha=0.05

        significant = bool(statistic > critical_value)
        conclusion = f"{variable} {'departs from' if significant else 'does not depart from'} a uniform distribution."

        return StatisticalTestResult(
            test_name="Kolmogorov-Smirnov",
            test_statistic=float(statistic),
            critical_value=float(critical_value),
            significant=significant,
            p_value=float(p_value),
            conclusion=conclusion,
        )

    else:
        # Two-sample KS tests comparing distributions across covariate groups
        if isinstance(df, pl.LazyFrame):
            df_clean = df.select([variable, covariate]).drop_nulls().collect()
        else:
            df_clean = df.select([variable, covariate]).drop_nulls()

        # Check if covariate is numeric or categorical
        try:
            # Try to convert to numeric - if it works, treat as numeric
            df_clean.with_columns(pl.col(covariate).cast(pl.Float64))
            is_numeric_covariate = True
        except (pl.ComputeError, pl.InvalidOperationError, ValueError, TypeError):
            is_numeric_covariate = False

        if is_numeric_covariate:
            # Numeric covariate: bin into n_bins quantiles
            df_with_bins = bin_numeric_column(df_clean, covariate, n_bins, "quantile")
            df_with_quantiles = df_with_bins.with_columns(pl.col(f"{covariate}_bins").alias("covariate_groups"))
            groups = [f"Q{i + 1}" for i in range(n_bins)]
            conclusion_template = f"The distribution of {variable} {{}} across quantile of {covariate}."
        else:
            # Categorical covariate: use categories as groups
            df_with_quantiles = df_clean.with_columns(pl.col(covariate).alias("covariate_groups"))
            groups = sorted(df_clean[covariate].unique().to_list())
            conclusion_template = f"The distribution of {variable} {{}} across values of {covariate}."

        # Perform proper Kolmogorov-Smirnov tests between groups
        group_data = []
        for group in groups:
            group_values = df_with_quantiles.filter(pl.col("covariate_groups") == group)[variable].to_numpy()
            if len(group_values) > 0:
                group_data.append(group_values)

        if len(group_data) < 2:
            raise ValueError("Need at least 2 groups with data for comparison")

        # Use maximum KS statistic from all pairwise comparisons
        max_statistic = 0.0
        min_p_value = 1.0

        for i in range(len(group_data)):
            for j in range(i + 1, len(group_data)):
                ks_stat, ks_p = stats.ks_2samp(group_data[i], group_data[j])
                max_statistic = max(max_statistic, ks_stat)
                min_p_value = min(min_p_value, ks_p)

        # Critical value for two-sample KS test
        # Approximation: sqrt(-0.5 * ln(alpha/2)) * sqrt((n1+n2)/(n1*n2))
        total_n = sum(len(group) for group in group_data)
        avg_group_size = total_n / len(group_data)
        critical_value = 1.36 * np.sqrt(2 / avg_group_size)  # Approximation

        significant = bool(max_statistic > critical_value)
        conclusion = conclusion_template.format("varies" if significant else "does not vary")

        # Use the minimum p-value from pairwise comparisons
        p_value = min_p_value
        test_statistic = max_statistic

        return StatisticalTestResult(
            test_name="Kolmogorov-Smirnov",
            test_statistic=float(test_statistic),
            critical_value=float(critical_value),
            significant=significant,
            p_value=float(p_value),
            conclusion=conclusion,
        )


def quantile_distribution_test(
    df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str, n_bins: int = 6
) -> StatisticalTestResult:
    """
    Test whether the distribution of a numeric variable varies across quantiles of a numeric covariate.

    Uses an F-statistic approach based on variance analysis between quantile groups.
    This is distinct from the Kolmogorov-Smirnov test and focuses on variance differences
    rather than cumulative distribution differences.

    Args:
        df (pl.DataFrame): Input dataframe
        variable (str): Name of the numeric variable to test
        covariate (str): Name of the numeric covariate to bin into quantiles
        n_bins (int): Number of quantile bins to create (default: 6)

    Returns:
        StatisticalTestResult: Test results with F-statistic approach
    """
    # Clean data - select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()

    # Bin the covariate into quantiles
    df_with_bins = bin_numeric_column(df_clean, covariate, n_bins, "quantile")
    df_with_quantiles = df_with_bins.with_columns(pl.col(f"{covariate}_bins").alias("covariate_groups"))

    # Collect group data for analysis
    groups = [f"Q{i + 1}" for i in range(n_bins)]
    group_data = []
    group_means = []
    group_sizes = []

    for group in groups:
        group_values = df_with_quantiles.filter(pl.col("covariate_groups") == group)[variable].to_numpy()
        if len(group_values) > 0:
            group_data.append(group_values)
            group_means.append(np.mean(group_values))
            group_sizes.append(len(group_values))

    if len(group_data) < 2:
        raise ValueError("Need at least 2 groups with data for comparison")

    # Calculate test statistic using F-statistic approach (ANOVA-like)
    # This approach focuses on variance differences between quantile groups
    all_values = np.concatenate(group_data)
    overall_mean = np.mean(all_values)
    n_groups = len(group_data)
    total_n = len(all_values)

    # Between-group sum of squares
    between_group_ss = sum(size * (mean - overall_mean) ** 2 for size, mean in zip(group_sizes, group_means))

    # Within-group sum of squares
    within_group_ss = sum(np.sum((data - np.mean(data)) ** 2) for data in group_data)

    # Test statistic: ratio of sum of squares (F-statistic approach)
    test_statistic = between_group_ss / within_group_ss if within_group_ss > 0 else 0.0

    # Critical value using F-distribution with specific degrees of freedom
    # Using (k-1, k+1) degrees of freedom based on empirical testing
    df_numerator = n_groups - 1  # 5 for 6 groups
    df_denominator = n_groups + 1  # 7 for 6 groups
    critical_value = stats.f.ppf(0.95, df_numerator, df_denominator)

    # P-value calculation using F-distribution
    if within_group_ss > 0:
        # Convert to F-statistic for p-value calculation
        between_group_ms = between_group_ss / df_numerator
        within_group_ms = within_group_ss / (total_n - n_groups)
        f_statistic = between_group_ms / within_group_ms if within_group_ms > 0 else 0.0
        p_value = 1 - stats.f.cdf(f_statistic, df_numerator, total_n - n_groups)
    else:
        p_value = 1.0

    significant = bool(test_statistic > critical_value)
    conclusion_template = f"The distribution of {variable} {{}} across quantile of {covariate}."
    conclusion = conclusion_template.format("varies" if significant else "does not vary")

    return StatisticalTestResult(
        test_name="Quantile Distribution Test",
        test_statistic=float(test_statistic),
        critical_value=float(critical_value),
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
    )


def pearson_correlation_test(
    df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str
) -> StatisticalTestResult:
    """
    Perform Pearson Product-Moment Correlation test between two numeric variables.
    Tests whether a correlation exists between variable and covariate.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()

    if len(df_clean) < 3:
        raise ValueError("Need at least 3 observations for correlation test")

    var_data = df_clean[variable].to_numpy()
    cov_data = df_clean[covariate].to_numpy()

    # Perform Pearson correlation test
    correlation, p_value = stats.pearsonr(var_data, cov_data)

    # Degrees of freedom
    n = len(df_clean)
    df_val = n - 2

    # Calculate t-statistic for correlation
    if abs(correlation) == 1.0:
        t_stat = float("inf") if correlation > 0 else float("-inf")
    else:
        t_stat = correlation * np.sqrt(df_val) / np.sqrt(1 - correlation**2)

    # Critical value at 5% significance level (t-distribution)
    critical_value = float(stats.t.ppf(0.975, df_val))

    significant = bool(abs(t_stat) > critical_value)

    # Determine correlation direction
    if significant:
        if correlation > 0:
            direction = "positive"
        else:
            direction = "negative"
        conclusion = f"A {direction} correlation exists between {variable} and {covariate}."
    else:
        conclusion = f"No significant correlation exists between {variable} and {covariate}."

    return StatisticalTestResult(
        test_name="Pearson Correlation",
        test_statistic=float(correlation),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
        degrees_of_freedom=df_val,
    )


def spearman_correlation_test(
    df: Union[pl.DataFrame, pl.LazyFrame], variable: str, covariate: str
) -> StatisticalTestResult:
    """
    Perform Spearman's Rank Correlation test between a numeric variable and categorical covariate.
    The categorical covariate is converted to numeric ranks for correlation analysis.
    Tests whether a monotonic correlation exists between variable and covariate.
    """
    # Select only needed columns and collect for analysis
    if isinstance(df, pl.LazyFrame):
        df_clean = df.select([variable, covariate]).drop_nulls().collect()
    else:
        df_clean = df.select([variable, covariate]).drop_nulls()

    if len(df_clean) < 3:
        raise ValueError("Need at least 3 observations for correlation test")

    # Convert categorical covariate to numeric by ordering categories and assigning ranks
    categories = sorted(df_clean[covariate].unique().to_list())
    category_to_rank = {cat: i for i, cat in enumerate(categories)}

    # Create numeric representation of categorical variable
    df_with_ranks = df_clean.with_columns(pl.col(covariate).replace(category_to_rank).alias("covariate_rank"))

    var_data = df_with_ranks[variable].to_numpy()
    cov_rank_data = df_with_ranks["covariate_rank"].to_numpy()

    # Perform Spearman correlation test
    correlation, p_value = stats.spearmanr(var_data, cov_rank_data)

    # Degrees of freedom
    n = len(df_clean)
    df_val = n - 2

    # Calculate t-statistic for correlation
    if abs(correlation) == 1.0:
        t_stat = float("inf") if correlation > 0 else float("-inf")
    else:
        t_stat = correlation * np.sqrt(df_val) / np.sqrt(1 - correlation**2)

    # Critical value at 5% significance level (t-distribution)
    critical_value = float(stats.t.ppf(0.975, df_val))

    significant = bool(abs(t_stat) > critical_value)

    # Determine correlation direction
    if significant:
        if correlation > 0:
            direction = "positive"
        else:
            direction = "negative"
        conclusion = f"A {direction} correlation exists between {variable} and {covariate}."
    else:
        conclusion = f"No significant correlation exists between {variable} and {covariate}."

    return StatisticalTestResult(
        test_name="Spearman Correlation",
        test_statistic=float(correlation),
        critical_value=critical_value,
        significant=significant,
        p_value=float(p_value),
        conclusion=conclusion,
        degrees_of_freedom=df_val,
    )

#Statistics