Source code for bsxplorer.UniversalReader_classes

from __future__ import annotations

import gc
import gzip
import os
import shutil
import tempfile
import warnings
from abc import abstractmethod
from collections import defaultdict
from fractions import Fraction
from pathlib import Path

import func_timeout
import numpy as np
import polars as pl
import pyarrow as pa
from pyarrow import csv as pcsv, parquet as pq

from .UniversalReader_batches import UniversalBatch, ARROW_SCHEMAS, ReportTypes, REPORT_TYPES_LIST
from .utils import ReportBar


# noinspection PyMissingOrEmptyDocstring
def invalid_row_handler(row):
    print(f"Got invalid row: {row}")
    return "skip"


# noinspection PyMissingOrEmptyDocstring
@func_timeout.func_set_timeout(20)
def open_csv(
        file: str | Path,
        read_options: pcsv.ReadOptions = None,
        parse_options: pcsv.ParseOptions = None,
        convert_options: pcsv.ConvertOptions = None,
        memory_pool=None
):
    return pcsv.open_csv(
        file,
        read_options,
        parse_options,
        convert_options,
        memory_pool
    )


# noinspection PyMissingOrEmptyDocstring
class ArrowParquetReader:
    def __init__(self,
                 file: str | Path,
                 use_cols: list = None,
                 use_threads: bool = True,
                 **kwargs):
        self.file = Path(file).expanduser().absolute()
        if not self.file.exists():
            raise FileNotFoundError()

        self.reader = pq.ParquetFile(file)
        self.__current_group = 0

        self.__use_cols = use_cols
        self.__use_threads = use_threads

    def __len__(self):
        return self.reader.num_row_groups

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.reader.close()

    def __iter__(self):
        return self

    def _mutate_next(self, batch: pa.RecordBatch):
        return batch

    def __next__(self) -> UniversalBatch:
        old_group = self.__current_group
        # Check if it is the last row group
        if old_group < self.reader.num_row_groups:
            self.__current_group += 1

            batch = self.reader.read_row_group(
                old_group,
                columns=self.__use_cols,
                use_threads=self.__use_threads
            )

            mutated = self._mutate_next(batch)
            return mutated
        raise StopIteration()

    @property
    def batch_size(self):
        return int(self.file.stat().st_size / self.reader.num_row_groups)


# noinspection PyMissingOrEmptyDocstring
class BinomReader(ArrowParquetReader):
    """
    Class for reading methylation data from .parquet files generated by :class:`BinomialData`.

    Parameters
    ----------
    file
        Path to the .parquet file.
    methylation_pvalue
        Pvalue with which cytosine will be considered methylated.
    """

    def __init__(self, file: str | Path, methylation_pvalue=.05, **kwargs):
        super().__init__(file, **kwargs)

        if 0 < methylation_pvalue <= 1:
            self.methylation_pvalue = methylation_pvalue
        else:
            self.methylation_pvalue = 0.05
            warnings.warn(
                f"P-value needs to be in (0;1] interval, not {methylation_pvalue}. Setting to default ({self.methylation_pvalue})")

    def _mutate_next(self, batch: pa.RecordBatch):
        df = pl.from_arrow(batch)

        mutated = (
            df
            .with_columns(
                (pl.col("p_value") <= self.methylation_pvalue).cast(pl.UInt8).alias("count_m")
            )
            .with_columns([
                pl.col("context").alias("trinuc"),
                pl.lit(1).alias("count_total"),
                pl.col("count_m").cast(pl.Float64).alias("density")
            ])
        )

        return UniversalBatch(mutated, batch)


