Source code for bsxplorer.UniversalReader_batches

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import OrderedDict
from copy import deepcopy
from typing import Literal

import polars as pl
import pyarrow as pa

from .utils import polars2arrow_convert

ARROW_SCHEMAS = {
    "bedgraph": pa.schema([
        ("chr", pa.utf8()),
        ("start", pa.uint64()),
        ("end", pa.uint64()),
        ("density", pa.float32()),
    ]),
    "coverage": pa.schema([
        ("chr", pa.utf8()),
        ("position", pa.uint64()),
        ("end", pa.uint64()),
        ("density", pa.float32()),
        ("count_m", pa.uint32()),
        ("count_um", pa.uint32())
    ]),
    "cgmap": pa.schema([
        ("chr", pa.utf8()),
        ("nuc", pa.utf8()),  # G/C
        ("position", pa.uint64()),
        ("context", pa.utf8()),
        ("dinuc", pa.utf8()),
        ("density", pa.float32()),
        ("count_m", pa.uint32()),
        ("count_total", pa.uint32()),
    ]),
    "bismark": pa.schema([
        ("chr", pa.utf8()),
        ("position", pa.uint64()),
        ("strand", pa.utf8()),
        ("count_m", pa.uint32()),
        ("count_um", pa.uint32()),
        ("context", pa.utf8()),
        ("trinuc", pa.utf8())
    ]),
    "binom": pa.schema([
        ("chr", pa.utf8()),
        ("strand", pa.utf8()),
        ("position", pa.uint64()),
        ("context", pa.utf8()),
        ("p_value", pa.float64()),
    ])
}

ReportTypes = Literal["bismark", "cgmap", "bedgraph", "coverage", "binom"]
REPORT_TYPES_LIST = ["bismark", "cgmap", "bedgraph", "coverage", "binom"]


# noinspection PyMissingOrEmptyDocstring
class BaseBatch(ABC):
    def __init__(self, df: pl.DataFrame):
        self.data = self.get_validated(df)

    def get_validated(self, df: pl.DataFrame):
        if all(c in df.columns for c in self.colnames()):
            try:
                return df.select(self.colnames()).cast(self.pl_schema())
            except Exception as e:
                raise pl.SchemaError(e)
        else:
            raise KeyError(f"Not all columns from schema in batch (missing {list(set(self.colnames()) - set(df.columns))})")

    @classmethod
    @abstractmethod
    def pl_schema(cls) -> OrderedDict:
        ...

    @classmethod
    def pa_schema(cls) -> pa.Schema:
        return polars2arrow_convert(cls.pl_schema())

    @classmethod
    def colnames(cls):
        return list(UniversalBatch.pl_schema().keys())

    def to_arrow(self):
        return self.data.to_arrow().cast(self.pa_schema())

# noinspection PyMissingOrEmptyDocstring

# noinspection PyMissingOrEmptyDocstring
class ConvertedBatch(BaseBatch):
    @classmethod
    @abstractmethod
    def from_full(cls, full_batch: UniversalBatch):
        ...


[docs]class UniversalBatch(BaseBatch): """ Class for storing and converting methylation report data. """
[docs] @classmethod def pl_schema(cls) -> OrderedDict: """ Returns ------- collections.OrderedDict Dictionary with column names and data types for polars. """ return OrderedDict( chr=pl.Utf8, strand=pl.Utf8, position=pl.UInt64, context=pl.Utf8, trinuc=pl.Utf8, count_m=pl.UInt32, count_total=pl.UInt32, density=pl.Float64 )
[docs] def __init__(self, data: pl.DataFrame, raw: pa.Table | pa.RecordBatch | None): super().__init__(data) self.raw = raw
def __len__(self): return len(self.data)
[docs] def filter_not_none(self): """ Filter cytosines which have 0 reads. """ self.data = self.data.filter(pl.col("density").is_not_nan()) return self
[docs] def to_bismark(self): """ Returns ------- polars.DataFrame DataFrame with methylation data as Bismark methylation report. """ converted = ( self.data .select([ "chr", "position", "strand", "count_m", (pl.col("count_total") - pl.col("count_m")).alias("count_um"), "context", "trinuc" ]) ) return converted
[docs] def to_cgmap(self): """ Returns ------- polars.DataFrame DataFrame with methylation data as CGmap. """ converted = ( self.data .select([ "chr", pl.when(strand="+").then(pl.lit("C")).otherwise(pl.lit("G")).alias("nuc"), "position", "context", pl.col("trinuc").str.slice(0, 2).alias("dinuc"), "density", "count_m", "count_total" ]) ) return converted
[docs] def to_coverage(self): """ Returns ------- polars.DataFrame DataFrame with methylation data as coverage. """ converted = ( self.data .filter(pl.col("count_total") != 0) .select([ "chr", (pl.col("position")).alias("start"), (pl.col("position") + 1).alias("end"), "density", "count_m", (pl.col("count_total") - pl.col("count_m")).alias("count_um") ]) ) return converted
[docs] def to_bedGraph(self): """ Returns ------- polars.DataFrame DataFrame with methylation data as bedGraph. """ converted = ( self.data .filter(pl.col("count_total") != 0) .select([ "chr", (pl.col("position") - 1).alias("start"), (pl.col("position")).alias("end"), (pl.col("count_m") / pl.col("count_total")).alias("density") ]) ) return converted
[docs] def filter_data(self, **kwargs): """ Filter data by expression or keyword arguments Parameters ---------- kwargs keywords arguements to pass to `polars.filter() <https://docs.pola.rs/api/python/version/0.20/reference/dataframe/api/polars.DataFrame.filter.html#polars.DataFrame.filter>`_ Returns ------- UniversalBatch """ new = deepcopy(self) new.data = new.data.filter(**kwargs) return new