Source code for bsxplorer.Clusters

from __future__ import annotations

import multiprocessing
import os
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal

import numpy as np
import pandas as pd
import plotly.express as px
import polars as pl
import seaborn as sns
from dynamicTreeCut import cutreeHybrid
from dynamicTreeCut.dynamicTreeCut import get_heights
from fastcluster import linkage
from scipy.cluster.hierarchy import leaves_list, optimal_leaf_ordering
from scipy.spatial.distance import pdist
from scipy.stats import pearsonr

from .Base import MetageneBase, MetageneFilesBase
from .Plots import flank_lines_plotly

default_n_threads = multiprocessing.cpu_count()
os.environ['OPENBLAS_NUM_THREADS'] = f"{default_n_threads}"
os.environ['MKL_NUM_THREADS'] = f"{default_n_threads}"
os.environ['OMP_NUM_THREADS'] = f"{default_n_threads}"
from sklearn.cluster import KMeans, MiniBatchKMeans


# noinspection PyMissingOrEmptyDocstring
class _ClusterBase(ABC):
    @abstractmethod
    def kmeans(self, n_clusters: int = 8, n_init: int = 10, **kwargs):
        ...

    @abstractmethod
    def cut_tree(self, dist_method="euclidean", clust_method="average", cutHeight_q=.99, **kwargs):
        ...

    @abstractmethod
    def all(self):
        ...

    def __merge_strands(self, df: pl.DataFrame):
        return df.filter(pl.col("strand") == "+").vstack(self.__strand_reverse(df.filter(pl.col("strand") == "-")))

    @staticmethod
    def __strand_reverse(df: pl.DataFrame):
        max_fragment = df["fragment"].max()
        return df.with_columns((max_fragment - pl.col("fragment")).alias("fragment"))

    def _process_metagene(
            self,
            metagene: MetageneBase,
            count_threshold=5,
            na_rm: float | None = None
    ) -> (np.ndarray, np.ndarray):
        # Merge strands
        df = self.__merge_strands(metagene.report_df)

        grouped = (
            df.lazy()
            .filter(pl.col("count") > count_threshold)
            .with_columns((pl.col("sum") / pl.col("count")).alias("density"))
            .group_by(["chr", "strand", "gene", "context"])
            .agg([pl.first("id"),
                  pl.first("start"),
                  pl.col("density"),
                  pl.col("fragment"),
                  pl.sum("count").alias("gene_count"),
                  pl.count("fragment").alias("count")])
        ).collect()

        # by_count = grouped.filter(pl.col("gene_count") > (count_threshold * pl.col("count")))
        # print(f"Left after count theshold filtration:\t{len(by_count)}")

        by_count = grouped
        if na_rm is None:
            by_count = grouped.filter(pl.col("count") == metagene.total_windows)
            print(f"Left after empty windows filtration:\t{len(by_count)}")

        if len(by_count) == 0:
            raise ValueError("All genes have empty windows")

        by_count = by_count.explode(["density", "fragment"]).drop(["gene_count", "count"]).fill_nan(0)

        unpivot: pl.DataFrame = (
            by_count
            .sort(["chr", "start"])
            .with_columns(pl.when(pl.col("id").is_null()).then(pl.col("gene")).otherwise(pl.col("id")).alias("name"))
            .pivot(
                index=["chr", "strand", "name"],
                values="density",
                columns="fragment",
                aggregate_function="sum",
                maintain_order=True
            )
            .select(["chr", "strand", "name"] + list(map(str, range(int(metagene.total_windows)))))
            .cast({"name": pl.Utf8})
        )

        if na_rm is None:
            unpivot = unpivot.drop_nulls()
        else:
            unpivot = unpivot.fill_null(na_rm)

        # add id if present
        names = unpivot["name"].to_numpy()
        matrix = unpivot.select(pl.all().exclude(["strand", "chr", "name"])).to_numpy()

        return matrix, names


