Source code for dbtk.cursors

# dbtk/cursors.py
"""
Cursor classes that wrap database cursors and provide different return types.
All cursors delegate to the underlying database cursor stored in _cursor.
"""

import inspect
import logging
from pathlib import Path
from typing import List, Any, Optional, Iterator, Callable, Union

from .record import Record
from .utils import ParamStyle, process_sql_parameters, normalize_field_name
from .defaults import settings

logger = logging.getLogger(__name__)
__all__ = ['Cursor', 'PreparedStatement']


def _resolve_sql_path(filename: Union[str, Path], caller_file: str) -> Path:
    """Resolve a SQL filename to an existing path.

    Checks the path as given first (absolute or relative to CWD), then falls
    back to the directory containing *caller_file* so that scripts can refer to
    sibling SQL files by bare name.
    """
    path = Path(filename)
    if not path.is_absolute() and not path.exists():
        candidate = Path(caller_file).parent / path
        if candidate.exists():
            return candidate
    return path


[docs] class PreparedStatement: """ A prepared SQL statement loaded from a file with cached parameter mapping. The statement is read from file once and SQL parameter conversion is performed once based on the cursor's paramstyle. The prepared statement can then be executed multiple times efficiently. It retains a reference to the cursor, so it can be used in the same way as a regular cursor (fetchone(), fetchmany(), etc.). """
[docs] def __init__(self, cursor, query: Optional[str] = None, filename: Optional[Union[str, Path]] = None, encoding: Optional[str] = 'utf-8-sig'): """ Create a prepared statement from a SQL file. Args: cursor: The cursor that will execute this statement query: SQL query string (optional) filename: Path to SQL file (relative to CWD) encoding: File encoding (default: utf-8-sig) """ self.cursor = cursor if filename: self.filename = filename # Read SQL from file with open(filename, encoding=encoding) as f: original_sql = f.read() elif query is not None: self.filename = None original_sql = query else: raise ValueError('Must provide either query or filename') # Transform SQL for cursor's paramstyle self.sql, self.param_names = process_sql_parameters( original_sql, cursor.paramstyle )
def __iter__(self): """Make prepared statement iterable.""" return self.cursor.__iter__() def __next__(self): """Iterator protocol.""" return self.cursor.__next__()
[docs] def execute(self, bind_vars: Optional[dict] = None) -> Any: """ Execute the prepared statement with the given parameters. Args: bind_vars: Dictionary of named parameters Returns: Cursor if return_cursor=True, else None """ if bind_vars is None: bind_vars = {} try: params = self.cursor.prepare_params(self.param_names, bind_vars) return self.cursor.execute(self.sql, params) except Exception as e: source = self.filename or '<query>' logger.error( f"Error executing prepared statement from {source}\n" f"Transformed SQL: {self.sql}\n" f"Parameters: {bind_vars}" ) raise
def __getattr__(self, key: str): """Delegate attribute access to underlying cursor.""" return getattr(self.cursor, key)
[docs] class Cursor: """ Cursor that returns query results as Records. It wraps database-specific cursor objects and provides a consistent interface plus additional functionality like SQL file execution, parameter conversion, and prepared statements. It also maintains a clean reference hierarchy to the connection (cursor.connection) and to the driver (cursor.connection.driver) Cursor returns Record objects, which provide flexible access via dictionary keys, attributes, or integer indices. Attributes ---------- connection : Database The database connection this cursor belongs to paramstyle : str Parameter style of the underlying database ('qmark', 'named', etc.) placeholder : str Placeholder string for bind parameters (e.g., '?', ':1', etc.) description Column metadata from the last query (delegated to underlying cursor) Note ---- Cursors delegate attribute access to the underlying database-specific cursor, so all native cursor functionality is available. Example ------- :: db = dbtk.connect('prod_ods') cursor = db.cursor() cursor.execute("SELECT id, name, email FROM users WHERE status = :status", {'status': 'active'}) for row in cursor: user_id, name, email = row # Plain list - index access only print(f"{user_id}: {name} ({email})") See Also -------- Record : Flexible data structure supporting dict, attribute, and index access """ # Attributes that live on this class and are not delegated to the underlying cursor _local_attrs = [ 'connection', 'debug', 'return_cursor', 'placeholder', 'paramstyle', 'record_factory', 'batch_size', '_cursor', '_row_factory_invalid', '_statement', '_bind_vars', '_bulk_method' ] # Attributes that are allowed to be passed in from the connection/configuration layer WRAPPER_SETTINGS = ('batch_size', 'debug', 'return_cursor', 'fast_executemany')
[docs] def __init__(self, connection, batch_size: Optional[int] = None, debug: Optional[bool] = False, return_cursor: Optional[bool] = False, **kwargs): """ Initialize a cursor for database operations. Parameters ---------- connection : Database Database connection object batch_size: int, optional How many rows to process at a time when using executemany() or bulk operations in DataSurge debug : bool, default False Enable debug output showing queries and bind variables return_cursor : bool, default False If True, execute() returns the cursor for method chaining **kwargs Additional arguments passed to the underlying database cursor Example ------- :: # Typically created via Database.cursor() cursor = db.cursor() # With debug enabled cursor = db.cursor(debug=True) # With method chaining cursor = db.cursor(return_cursor=True) results = cursor.execute("SELECT * FROM users").fetchall() """ self.connection = connection self.debug = debug self.record_factory = None self._row_factory_invalid = True self.return_cursor = return_cursor if batch_size is None: batch_size = settings.get('default_batch_size', 1000) self.batch_size = batch_size self._statement = None # Stores statement locally if adapter doesn't self._bind_vars = None # Stores bind vars locally if adapter doesn't self._bulk_method = None # Allows us to override executemany if needed # remove any kwargs not intended for the underlying cursor filtered_kwargs = {key: val for key, val in kwargs.items() if key not in self.WRAPPER_SETTINGS} # Create underlying cursor try: if hasattr(self.connection, '_connection'): self._cursor = self.connection._connection.cursor(**filtered_kwargs) else: self._cursor = self.connection.cursor(**filtered_kwargs) except Exception as e: raise TypeError(f'First argument must be a database connection object: {e}') # Handle fast_executemany configuration for pyodbc if 'fast_executemany' in kwargs: if hasattr(self._cursor, 'fast_executemany'): self._cursor.fast_executemany = kwargs['fast_executemany'] elif hasattr(self.connection, 'driver_name') and self.connection.driver_name == 'pyodbc_sqlserver': logger.info( "pyodbc with SQL Server detected. Consider setting cursor: {fast_executemany: true} " "in your connection config for better bulk insert performance. Note: fast_executemany " "may cause MemoryError with TEXT/NVARCHAR(MAX)/JSON columns - use VARCHAR types instead." ) # Set parameter style info self.paramstyle = self.connection.driver.paramstyle if hasattr(self.connection, 'placeholder'): self.placeholder = self.connection.placeholder # Ensure arraysize exists (some adapters don't have it) if not hasattr(self._cursor, 'arraysize'): self.__dict__['arraysize'] = 1000
def __getattr__(self, key: str) -> Any: """Delegate attribute access to underlying cursor.""" if key == 'statement' and not hasattr(self._cursor, 'statement'): return self._statement if key == 'bind_vars' and not hasattr(self._cursor, 'bind_vars'): return self._bind_vars return getattr(self._cursor, key) def __setattr__(self, key: str, value: Any) -> None: """Set attributes on this cursor or delegate to underlying cursor.""" if key in self._local_attrs: self.__dict__[key] = value else: setattr(self._cursor, key, value) def __dir__(self) -> List[str]: """Return available attributes.""" return list(set( dir(self._cursor) + dir(self.__class__) + self._local_attrs )) def __iter__(self) -> Iterator: """Make cursor iterable.""" if self._is_ready(): return self def __next__(self) -> Any: """Iterator protocol.""" row = self.fetchone() if row is not None: return row else: raise StopIteration
[docs] def prepare_params(self, param_names: list, bind_vars: dict, paramstyle: str = None) -> Any: """ Convert named parameters to the format required by the cursor's paramstyle. Used automatically by PreparedStatement, Table and DataSurge classes. Parameters ---------- param_names : list of str Ordered list of parameter names as they appear in the SQL statement. Missing names default to ``None`` with a debug-level log message. bind_vars : dict Dictionary of named parameter values keyed by parameter name. paramstyle : str, optional Override the cursor's native paramstyle for this call. Must be one of the values in :class:`dbtk.utils.ParamStyle`. Ignored if not a recognised style. Returns ------- tuple For positional styles (``qmark``, ``numeric``, ``format``): values ordered to match ``param_names``. dict For named styles (``named``, ``pyformat``): subset of ``bind_vars`` containing only the required parameters. Example ------- :: from dbtk.utils import ParamStyle, process_sql_parameters # query with :named or %(pyformat)s parameters sql = "SELECT * FROM warriors WHERE nation = :nation AND rank = COALESCE(:rank, rank)" # query rewritten in cursor's parameter style and parameter names in order they appear query, params_names = process_sql_parameters(sql, ParamStyle.get_positional_style(cur.paramstyle)) # missing parameter defaults to None, extra parameters are ignored cur.prepare_params(params_names, {'nation': 'Fire Nation', 'nick_name': 'Sparky'}) ('Fire Nation', None) """ missing = set(param_names) - set(bind_vars.keys()) paramstyle = paramstyle if paramstyle and paramstyle in ParamStyle.values() else self.paramstyle if missing: logger.debug(f"Parameters not provided, defaulting to None: {', '.join(missing)}") if paramstyle in ParamStyle.positional_styles(): # Build tuple in param_names order return tuple(bind_vars.get(name) for name in param_names) else: # Return dict with only the params we need return {name: bind_vars.get(name) for name in param_names}
def _detect_bulk_method(self) -> Callable: """ Detect and return the fastest bulk execution method for this cursor. Called once per cursor, on first executemany(). Stores in self._bulk_method. """ adapter = self.connection.driver.__name__ if adapter == 'psycopg2': try: from psycopg2.extras import execute_batch # Return a bound dispatcher: execute_batch(cur, sql, argslist, page_size) def psycopg_batch(cur, sql, argslist): page_size=getattr(self, 'batch_size', 1000) return execute_batch(cur, sql, argslist, page_size=page_size) logger.debug("Cursor upgraded: executemany → psycopg2.extras.execute_batch") return psycopg_batch except ImportError: logger.debug("psycopg2.extras not available — using native executemany") # Fallback for everything else (SQLite, MySQL, etc.) return lambda cur, sql, argslist: cur.executemany(sql, argslist) def _create_record_factory(self) -> None: """Create Record subclass with original column names from description.""" self._row_factory_invalid = False # Get original column names from description (no transformation) if not self.description: original_columns = [] else: original_columns = [col[0] for col in self.description] # Create dynamic Record subclass and set fields # set_fields() will handle normalization automatically RecordClass = type('Record', (Record,), {}) RecordClass.set_fields(original_columns) self.record_factory = RecordClass
[docs] def columns(self, normalized: bool = False) -> List[str]: """ Return list of column names. Parameters ---------- normalized : bool, default False If True, return normalized column names (sanitized for Python attributes). If False, return original column names from database. Returns ------- List[str] Column names in order Example ------- :: cursor.execute("SELECT 'First Name', 'User ID' FROM ...") cursor.columns() # ['First Name', 'User ID'] cursor.columns(normalized=True) # ['first_name', 'user_id'] """ if not self.description: return [] if normalized: # Return normalized column names return [normalize_field_name(c[0]) for c in self.description] else: # Return original column names return [c[0] for c in self.description]
def _is_ready(self) -> bool: """Check if ready and update record factory if columns changed.""" if self._cursor.description is None: raise Exception('Query has not been run or did not succeed.') elif self.record_factory is None: self._create_record_factory() elif self._row_factory_invalid: # Check if columns have changed since last query if hasattr(self.record_factory, '_fields'): # Get current original column names from description current_columns = [col[0] for col in self.description] if self.description else [] if self.record_factory._fields != current_columns: self._create_record_factory() else: self._row_factory_invalid = False else: self._create_record_factory() return True
[docs] def execute(self, query: str, bind_vars: Union[tuple, dict] = (), convert_params: bool = False) -> None: """ Execute a database query. Pass convert_params=True to have the query rewritten to the cursor's paramstyle and parameters handled automatically (same as PreparedStatement and execute_file). Parameters ---------- bind_vars : tuple or dict, default () Bind parameters to pass to the database. When convert_params is False (the default), passed directly to the underlying cursor and must already be in the format required by the cursor's paramstyle. When convert_params is True, must be a dict (or Record). convert_params : bool, default False If True, parameter order will be extracted from the query and the query will be rewritten to match the cursor's paramstyle. Missing parameters will be defaulted to None and extra parameters will be ignored. """ self._row_factory_invalid = True if convert_params: if not hasattr(bind_vars, 'items'): if bind_vars: raise ValueError(f'bind_vars must be a dict when convert_params=True') bind_vars = {} query, param_names = process_sql_parameters(query, self.paramstyle) bind_vars = self.prepare_params(param_names, bind_vars) if self.debug: logger.debug(f'Query:\n{query}') logger.debug(f'Bind vars:\n{bind_vars}') # Store statement and bind_vars locally if the adapter doesn't if not hasattr(self._cursor, 'statement'): self.__dict__['_statement'] = query if not hasattr(self._cursor, 'bind_vars'): self.__dict__['_bind_vars'] = bind_vars # some adapters return a cursor instead of the Database API specified None _ = self._cursor.execute(query, bind_vars) if self.return_cursor: return self else: return None
[docs] def execute_file(self, filename: Union[str, Path], bind_vars: Optional[dict] = None, **kwargs) -> Any: """ Execute SQL query from a file with named parameter substitution. This is a convenience method for one-off queries. For queries that will be executed multiple times, use prepare_file() instead for better performance. Args: filename: Path to SQL file. Resolved relative to CWD first; if not found there, falls back to the directory of the calling script. bind_vars: Dictionary of named parameters **kwargs: encoding: File encoding (default: utf-8-sig) Returns: Cursor if return_cursor=True, else None Example: cursor.execute_file('queries/get_user.sql', {'user_id': 123}) """ encoding = kwargs.get('encoding', 'utf-8-sig') path = _resolve_sql_path(filename, inspect.stack()[1].filename) try: # Read SQL from file with open(path, encoding=encoding) as f: sql = f.read() # Transform SQL for this cursor's paramstyle from .database import ParamStyle transformed_sql, param_names = process_sql_parameters(sql, self.paramstyle) # Prepare parameters if bind_vars: params = self.prepare_params(param_names, bind_vars) else: params = None return self.execute(transformed_sql, params) except Exception as e: statement = locals().get('transformed_sql', 'N/A') logger.error( f"Error executing SQL file: {path}\n" f"Transformed SQL: {statement}\n" f"Parameters: {bind_vars}" ) raise
[docs] def prepare_file(self, filename: Union[str, Path], encoding: str = 'utf-8-sig') -> PreparedStatement: """ Prepare a SQL statement from a file for repeated execution. The SQL file is read once and parameter conversion is performed once. The returned PreparedStatement can be executed multiple times efficiently. Args: filename: Path to SQL file. Resolved relative to CWD first; if not found there, falls back to the directory of the calling script. encoding: File encoding (default: utf-8-sig) Returns: PreparedStatement object Example ------- :: stmt = cursor.prepare_file('queries/insert_user.sql') for user in users: stmt.execute({'user_id': user.id, 'name': user.name}) """ path = _resolve_sql_path(filename, inspect.stack()[1].filename) return PreparedStatement(self, filename=path, encoding=encoding)
[docs] def prepare_query(self, query: str) -> PreparedStatement: """ Prepare a SQL statement from a query string for repeated execution. The parameter conversion is performed once. The returned PreparedStatement can be executed multiple times efficiently. Args: query: Returns: PreparedStatement object Example ------- :: stmt = cursor.prepare_query('SELECT * FROM users WHERE user_id = :user_id') for user in users: stmt.execute({'user_id': user.id}) """ return PreparedStatement(self, query=query)
[docs] def executemany(self, query: str, bind_vars: List[tuple]) -> None: """Execute a query against multiple parameter sets.""" self._row_factory_invalid = True if self.debug: logger.debug(f'Executemany - Query:\n{query}') logger.debug(f'Bind vars (first row):\n{bind_vars[0]}') # Store statement and bind_vars (first row only) locally if the adapter doesn't if not hasattr(self._cursor, 'statement'): self.__dict__['_statement'] = query if not hasattr(self._cursor, 'bind_vars'): self.__dict__['_bind_vars'] = bind_vars[0] if self._bulk_method is None: # Detect and cache the fastest bulk execution method self._bulk_method = self._detect_bulk_method() _ = self._bulk_method(self._cursor, query, bind_vars) if self.return_cursor: return self else: return None
[docs] def selectinto(self, query: str, bind_vars: tuple = ()) -> Any: """Execute query that must return exactly one row.""" self.execute(query, bind_vars) rows = self.fetchmany(2) if len(rows) == 0: raise self.connection.driver.DatabaseError('No Data Found.') elif len(rows) > 1: raise self.connection.driver.DatabaseError( 'selectinto() must return one and only one row.' ) else: return rows[0]
[docs] def fetchone(self) -> Optional[Any]: """Fetch the next row.""" if self._is_ready(): row = self._cursor.fetchone() if row: return self.record_factory(*row) return None
[docs] def fetchmany(self, size: Optional[int] = None) -> List[Any]: """Fetch the next set of rows.""" if size is None: size = self._cursor.arraysize if self._is_ready(): return [ self.record_factory(*row) for row in self._cursor.fetchmany(size) ] return []
[docs] def fetchall(self) -> List[Any]: """Fetch all remaining rows.""" if self._is_ready(): return [ self.record_factory(*row) for row in self._cursor.fetchall() ] return []