# dbtk/readers/xml.py
"""XML file reader with XPath support for element extraction."""
import logging
import os
from pathlib import Path
from typing import List, Any, Dict, Optional, TextIO, Union, Iterator
from .base import Reader
try:
from lxml import etree
HAS_LXML = True
except ImportError:
from xml.etree import ElementTree as etree
HAS_LXML = False
# lxml and stdlib use different exception types
_XMLParseError = etree.XMLSyntaxError if HAS_LXML else etree.ParseError
_XPathError = etree.XPathEvalError if HAS_LXML else SyntaxError
def _xpath(node, expr):
"""Run an XPath expression using lxml or stdlib ElementTree."""
if HAS_LXML:
return node.xpath(expr)
# stdlib findall supports a limited XPath subset; // must be relative
if expr.startswith('//'):
expr = '.' + expr
try:
return node.findall(expr)
except SyntaxError:
return []
logger = logging.getLogger(__name__)
[docs]
class XMLColumn:
"""Column definition for XML extraction."""
[docs]
def __init__(self, name: str, xpath: Optional[str] = None,
data_type: str = 'text'):
"""
Args:
name: Column name for the Record
xpath: XPath expression (if None, uses simple element matching)
data_type: Data type hint (not enforced, just documentation)
"""
self.name = name
self.xpath = xpath
self.data_type = data_type
def __repr__(self):
return f"XMLColumn('{self.name}', xpath='{self.xpath}', data_type='{self.data_type}')"
[docs]
class XMLReader(Reader):
"""XML file reader that returns Record objects."""
[docs]
def __init__(self,
fp: TextIO,
record_xpath: str = "//record",
columns: Optional[List[XMLColumn]] = None,
sample_size: int = 10,
add_row_num: bool = True,
skip_rows: int = 0,
n_rows: Optional[int] = None,
null_values=None):
"""
Initialize XML reader.
Args:
fp: File pointer to XML file
record_xpath: XPath expression to find record elements
columns: List of XMLColumn definitions for custom extraction
sample_size: Number of records to sample for column discovery
add_row_num: Add _row_num to each record
skip_rows: Number of data rows to skip after headers
n_rows: Maximum number of rows to read, or None for all
null_values: Values to convert to None (e.g., '\\N', 'NULL', 'NA')
"""
super().__init__(add_row_num=add_row_num,
skip_rows=skip_rows, n_rows=n_rows,
null_values=null_values)
self.fp = fp
# Set trackable for progress tracking
if hasattr(fp, '_uncompressed_size'):
# Compressed file - use buffer's tell() but preserve _uncompressed_size
self._trackable = fp.buffer
self._trackable._uncompressed_size = fp._uncompressed_size
elif hasattr(fp, 'buffer'):
# Text mode file - use buffer for better performance
self._trackable = fp.buffer
try:
self._trackable._uncompressed_size = os.fstat(self._trackable.fileno()).st_size
except (AttributeError, OSError):
pass
else:
# Binary mode or other file type
self._trackable = fp
try:
self._trackable._uncompressed_size = os.fstat(self._trackable.fileno()).st_size
except (AttributeError, OSError):
pass
self.record_xpath = record_xpath
self.custom_columns = columns or []
self.sample_size = sample_size
self._tree = None
self._record_nodes = None
self._column_cache = None
self._all_columns = [] # Combined auto-discovered + custom columns
# Parse XML and discover structure
self._parse_xml()
def _parse_xml(self):
"""Parse the XML file and prepare for reading."""
try:
self._tree = etree.parse(self.fp)
except _XMLParseError as e:
raise ValueError(f"Invalid XML: {e}")
# Find all record nodes
try:
self._record_nodes = _xpath(self._tree, self.record_xpath)
except _XPathError as e:
raise ValueError(f"Invalid XPath expression '{self.record_xpath}': {e}")
if not self._record_nodes:
raise ValueError(f"No records found with XPath: {self.record_xpath}")
def _introspect_columns(self) -> List[str]:
"""Analyze first few records to discover all possible columns."""
if self._column_cache is not None:
return self._column_cache
# Start with custom columns
self._all_columns = list(self.custom_columns)
custom_names = {col.name for col in self.custom_columns}
# Auto-discover columns from sample records
discovered_elements = []
sample_records = self._record_nodes[:self.sample_size]
for record_node in sample_records:
for child in record_node:
if child.tag:
col_name = self._flatten_element_name(child.tag)
if col_name not in custom_names and col_name not in discovered_elements: # Don't duplicate custom columns
discovered_elements.append(col_name)
# Add discovered columns as XMLColumn objects (normalization happens in Record.set_fields())
for element_name in discovered_elements:
self._all_columns.append(XMLColumn(element_name))
# Extract just the names for the column cache
self._column_cache = [col.name for col in self._all_columns]
return self._column_cache
def _flatten_element_name(self, tag: str) -> str:
"""Convert XML element name to valid column name."""
# Handle namespaces: {namespace}localname -> localname
if '}' in tag:
tag = tag.split('}')[1]
# Replace invalid characters for Python identifiers
tag = tag.replace('-', '_').replace('.', '_').replace(':', '_')
return tag
def _extract_column_value(self, record_node, xml_column: XMLColumn) -> Optional[str]:
"""Extract value for a column from a record node."""
# Check if it has custom XPath
if xml_column.xpath:
try:
result = _xpath(record_node, xml_column.xpath)
if result:
# Handle different XPath result types
if isinstance(result[0], str):
# Text result
value = result[0]
elif hasattr(result[0], 'text'):
# Element result
value = result[0].text
else:
# Other result (attribute, etc.)
value = str(result[0])
return value.strip() if value else None
else:
return None
except _XPathError:
return None
# Look for simple child element by column name
# Convert column name back to possible XML tag names
possible_tags = [xml_column.name, xml_column.name.replace('_', '-'), xml_column.name.replace('_', '.')]
for tag in possible_tags:
# Try with and without namespace
child = record_node.find(tag)
if child is None:
# Try with any namespace
child = record_node.find(f".//{tag}")
if child is not None:
text = child.text
return text.strip() if text else None
return None
def _read_headers(self) -> List[str]:
"""Read and return column names from XML structure."""
return self._introspect_columns()
def _generate_rows(self) -> Iterator[List[Any]]:
"""Generate data rows from XML record nodes."""
# Ensure columns are discovered before reading data
if not self._all_columns:
self._introspect_columns()
for record_node in self._record_nodes:
row_data = []
for xml_column in self._all_columns:
value = self._extract_column_value(record_node, xml_column)
row_data.append(value)
yield row_data
def _cleanup(self):
"""Close the file pointer."""
if hasattr(self, 'fp') and self.fp:
self.fp.close()
@property
def record_count(self) -> int:
"""Return total number of records found."""
return len(self._record_nodes) if self._record_nodes else 0
@property
def columns(self) -> List[XMLColumn]:
"""Return all column definitions (custom + auto-discovered)."""
if not self._all_columns:
self._introspect_columns() # Trigger discovery
return self._all_columns.copy()
# Convenience function to match other readers
[docs]
def open_xml(filename: Union[str, Path], **kwargs) -> XMLReader:
"""
Open XML file for reading.
Args:
filename: Path to XML file
**kwargs: Arguments passed to XMLReader
Returns:
XMLReader instance
Example
-------
::
with open_xml('data.xml', record_xpath='//user') as reader:
for record in reader:
print(record.name)
"""
fp = open(filename, 'rb')
return XMLReader(fp, **kwargs)
if not HAS_LXML:
logger.debug('lxml not available; falling back to stdlib xml.etree.ElementTree (limited XPath support).')