[docs]class ClusterSingle(_ClusterBase): """Class for operating with single sample regions clustering""" def __init__(self, metagene: MetageneBase, count_threshold=5, na_rm: float | None = None, empty=False): if not empty: self.matrix, self.names = self._process_metagene(metagene, count_threshold, na_rm) self._x_ticks = metagene._x_ticks self._borders = metagene._borders @classmethod def _from_raw(cls, matrix, names, x_ticks, _borders): c = cls(None, empty=True) c.matrix = matrix c.names = names c._x_ticks = x_ticks c._borders = _borders return c
[docs] def kmeans(self, n_clusters: int = 8, n_init: int = 10, **kwargs): """ KMeans clustering on sample regions. Clustering is being made with `sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_. Parameters ---------- n_clusters The number of clusters to generate. n_init Number of times the k-means algorithm is run with different centroid seeds. kwargs See `sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_. Returns ------- :class:`ClusterPlot` """ kmeans = KMeans(n_clusters=n_clusters, n_init=n_init, **kwargs).fit(self.matrix) print(f"Clustering done in {kmeans.n_iter_} iterations") return ClusterPlot(ClusterData.from_kmeans(kmeans, self.names))
[docs] def cut_tree(self, dist_method="euclidean", clust_method="average", cut_height_q=.99, **kwargs): """ KMeans clustering on sample regions. Clustering is being made with `dynamicTreeCut.cutreeHybrid <https://github.com/kylessmith/dynamicTreeCut>`_. Parameters ---------- dist_method Distances calculation metric clust_method Hierarchical clustering method cut_height_q Quantile of leaves height to be cut. kwargs See `dynamicTreeCut <https://github.com/kylessmith/dynamicTreeCut>`_. Returns ------- :class:`ClusterPlot` """ dist = pdist(self.matrix, metric=dist_method) link_matrix = linkage(dist, method=clust_method) cutHeight = np.quantile(get_heights(link_matrix), q=cut_height_q) tree = cutreeHybrid(link_matrix, dist, cutHeight=cutHeight, **kwargs) labels = tree["labels"] return ClusterPlot(ClusterData.from_matrix(self.matrix, labels, self.names))
[docs] def all(self): """ Returns all regions for downstream plotting. Returns ------- :class:`ClusterPlot` """ return ClusterPlot(ClusterData(self.matrix, np.arange(len(self.matrix), dtype=np.int64), self.names))
[docs]class ClusterMany(_ClusterBase): """Class for operating with multiple samples regions clustering""" def __init__(self, metagenes: MetageneFilesBase, count_threshold=5, na_rm: float | None = None): intersect_list = set.intersection(*[set(metagene.report_df["gene"].to_list()) for metagene in metagenes.samples]) for i in range(len(metagenes.samples)): metagenes.samples[i].report_df = metagenes.samples[i].report_df.filter(pl.col("gene").is_in(intersect_list)) self.clusters = [ClusterSingle(metagene, count_threshold, na_rm) for metagene in metagenes.samples] self.sample_names = metagenes.labels
[docs] def compare(self): if len(self.clusters) > 2: raise ValueError("This method is available only for 2 samples") # Match region set a_sample = self.clusters[0] b_sample = self.clusters[1] intersection = list(set.intersection(*map(lambda cluster: set(cluster.names), self.clusters))) intersection.sort() a_order = np.argsort(a_sample.names) b_order = np.argsort(b_sample.names) a_matrix = a_sample.matrix[a_order[np.searchsorted(a_sample.names, intersection, sorter=a_order)], :] b_matrix = b_sample.matrix[b_order[np.searchsorted(b_sample.names, intersection, sorter=b_order)], :] diff_matrix = b_matrix - a_matrix names = a_sample.names[a_order[np.searchsorted(a_sample.names, intersection, sorter=a_order)]] cluster_single = ClusterSingle._from_raw(diff_matrix, names, a_sample._x_ticks, a_sample._borders) return cluster_single
[docs] def kmeans(self, n_clusters: int = 8, n_init: int = 10, **kwargs): """ KMeans clustering on sample regions. Clustering is being made with `sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_. Parameters ---------- n_clusters The number of clusters to generate. n_init Number of times the k-means algorithm is run with different centroid seeds. kwargs See `sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_. Returns ------- :class:`ClusterPlot` """ return ClusterPlot([cluster.kmeans(n_clusters, n_init, **kwargs).data for cluster in self.clusters], self.sample_names)
[docs] def cut_tree(self, dist_method="euclidean", clust_method="average", cut_height_q=.99, **kwargs): """ KMeans clustering on sample regions. Clustering is being made with `dynamicTreeCut.cutreeHybrid <https://github.com/kylessmith/dynamicTreeCut>`_. Parameters ---------- dist_method Distances calculation metric clust_method Hierarchical clustering method cut_height_q Quantile of leaves height to be cut. kwargs See `dynamicTreeCut <https://github.com/kylessmith/dynamicTreeCut>`_. Returns ------- :class:`ClusterPlot` """ return ClusterPlot([ cluster.cut_tree(dist_method="euclidean", clust_method="average", cut_height_q=.99, **kwargs).data for cluster in self.clusters ], self.sample_names)
[docs] def all(self): """ Returns all regions for downstream plotting. Returns ------- :class:`ClusterPlot` """ return ClusterPlot([cluster.all().data for cluster in self.clusters], self.sample_names)
# noinspection PyMissingOrEmptyDocstring class ClusterData: def __init__(self, centers: np.ndarray, labels: np.array, names: list[str] | np.array, ticks: list[int] = None, borders: list[int] = None, matrix: np.ndarray = None): self.centers = centers self.labels = labels self.names = names self.ticks = ticks self.borders = borders self.matrix = matrix @classmethod def from_kmeans(cls, kmeans: KMeans, names: list[str] | np.array): return cls(kmeans.cluster_centers_, kmeans.labels_, names) @classmethod def from_matrix(cls, matrix: np.ndarray, labels: np.array, names: list[str] | np.array, method=Literal["mean", "median", "log1p"]): if method == "mean": agg_fun = lambda matrix: np.mean(matrix, axis=0) elif method == "median": agg_fun = lambda matrix: np.median(matrix, axis=0) elif method == "log1p": agg_fun = lambda matrix: np.log1p(matrix, axis=0) else: agg_fun = lambda matrix: np.mean(matrix, axis=0) modules = np.array([agg_fun(matrix[labels == label, :]) for label in labels]) return cls(modules, labels, names)
[docs]class ClusterPlot: """Class for plotting cluster data.""" def __init__(self, data: ClusterData | list[ClusterData], sample_names=None): if isinstance(data, list) and len(data) == 1: self.data = data[0] else: self.data = data self.sample_names = sample_names
[docs] def save_tsv(self, filename: str): """ Save labels for regions in a TSV file. Parameters ---------- filename File name for output file """ filename = Path(filename) def save(data: ClusterData, path: Path): df = pl.DataFrame(dict(name=list(map(str, data.names)), label=data.labels), schema=dict(name=pl.Utf8, label=pl.Utf8)) df.write_csv(path, include_header=False, separator="\t") if self.sample_names is not None and isinstance(self.data, list): for data, sample_name in zip(self.data, self.sample_names): new_name = filename.name + sample_name save(data, filename.with_name(new_name).with_suffix(".tsv")) if not isinstance(self.data, list): save(self.data, filename.with_suffix(".tsv"))
def __intersect_genes(self): if isinstance(self.data, list): names = [d.names for d in self.data] intersection = set.intersection(*map(set, names)) if len(intersection) < 1: raise ValueError("No same regions between samples") elif len(intersection) < max(map(len, names)): print( f"Found {len(intersection)} intersections between samples with {max(map(len, names))} regions max")
[docs] def draw_mpl(self, method='average', metric='euclidean', cmap: str = "cividis", **kwargs): """ Draws clustermap with seaborn.clustermap. Parameters ---------- method Method for hierarchical clustering. metric Metric for distance calculation cmap Colormap to use **kwargs ``seaborn.clustermap`` parameters See Also -------- `seaborn.clustermap <https://seaborn.pydata.org/generated/seaborn.clustermap.html>`_ : For more information about possible parameters """ if isinstance(self.data, list): warnings.warn("Matplotlib version of cluster plot is not available for multiple samples") return None else: df = pd.DataFrame( self.data.centers, index=[f"{name} ({count})" for name, count in zip(*np.unique(self.data.labels, return_counts=True))]) args = dict(col_cluster=False) | kwargs args |= dict(cmap=cmap, method=method, metric=metric) fig = sns.clustermap(df, **args) return fig
[docs] def draw_plotly(self, method='average', metric='euclidean', cmap: str = "cividis", tick_labels: list[str] = None, **kwargs): """ Draws clustermap with plotly imshow. Parameters ---------- method Method for hierarchical clustering. metric Metric for distance calculation cmap Colormap to use Returns -------- ``plotly.graph_objects.Figure`` """ if isinstance(self.data, list): # order for first sample dist = pdist(self.data[0].centers, metric=metric) link = linkage(dist, method, metric) link = optimal_leaf_ordering(link, dist, metric=metric) order = leaves_list(link) im = np.dstack([d.centers[order, :] for d in self.data]) figure = px.imshow(im, color_continuous_scale=cmap, animation_frame=2, aspect='auto', **kwargs) figure.update_layout(sliders=[{"currentvalue": {"prefix": "Sample = "}}]) if self.sample_names is not None: for step, sample_name in zip(figure.layout.sliders[0].steps, self.sample_names): step.label = sample_name step.name = sample_name return figure else: dist = pdist(self.data.centers, metric=metric) link = linkage(dist, method, metric) link = optimal_leaf_ordering(link, dist, metric=metric) order = leaves_list(link) im = self.data.centers[order, :] ticktext = np.array([f"{label} ({count})" for label, count in zip(*np.unique(self.data.labels, return_counts=True))]) figure = px.imshow(im, color_continuous_scale=cmap, aspect='auto', **kwargs) figure.update_layout( yaxis=dict( tickmode='array', tickvals=list(range(len(order))), ticktext=ticktext[order] ) ) if tick_labels is None: tick_labels = ["Upstream", "", "Body", "", "Downstream"] figure = flank_lines_plotly(figure, self.data.ticks, tick_labels, self.data.borders) return figure
@property def labels(self): return self.data.labels @property def names(self): return self.data.names
[docs] def module_corr(self, module: int = None, p_cutoff: float = None): if self.data.matrix is None: return None if module > self.data.centers.shape[0] - 1: raise ValueError(f"Max cluster index is {self.data.centers.shape[0] - 1}!") else: module_index = self.data.labels == module module_members = self.data.matrix[module_index, :] module_vector = self.data.centers[module, :] cor = np.apply_along_axis(lambda row: pearsonr(row, module_vector), axis=1, arr=module_members).astype(np.float64) res_df = pl.DataFrame( data=np.c_[self.data.names[self.data.labels == module], cor].T.tolist(), schema=dict(name=pl.String, cor=pl.Float64, pvalue=pl.Float64) ) if p_cutoff is not None: res_df = res_df.filter(pl.col("pvalue") <= p_cutoff) return res_df
[docs] def get_module_ids(self, module: int): if module > self.data.centers.shape[0] - 1: raise ValueError(f"Max cluster index is {self.data.centers.shape[0] - 1}!") return self.names[self.labels == module]