# noinspection PyMissingOrEmptyDocstring
class ArrowReaderCSV:

    __slots__ = ["file", "batch_size", "_current_batch", "reader"]

    @property
    @abstractmethod
    def pa_schema(self):
        ...

    @abstractmethod
    def convert_options(self, **kwargs):
        ...

    @abstractmethod
    def parse_options(self, **kwargs):
        ...

    @abstractmethod
    def read_options(self, **kwargs):
        ...

    def __init__(self, file, block_size_mb):
        self.file = file
        self.batch_size = block_size_mb * 1024 ** 2
        self._current_batch = None

    def __enter__(self):
        return self

    def __iter__(self):
        return self

    def get_reader(
            self,
            read_options: pcsv.ReadOptions = None,
            parse_options: pcsv.ParseOptions = None,
            convert_options: pcsv.ConvertOptions = None,
            memory_pool=None
    ) -> pcsv.CSVStreamingReader:
        """
        This function is needed, because if Arrow tries to open CSV with wrong schema, it gets stuck on it forever.
        This function makes a timout on initializing reader
        """
        try:
            reader = open_csv(
                self.file,
                read_options,
                parse_options,
                convert_options,
                memory_pool
            )

            return reader
        except pa.ArrowInvalid as e:
            print(f"Error opening file: {self.file}")
            raise e
        except func_timeout.exceptions.FunctionTimedOut:
            print(
                "Time for oppening file exceeded. Check if input type is correct or try making your batch size smaller.")
            os._exit(0)

    def _mutate_next(self, batch: pa.RecordBatch) -> UniversalBatch:
        return UniversalBatch(pl.from_arrow(batch), batch)

    def __next__(self):
        try:
            raw = self.reader.read_next_batch()
            batch = self._mutate_next(raw)
            return batch
        except pa.ArrowInvalid as e:
            print(e)
            return self.__next__()
        except StopIteration:
            raise StopIteration

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.reader.close()


# noinspection PyMissingOrEmptyDocstring
class BismarkReader(ArrowReaderCSV):
    pa_schema = ARROW_SCHEMAS["bismark"]

    def convert_options(self):
        return pcsv.ConvertOptions(
            column_types=self.pa_schema,
            strings_can_be_null=False
        )

    def parse_options(self):
        return pcsv.ParseOptions(
            delimiter="\t",
            quote_char=False,
            escape_char=False,
            ignore_empty_lines=True,
            invalid_row_handler=lambda _: "skip"
        )

    def read_options(self, use_threads=True, block_size=None):
        return pcsv.ReadOptions(
            use_threads=use_threads,
            block_size=block_size,
            skip_rows=0,
            skip_rows_after_names=0,
            column_names=self.pa_schema.names
        )

    def __init__(
            self,
            file: str | Path,
            use_threads: bool = True,
            block_size_mb: int = 50,
            memory_pool: pa.MemoryPool = None,
            **kwargs
    ):
        super().__init__(file, block_size_mb)
        self.reader = self.get_reader(
            self.read_options(use_threads=use_threads, block_size=self.batch_size),
            self.parse_options(),
            self.convert_options(),
            memory_pool=memory_pool
        )

    def _mutate_next(self, batch: pa.RecordBatch):
        mutated = (
            pl.from_arrow(batch)
            .with_columns((pl.col("count_m") + pl.col("count_um")).alias("count_total"))
            .with_columns((pl.col("count_m") / pl.col("count_total")).alias("density"))
        )
        return UniversalBatch(mutated, batch)


# noinspection PyMissingOrEmptyDocstring
class CGMapReader(BismarkReader):
    pa_schema = ARROW_SCHEMAS["cgmap"]

    def _mutate_next(self, batch: pa.RecordBatch):
        mutated = (
            pl.from_arrow(batch)
            .with_columns([
                pl.when(pl.col("nuc") == "C").then(pl.lit("+")).otherwise(pl.lit("-")).alias("strand"),
                pl.when(pl.col("dinuc") == "CG")
                # if dinuc == "CG" => trinuc = "CG"
                .then(pl.col("dinuc"))
                # otherwise trinuc = dinuc + context_last
                .otherwise(pl.concat_str([pl.col("dinuc"), pl.col("context").str.slice(-1)]))
                .alias("trinuc")
            ])
        )
        return UniversalBatch(mutated, batch)


