# dbtk/readers/base.py
"""
Base classes and utilities for file readers.
Defines the abstract Reader interface and Clean enumeration for
header normalization across all reader implementations.
"""
import itertools
import logging
import re
import time
from abc import ABC, abstractmethod
from typing import Any, Iterator, List, Optional, Union
from os import path
from ..record import Record
from ..defaults import settings
logger = logging.getLogger(__name__)
class _Progress:
__slots__ = ('tell', 'byte_total', 'row_total')
def __init__(self, source_obj=None, row_total: Optional[int] = None):
if source_obj is not None and hasattr(source_obj, 'tell'):
self.tell = source_obj.tell
self.byte_total = getattr(source_obj, '_uncompressed_size', None)
else:
self.tell = lambda: 0
self.byte_total = None
self.row_total = row_total
def current(self) -> int:
try:
return self.tell()
except Exception:
return 0
def update(self, row_num: int) -> str:
if self.byte_total is not None:
# Byte-based progress
pos = self.current()
if self.byte_total == 0:
return ""
pct = pos / self.byte_total
current_val = pos // 1024
total_val = self.byte_total // 1024
unit = "K"
elif self.row_total is not None:
# Row-based progress
if self.row_total == 0:
return ""
pct = row_num / self.row_total
current_val = row_num
total_val = self.row_total
unit = ""
else:
return ""
filled = max(0, min(20, round(20 * pct)))
bar = "█" * filled + "░" * (20 - filled)
return f"{bar} {current_val:,}{unit}/{total_val:,}{unit}"
[docs]
class Reader(ABC):
"""
Abstract base class for all file readers in DBTK.
Provides unified interface and common functionality for reading various file formats
(CSV, Excel, JSON, XML, fixed-width). All readers support the same features regardless
of file format: header cleaning, record skipping, row number tracking, and flexible
return types.
Readers are designed to work as context managers and iterators, making them ideal
for memory-efficient processing of large files. They automatically handle resource
cleanup and support both Record objects (with multiple access patterns) and plain
dictionaries as return types.
Common Features
---------------
* **Row number tracking** - Automatic _row_num field for debugging
* **Record skipping** - Skip header rows or bad data
* **Record limiting** - Process only first N records
* **Record filtering** - Filter records with custom functions (new!)
* **Flexible return types** - Record objects or dictionaries
* **Context manager** - Automatic resource cleanup
* **Iterator protocol** - Memory-efficient streaming
* **Null value conversion** - Convert specified values to None
Parameters
----------
add_row_num : bool, default True
Add a '_row_num' field to each record containing the 1-based row number
skip_rows : int, default 0
Number of data rows to skip after headers (useful for skipping footer rows
or known bad data at start of file)
n_rows : int, optional
Maximum number of rows to read. None (default) reads all rows.
null_values : str, list, tuple, or set, optional
Values to convert to None. Can be a single string or a collection of strings.
Common examples: '\\N', 'NULL', 'NA', '' (empty string)
Example
-------
::
# Subclasses implement specific file formats
from dbtk import readers
# CSV with default settings - returns Record objects
with readers.CSVReader(open('data.csv')) as reader:
for record in reader:
print(record.name, record.email) # attribute access
print(record['name'], record['email']) # or dict-style
# Skip first 5 rows, read only 100 records
with readers.CSVReader(open('data.csv'),
skip_rows=5,
n_rows=100) as reader:
for row in reader:
print(row.name)
# Access fields with original or normalized names
with readers.CSVReader(open('messy.csv')) as reader:
# Headers like "ID #", "Student Name" preserved as originals
# but also accessible as normalized: id_hash, student_name
for record in reader:
print(record['ID #'], record['Student Name']) # original
print(record.id, record.student_name) # normalized
# Filter records with custom function
with readers.CSVReader(open('data.csv')) as reader:
reader.add_filter(lambda r: int(r.age) >= 18)
reader.add_filter(lambda r: r.country == 'US')
for record in reader:
print(record.name) # Only US adults
See Also
--------
CSVReader : Read CSV files
JSONReader : Read JSON files
ExcelReader : Read Excel .xlsx files
XMLReader : Read XML files
FixedReader : Read fixed-width text files
Record : Flexible row objects with multiple access patterns
Notes
-----
This is an abstract base class. Use one of the concrete implementations
(CSVReader, JSONReader, etc.) for actual file reading.
Subclasses must implement:
* ``_read_headers()`` - Return list of raw column names from file
* ``_generate_rows()`` - Yield raw data rows as lists
Optionally override:
* ``_cleanup()`` - Release resources (file handles, etc.)
"""
# Class constants for "big" thresholds for adding progress tracking
BIG_ROW_THRESHOLD = 10_000 # Show progress for >10k rows
BIG_BYTE_THRESHOLD = 5 * 1024 * 1024 # Show progress for >5MB files
[docs]
def __init__(self,
add_row_num: bool = True,
skip_rows: int = 0,
n_rows: Optional[int] = None,
headers: Optional[List[str]] = None,
null_values: Union[str, List[str], tuple, set, None] = None
):
"""
Initialize the reader with common options.
Parameters
----------
add_row_num : bool, default True
Add a '_row_num' field to each record containing the 1-based row number
skip_rows : int, default 0
Number of data rows to skip after headers
n_rows : int, optional
Maximum number of rows to read, or None for all rows
headers: Optional list of header names to use instead of reading from row 0
null_values : str, list, tuple, or set, optional
Values to convert to None. Can be a single string or collection of strings.
Common examples: '\\N' (IMDB), 'NULL', 'NA', '' (empty string)
Example
-------
::
# In subclass implementation
class MyReader(Reader):
def __init__(self, file_path, **kwargs):
super().__init__(**kwargs)
self.file = open(file_path)
def _read_headers(self):
return ['id', 'name', 'email']
def _generate_rows(self):
for line in self.file:
yield line.strip().split(',')
"""
self.add_row_num = add_row_num
self._row_num = 0
self._rows_read = 0
self.skip_rows = skip_rows
self.n_rows = n_rows
self._record_class = None
# Normalize null_values to a set for O(1) lookup
if null_values is None:
self._null_values = set()
elif isinstance(null_values, str):
self._null_values = {null_values}
elif isinstance(null_values, (list, tuple, set)):
self._null_values = set(null_values)
else:
raise TypeError(f"null_values must be str, list, tuple, or set, got {type(null_values)}")
self._raw_headers: Optional[List[str]] = headers
self._headers: List[str] = []
self._headers_initialized: bool = False
self._data_iter: Optional[Iterator[List[Any]]] = None
self._total_records = 0 # used by progress tracker when we know the number of rows ahead of time (Excel, DataFrames)
self._trackable = None # used by progress tracker
self._prog: _Progress = None # progress tracker _Progress object
self._big: bool = False # True if source > 5MB (adds progress bar)
self._source: str = None # keep track of source filename for subclasses that use a file pointer directly (Excel)
self._start_time: float = 0 # will get updated when the first record is read (time.monotonic())
self._filters = [] # filter pipeline (list of callables)
self._done: bool = False # True after StopIteration summary has been logged
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit with cleanup."""
self._cleanup()
[docs]
def add_filter(self, func) -> "Reader":
"""
Add a filter function to the filtering pipeline.
Filter functions are applied after skip_rows and null_values conversion, but before
the n_rows limit. Multiple calls to add_filter() accumulate in a pipeline - all filters
must return True for a record to be included.
The filter operates on the final Record object, after null value conversion has
been applied. This allows you to filter on clean data rather than raw values.
Parameters
----------
func : callable
A function that takes a record and returns True to keep it,
False to filter it out. The function should accept a single argument (the record).
Returns
-------
Reader
Returns self to allow method chaining
Example
-------
::
from dbtk import readers
# Single filter
with readers.CSVReader(open('users.csv')) as reader:
reader.add_filter(lambda r: r.age >= 18)
for record in reader:
print(record.name) # Only adults
# Multiple filters (all must pass)
with readers.CSVReader(open('users.csv')) as reader:
reader.add_filter(lambda r: r.age >= 18)
reader.add_filter(lambda r: r.country == 'US')
reader.add_filter(lambda r: r.active == 'true')
for record in reader:
print(record.name) # US adults who are active
# Complex filter function
def valid_email(record):
return '@' in record.email and '.' in record.email
with readers.CSVReader(open('users.csv')) as reader:
reader.add_filter(valid_email)
for record in reader:
print(record.email)
# Get exactly n_rows after filtering
with readers.CSVReader(open('users.csv')) as reader:
reader.add_filter(lambda r: r.age >= 18)
reader.n_rows = 100 # Get 100 records that pass the filter
data = list(reader)
print(len(data)) # Will be 100 (or less if fewer than 100 match)
Notes
-----
* Filters are lazy - they're applied during iteration, not when add_filter() is called
* Execution order: read → skip_rows → null_values → filter pipeline → n_rows
* If both skip_rows and add_filter() are used, a warning is logged (skip_rows applies first)
* The n_rows limit applies after filtering, so n_rows=100 returns 100 filtered records
* Record._row_num field reflects the count of returned (filtered) records, not raw file rows
"""
if not callable(func):
raise TypeError(f"add_filter() requires a callable, got {type(func).__name__}")
self._filters.append(func)
# Warn if combining skip_rows with filtering (only on first filter)
if self.skip_rows > 0 and len(self._filters) == 1:
logger.warning(
"Using both skip_rows and add_filter() - skip_rows applies before filtering. "
"Consider using only add_filter() for clarity."
)
return self
def __iter__(self) -> Iterator[Record]:
"""Make reader iterable."""
return self
def __next__(self):
if not self._headers_initialized:
self._setup_record_class()
if not self._start_time:
self._start_time = time.monotonic()
if self._prog is None:
if self._trackable and hasattr(self._trackable, 'tell'):
self._prog = _Progress(self._trackable)
self._big = self._prog.byte_total is not None and self._prog.byte_total > self.BIG_BYTE_THRESHOLD
elif hasattr(self, '_total_rows') and self._total_rows is not None:
self._prog = _Progress(row_total=self._total_rows)
self._big = self._total_rows > self.BIG_ROW_THRESHOLD
while True:
# Check n_rows limit (applies after filtering)
if self.n_rows is not None and self._row_num >= self.n_rows:
raise StopIteration
try:
row_data = self._read_next_row()
except StopIteration:
if not self._done:
self._done = True
took = time.monotonic() - self._start_time
rate = self._row_num / took if took else 0
if self._big:
print(f"\r{self.__class__.__name__[:-6]} → {self._prog.update(self._row_num)} ✅")
# Show both counts if filtering was used
if self._filters and self._rows_read != self._row_num:
print(f"Done in {took:.2f}s - {self._rows_read:,} read → {self._row_num:,} returned ({int(rate):,} rec/s)")
logger.info(f"Read {self._rows_read:,} rows, returned {self._row_num:,} in {took:.2f}s ({int(rate):,} rec/s)")
else:
print(f"Done in {took:.2f}s ({int(rate):,} rec/s)")
logger.info(f"Read {self._row_num:,} rows in {took:.2f}s ({int(rate):,} rec/s)")
raise # ← let for-loop end
# Track total rows read (before filtering)
self._rows_read += 1
# Show progress based on rows read, not filtered count
if self._big and (self._rows_read == 500 or self._rows_read % 50_000 == 0):
if self._filters:
print(f"\r{self.__class__.__name__[:-6]} → {self._rows_read:,} read → {self._row_num:,} filtered", end="", flush=True)
else:
print(f"\r{self.__class__.__name__[:-6]} → {self._prog.update(self._rows_read)} "
f"({self._rows_read:,})", end="", flush=True)
# Increment _row_num before creating record (needed for _row_num field)
self._row_num += 1
record = self._create_record(row_data)
# Apply filter pipeline
if self._filters:
passed = all(f(record) for f in self._filters)
if not passed:
# Record didn't pass filter, undo increment and try next row
self._row_num -= 1
continue
# Record passed all filters (or no filters present)
return record
def __repr__(self):
source = self._get_source(base_name=True)
if source:
source = f"'{source}'"
return f"{self.__class__.__name__}({source})"
@property
def source(self) -> str:
"""
Get the filename of the source file.
For the ExcelReader and XLSReader, the source must be set manually because the Workbook objects do not keep
a reference to the original file.
"""
if self._source is None:
self._source = self._get_source()
return self._source
@source.setter
def source(self, value: str):
self._source = value
def _get_source(self, base_name: Optional[bool] = False) -> str:
""" Get the source filename for the Reader.
Args:
base_name: If True, return the base filename (no path)
Returns: filename
"""
if hasattr(self, 'fp') and hasattr(self.fp, 'name'):
source = self.fp.name
elif hasattr(self, 'source'):
source = self.source
else:
source = ''
if base_name:
source = path.basename(source)
return source
@property
def row_count(self) -> int:
"""
Returns the number of rows.
This property provides access to the total number of rows, which
is stored in the private attribute `_row_num`.
Returns:
int: The total number of rows.
"""
return self._row_num
@abstractmethod
def _read_headers(self) -> List[str]:
"""
Read and return raw headers from the file.
Must be implemented by subclasses.
"""
pass
@abstractmethod
def _generate_rows(self) -> Iterator[List[Any]]:
"""Generate all raw data rows as lists, without applying skip or limit."""
pass
def _read_next_row(self) -> List[Any]:
if self._data_iter is None:
gen = self._generate_rows()
# Only apply skip_rows here; n_rows is applied after filtering in __next__()
if self.skip_rows > 0:
gen = itertools.islice(gen, self.skip_rows, None)
self._data_iter = gen
return next(self._data_iter)
def _cleanup(self):
"""
Cleanup resources. Override in subclasses if needed.
Default implementation does nothing.
"""
pass
def _setup_record_class(self):
"""Initialize headers and create Record subclass with original field names."""
if self._headers_initialized:
return
# Read raw headers from file (original field names)
raw_headers = self._read_headers()
# Store original headers (no normalization - Record.set_fields() handles it)
self._headers = raw_headers[:]
# Add _row_num if requested and not already present
if self.add_row_num:
if '_row_num' in self._headers:
raise ValueError("Header '_row_num' already exists. Remove it or set add_row_num=False.")
self._headers.append('_row_num')
# Create Record subclass and set fields
# set_fields() will automatically normalize for attribute access
self._record_class = type('FileRecord', (Record,), {})
self._record_class.set_fields(self._headers)
self._headers_initialized = True
def _convert_nulls(self, row_data: List[Any]) -> List[Any]:
"""
Convert null values to None in row data.
Args:
row_data: List of values for this row
Returns:
List with null values converted to None
"""
if not self._null_values:
return row_data
return [None if val in self._null_values else val for val in row_data]
def _create_record(self, row_data: List[Any]) -> Record:
"""
Create a Record from row data.
Args:
row_data: List of values for this row
Returns:
Record instance with values populated
"""
if not self._headers_initialized:
self._setup_record_class()
# Make a copy to avoid modifying the original
row_data = list(row_data)
# Convert null values to None
row_data = self._convert_nulls(row_data)
# Pad with None if row is shorter than expected (excluding _row_num)
expected_data_cols = len(self._headers) - (1 if self.add_row_num and '_row_num' in self._headers else 0)
while len(row_data) < expected_data_cols:
row_data.append(None)
# Add _row_num if it's in headers (always goes at the end)
if self.add_row_num and '_row_num' in self._headers:
row_data.append(self.skip_rows + self._row_num)
# Truncate if row is longer than headers
if len(row_data) > len(self._headers):
row_data = row_data[:len(self._headers)]
# Return Record
return self._record_class(*row_data)
@property
def headers(self) -> List[str]:
"""
Get the column headers.
Returns:
List of cleaned header names
"""
if not self._headers_initialized:
self._setup_record_class()
return self._headers.copy()
@property
def fieldnames(self) -> List[str]:
"""
Alias for headers to maintain compatibility with csv.DictReader.
Returns:
List of cleaned header names
"""
return self.headers