# noinspection PyMissingOrEmptyDocstring
class CoverageReader(ArrowReaderCSV):
    pa_schema = ARROW_SCHEMAS["coverage"]

    read_options = BismarkReader.read_options
    parse_options = BismarkReader.parse_options
    convert_options = BismarkReader.convert_options

    def __init__(
            self,
            file: str | Path,
            cytosine_file: str | Path,
            use_threads: bool = True,
            block_size_mb: int = 50,
            memory_pool: pa.MemoryPool = None,
            **kwargs
    ):
        super().__init__(file, block_size_mb)
        self.reader = self.get_reader(
            self.read_options(use_threads=use_threads, block_size=self.batch_size),
            self.parse_options(),
            self.convert_options(),
            memory_pool=memory_pool
        )

        self.cytosine_file = cytosine_file
        self._sequence_rows_read = 0
        self._sequence_metadata = pq.read_metadata(cytosine_file)

    def _align(self, pa_batch: pa.RecordBatch):
        batch = pl.from_arrow(pa_batch)

        # get batch stats
        batch_stats = batch.group_by("chr").agg([
            pl.col("position").max().alias("max"),
            pl.col("position").min().alias("min")
        ])

        output = None

        for chrom in batch_stats["chr"]:
            chrom_min, chrom_max = batch_stats.filter(chr=chrom).select(["min", "max"]).row(0)

            # Read and filter sequence
            filters = [
                ("chr", "=", chrom),
                ("position", ">=", chrom_min),
                ("position", "<=", chrom_max)
            ]
            pa_filtered_sequence = pq.read_table(self.cytosine_file, filters=filters)

            modified_schema = pa_filtered_sequence.schema
            modified_schema = modified_schema.set(1, pa.field("context", pa.utf8()))
            modified_schema = modified_schema.set(2, pa.field("chr", pa.utf8()))

            pa_filtered_sequence = pa_filtered_sequence.cast(modified_schema)

            pl_sequence = (
                pl.from_arrow(pa_filtered_sequence)
                .with_columns([
                    pl.when(pl.col("strand") == True).then(pl.lit("+"))
                    .otherwise(pl.lit("-"))
                    .alias("strand")
                ])
                .cast(dict(position=batch.schema["position"]))
            )

            # Align batch with sequence
            batch_filtered = batch.filter(pl.col("chr") == chrom)
            aligned = (
                batch_filtered.lazy()
                .cast({"position": pl.UInt64})
                .set_sorted("position")
                .join(pl_sequence.lazy(), on="position", how="left")
            ).collect()

            self._sequence_rows_read += len(pl_sequence)

            if output is None:
                output = aligned
            else:
                output.extend(aligned)

            gc.collect()

        return output

    def _mutate_next(self, batch: pa.RecordBatch):
        mutated = (
            self._align(batch)
            .with_columns([
                (pl.col("count_m") + pl.col("count_um")).alias("count_total"),
                pl.col("context").alias("trinuc")
            ])
            .with_columns((pl.col("count_m") / pl.col("count_total")).alias("density"))
        )
        return UniversalBatch(mutated, batch)


# noinspection PyMissingOrEmptyDocstring
class BedGraphReader(CoverageReader):
    pa_schema = ARROW_SCHEMAS["bedgraph"]

    def __init__(self, file: str | Path, cytosine_file: str | Path, max_count=100, **kwargs):
        super().__init__(file, cytosine_file, **kwargs)
        self.max_count = max_count

    def _get_fraction(self, df: pl.DataFrame):
        def fraction(n, limit):
            f = Fraction(n).limit_denominator(limit)
            return f.numerator, f.denominator

        fraction_v = np.vectorize(fraction)
        converted = fraction_v((df["density"] / 100).to_list(), self.max_count)

        final = (
            df.with_columns([
                pl.lit(converted[0]).alias("count_m"),
                pl.lit(converted[1]).alias("count_total")
            ])
        )

        return final

    def _mutate_next(self, batch: pa.RecordBatch):
        aligned = self._align(batch)
        fractioned = self._get_fraction(aligned)
        full = fractioned.with_columns(pl.col("context").alias("trinuc"))

        return UniversalBatch(full, batch)


[docs]class UniversalReader(object): """ Class for batched reading methylation reports. Parameters ---------- file Path to the methylation report file. report_type Type of the methylation report. Possible types: "bismark", "cgmap", "bedgraph", "coverage", "binom" use_threads Will reading be multithreaded. bar Indicate the progres bar while reading. kwargs Specific keyword arguments to the reader. Other Parameters ---------------- cytosine_file: str | Path Instance of :class:`Sequence` for reading bedGraph or .coverage reports. methylation_pvalue: float Pvalue with which cytosine will be considered methylated. block_size_mb: int Size of batch in bytes, which will be read from report file (for report types other than "binom"). Examples -------- >>> reader = UniversalReader("path/to/file.txt", report_type="bismark", use_threads=True) >>> for batch in reader: >>> do_something(batch) """
[docs] def __init__( self, file: str | Path, report_type: ReportTypes, use_threads: bool = False, bar: bool = True, **kwargs ): super().__init__() self.__validate(file, report_type) self.__decompressed = self.__decompress() if self.__decompressed is not None: self.file = Path(self.__decompressed.name) self.memory_pool = pa.system_memory_pool() self.reader = self.__readers[report_type](file, use_threads=use_threads, memory_pool=self.memory_pool, **kwargs) self.report_type = report_type self._bar_on = bar self.bar = None
__readers = { "bismark": BismarkReader, "cgmap": CGMapReader, "bedgraph": BedGraphReader, "coverage": CoverageReader, "binom": BinomReader } def __validate(self, file, report_type): file = Path(file).expanduser().absolute() if not file.exists(): raise FileNotFoundError(file) if not file.is_file(): raise IsADirectoryError(file) self.file = file if report_type not in self.__readers.keys(): raise KeyError(report_type) self.report_type: str = report_type def __decompress(self): if ".gz" in self.file.suffixes: temp_file = tempfile.NamedTemporaryFile(dir=self.file.parent, prefix=self.file.stem) print(f"Temporarily unpack {self.file} to {temp_file.name}") with gzip.open(self.file, mode="rb") as file: shutil.copyfileobj(file, temp_file) return temp_file def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.__decompressed is not None: self.__decompressed.close() if self.bar is not None: self.bar.goto(self.bar.max) self.bar.finish() self.reader.__exit__(exc_type, exc_val, exc_tb) self.memory_pool.release_unused() def __iter__(self): if self._bar_on: self.bar = ReportBar(max=self.file_size) self.bar.start() return self def __next__(self) -> UniversalBatch: full_batch = self.reader.__next__() if self.bar is not None: self.bar.next(self.batch_size) return full_batch @property def batch_size(self): """ Returns ------- int Size of batch in bytes. """ return self.reader.batch_size @property def file_size(self): """ Returns ------- int File size in bytes. """ return self.file.stat().st_size
[docs]class UniversalReplicatesReader(object): """ Class for reading from replicates methylation reports. The reader sums up the methylation counts. Parameters ---------- readers List of initialized instances of :class:`UniversalReader`. Examples -------- >>> reader1 = UniversalReader("path/to/file1.txt", report_type="bismark", use_threads=True) >>> reader2 = UniversalReader("path/to/file2.txt", report_type="bismark", use_threads=True) >>> >>> for batch in UniversalReplicatesReader([reader1, reader2]): >>> do_something(batch) """
[docs] def __init__( self, readers: list[UniversalReader], ): self.readers = readers self.haste_limit = 1e9 self.bar = None if any(map(lambda reader: reader.report_type in ["bedgraph"], self.readers)): warnings.warn("Merging bedGraph may lead to incorrect results. Please, use other report types.")
def __iter__(self): self.bar = ReportBar(max=self.full_size) self.bar.start() self._seen_chroms = [] self._unfinished = None self._readers_data = { idx: dict(iterator=iter(reader), read_rows=0, haste=0, finished=False, name=reader.file, chr=None, pos=None) for reader, idx in zip(self.readers, range(len(self.readers)))} return self def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): for reader in self.readers: reader.__exit__(exc_type, exc_val, exc_tb) if self.bar is not None: self.bar.goto(self.bar.max) self.bar.finish() def __next__(self): # Key is readers' index reading_from = [key for key, value in self._readers_data.items() if (value["haste"] < self.haste_limit and not value["finished"])] # Read batches from not hasting readers batches_data = {} for key in reading_from: # If batch is present, drop "density" col and set "chr" and "position" sorted for further sorted merge try: batch = next(self._readers_data[key]["iterator"]) batches_data[key] = ( batch.data .set_sorted(["chr", "position"]) .drop("density") .with_columns(pl.lit([np.uint8(key)]).alias("group_idx")) ) # Upd metadata self._readers_data[key]["read_rows"] += len(batch) self._readers_data[key] |= dict(chr=batch.data[-1]["chr"].to_list()[0], pos=batch.data[-1]["position"].to_list()[0]) # Else mark reader as finished except StopIteration: self._readers_data[key]["finished"] = True if not batches_data and len(self._unfinished) == 0: raise StopIteration elif batches_data: self.haste_limit = sum(len(batch) for batch in batches_data.values()) // 3 # We assume that input file is sorted in some way # So we gather chromosomes order in the input files to understand which is last if self._unfinished is not None: batches_data[-1] = self._unfinished pack_seen_chroms = [] for key in batches_data.keys(): batch_chrs = batches_data[key]["chr"].unique(maintain_order=True).to_list() [pack_seen_chroms.append(chrom) for chrom in batch_chrs if chrom not in pack_seen_chroms] [self._seen_chroms.append(chrom) for chrom in pack_seen_chroms if chrom not in self._seen_chroms] # Merge read batches and unfinished data with each other and then group merged = [] for chrom in pack_seen_chroms: # Retrieve first batch (order doesn't matter) with which we will merge chr_merged = batches_data[list(batches_data.keys())[0]].filter(chr=chrom) # Get keys of other batches other = [key for key in batches_data.keys() if key != list(batches_data.keys())[0]] for key in other: chr_merged = chr_merged.merge_sorted(batches_data[key].filter(chr=chrom), key="position") merged.append(chr_merged) merged = pl.concat(merged).set_sorted(["chr", "position"]) # Group merged rows by chromosome and position and check indexes that have merged grouped = ( merged.lazy() .group_by(["chr", "position"], maintain_order=True) .agg([ pl.first("strand", "context", "trinuc"), pl.sum("count_m", "count_total"), pl.col("group_idx").explode() ]) .with_columns(pl.col("group_idx").list.len().alias("group_count")) ) # Finished rows are those which have grouped among all readers min_chr_idx = min(self._seen_chroms.index(reader_data["chr"]) for reader_data in self._readers_data.values()) min_position = min(reader_data["pos"] for reader_data in self._readers_data.values() if reader_data["chr"] == self._seen_chroms[min_chr_idx]) # Finished if all readers have grouped or there is no chance to group because position is already skipped marked = ( grouped .with_columns([ pl.col("chr").replace(self._seen_chroms, list(range(len(self._seen_chroms)))).cast(pl.Int8).alias( "chr_idx"), pl.lit(min_position).alias("min_pos") ]) .with_columns( pl.when( (pl.col("group_count") == len(self._readers_data)) | ( (pl.col("chr_idx") < min_chr_idx) | (pl.col("chr_idx") == min_chr_idx) & ( pl.col("position") < pl.col("min_pos"))) ).then(pl.lit(True)).otherwise(pl.lit(False)).alias("finished") ).collect() ) self._unfinished = marked.filter(finished=False).select(merged.columns) hasting_stats = defaultdict(int) if len(self._unfinished) > 0: group_idx_stats = self._unfinished.select(pl.col("group_idx").list.to_struct()).unnest("group_idx") for col in group_idx_stats.columns: hasting_stats[group_idx_stats[col].drop_nulls().item(0)] += group_idx_stats[col].count() for key in self._readers_data.keys(): self._readers_data[key]["haste"] = hasting_stats[key] # Update bar if reading_from: self.bar.next(sum(self.readers[idx].batch_size for idx in reading_from)) out = marked.filter(finished=True) return UniversalBatch(self._convert_to_full(out), out) @staticmethod def _convert_to_full(df: pl.DataFrame): return ( df.with_columns((pl.col("count_m") / pl.col("count_total")).alias("density")).select( UniversalBatch.colnames()) ) @property def full_size(self): """ Returns ------- int Total size of readers' files in bytes. """ return sum(map(lambda reader: reader.file_size, self.readers))
[docs]class UniversalWriter: """ Class for writing reports in specific methylation report format. Parameters ---------- file Path where the file will be written. report_type Type of the methylation report. Possible types: "bismark", "cgmap", "bedgraph", "coverage", "binom" """
[docs] def __init__( self, file: str | Path, report_type: ReportTypes, ): file = Path(file).expanduser().absolute() self.file = file if report_type not in REPORT_TYPES_LIST: raise KeyError(report_type) self.report_type: str = report_type self.memory_pool = pa.system_memory_pool() self.writer = None
def __enter__(self): self.open() return self
[docs] def close(self): """ This method should be called after writing all data, when writer is called without `with` context manager. """ self.writer.close() self.memory_pool.release_unused()
[docs] def open(self): """ This method should be called before writing data, when writer is called without `with` context manager. """ write_options = pcsv.WriteOptions( delimiter="\t", include_header=False, quoting_style="none" ) self.writer = pcsv.CSVWriter( self.file, ARROW_SCHEMAS[self.report_type], write_options=write_options, memory_pool=self.memory_pool )
def __exit__(self, exc_type, exc_val, exc_tb): self.close()
[docs] def write(self, universal_batch: UniversalBatch): """ Method for writing batch to the report file. """ if universal_batch is None: return if self.writer is None: self.__enter__() if self.report_type == "bismark": fmt_df = universal_batch.to_bismark() elif self.report_type == "cgmap": fmt_df = universal_batch.to_cgmap() elif self.report_type == "bedgraph": fmt_df = universal_batch.to_bedGraph() elif self.report_type == "coverage": fmt_df = universal_batch.to_coverage() else: raise KeyError(f"{self.report_type} not supported") self.writer.write(fmt_df.to_arrow().cast(ARROW_SCHEMAS[self.report_type]))