Source code for openpois.osm.format_observations

#   -------------------------------------------------------------
#   Copyright (c) Henry Spatial Analysis. All rights reserved.
#   Licensed under the MIT License. See LICENSE in project root for information.
#   -------------------------------------------------------------

"""
This module formats OSM changes and versions into observations, which can be more easily
queried and statistically analyzed.
"""

import os
import re
from pathlib import Path

import duckdb
import pyarrow as pa
import pyarrow.parquet as pq


_SAFE_KEY_RE = re.compile(r"^[A-Za-z0-9_:]+$")


def _validate_key(k: str) -> str:
    """Allow only alphanumerics, underscores, and colons in interpolated keys.

    OSM tag keys such as ``addr:street`` are valid; anything else is rejected
    to avoid opening a SQL injection path through the pivot CTE.
    """
    if not isinstance(k, str) or not _SAFE_KEY_RE.match(k):
        raise ValueError(f"Unsafe tag key for SQL interpolation: {k!r}")
    return k


def _init_scan_state(keep_keys: list[str]) -> dict:
    return {
        "add_to_list": False,
        "last_tag_timestamp": None,
        "last_obs_timestamp": None,
        "last_tag_user": None,
        "last_tag_value": None,
        "tag_value": None,
        "keep_current": {k: None for k in keep_keys},
        "keep_last": {k: None for k in keep_keys},
    }


def _advance_scan_state(
    state: dict,
    row: tuple,
    col_idx: dict,
    tag_key: str,
    keep_keys: list[str],
) -> dict | None:
    """Run one row through the per-POI state machine.

    Returns the observation dict to emit, or ``None`` if this version is
    before the tag was first added (so ``add_to_list`` is still False).
    """
    elem_id = row[col_idx["id"]]
    version = row[col_idx["version"]]
    changeset = row[col_idx["changeset"]]
    obs_timestamp = row[col_idx["timestamp"]]
    user = row[col_idx["user"]]

    # `last_tag_value` on the emitted obs must reflect the PRE-update state;
    # the other `last_*` fields are updated below after `obs` is built.
    prev_last_tag_value = state["last_tag_value"]

    # Keep-keys: shift current → last only when this version's changeset
    # touches the key; otherwise current + last both stay sticky.
    for k in keep_keys:
        ch = row[col_idx[f"{k}__change"]]
        if ch is not None:
            state["keep_last"][k] = state["keep_current"][k]
            state["keep_current"][k] = row[col_idx[f"{k}__value"]]

    tag_val = row[col_idx[f"{tag_key}__value"]]
    tag_ch = row[col_idx[f"{tag_key}__change"]]
    vis_val = row[col_idx["visible__value"]]
    vis_ch = row[col_idx["visible__change"]]

    tag_added = tag_ch == "Added"
    tag_changed = tag_ch == "Changed"
    tag_deleted = tag_ch == "Deleted"
    poi_deleted = (vis_ch is not None) and (vis_val == "false")
    poi_re_added = (
        state["add_to_list"]
        and (vis_ch is not None)
        and (vis_val == "true")
    )
    any_change = (
        tag_added or tag_changed or tag_deleted or poi_deleted or poi_re_added
    )

    if tag_added:
        state["add_to_list"] = True
    if tag_added or tag_changed:
        state["last_tag_value"] = tag_val
        state["tag_value"] = tag_val
    if tag_deleted or poi_deleted:
        state["tag_value"] = None
    if poi_re_added:
        state["tag_value"] = state["last_tag_value"]

    if not state["add_to_list"]:
        return None

    obs = {
        "id": elem_id,
        "version": version,
        "changeset": changeset,
        "obs_timestamp": obs_timestamp,
        "last_obs_timestamp": state["last_obs_timestamp"],
        "last_tag_timestamp": state["last_tag_timestamp"],
        "user": user,
        "last_tag_user": state["last_tag_user"],
        "tag_value": state["tag_value"],
        "last_tag_value": prev_last_tag_value,
        "changed": int(any_change),
        "deleted": None,
        "tag_key": tag_key,
    }
    for k in keep_keys:
        obs[k] = state["keep_current"][k]
        obs[f"{k}_last_value"] = state["keep_last"][k]

    if any_change:
        state["last_tag_timestamp"] = obs_timestamp
        state["last_tag_user"] = user
    state["last_obs_timestamp"] = obs_timestamp
    return obs


[docs] def format_observations_duckdb( changes_path: Path, versions_path: Path, output_path: Path, tag_key: str, keep_keys: list[str], duckdb_memory_limit: str = "4GB", duckdb_threads: int | None = None, duckdb_temp_dir: Path | None = None, batch_rows: int = 100_000, verbose: bool = True, ) -> int: """ Stream POI observations from Parquet inputs to Parquet via DuckDB. DuckDB pivots the long-form ``osm_changes.parquet`` wide by tag key, LEFT-joins ``osm_versions.parquet`` on ``(type, id, version)``, and returns rows sorted by ``(type, id, version)``; the sort spills to ``duckdb_temp_dir`` past ``duckdb_memory_limit``. A Python scan then iterates the sorted stream through :func:`_advance_scan_state`, buffering emitted observations per DuckDB fetch batch and flushing them as ``pyarrow.Table`` record batches to a ``ParquetWriter``. Peak RSS is bounded to roughly ``duckdb_memory_limit`` plus one fetch batch of observations, regardless of input size. Args: changes_path: Input ``osm_changes.parquet``. versions_path: Input ``osm_versions.parquet``. output_path: Destination ``.parquet``. Overwritten. tag_key: Tag key to model (e.g. ``"name"``). keep_keys: Tag keys to retain on each observation. Must not include special characters (validated against ``[A-Za-z0-9_:]+``). duckdb_memory_limit: DuckDB ``memory_limit`` setting. The sort operator spills to disk past this. duckdb_threads: DuckDB worker thread count. Defaults to ``os.cpu_count()``. duckdb_temp_dir: Sort-spill directory. Defaults to ``output_path.parent``. batch_rows: Rows pulled per ``fetchmany`` call; also the ParquetWriter flush size. verbose: Print progress. Returns: Total number of observation rows written. """ changes_path = Path(changes_path) versions_path = Path(versions_path) output_path = Path(output_path) output_path.parent.mkdir(parents = True, exist_ok = True) tag_key = _validate_key(tag_key) keep_keys = [_validate_key(k) for k in keep_keys] # Pivot needs tag_key, 'visible', and all keep_keys (deduplicated). pivot_keys: list[str] = [tag_key, "visible"] for k in keep_keys: if k not in pivot_keys: pivot_keys.append(k) threads = duckdb_threads if duckdb_threads is not None else (os.cpu_count() or 1) temp_dir = ( Path(duckdb_temp_dir) if duckdb_temp_dir is not None else output_path.parent ) temp_dir.mkdir(parents = True, exist_ok = True) pivot_exprs: list[str] = [] for k in pivot_keys: pivot_exprs.append( f"MAX(CASE WHEN key = '{k}' THEN value END) AS \"{k}__value\"" ) pivot_exprs.append( f"MAX(CASE WHEN key = '{k}' THEN change END) AS \"{k}__change\"" ) pivot_select = ",\n ".join(pivot_exprs) key_list_sql = ", ".join(f"'{k}'" for k in pivot_keys) pivot_cols_sql = ", ".join( f'p."{k}__value", p."{k}__change"' for k in pivot_keys ) sql = f""" WITH pivoted AS ( SELECT type, id, version, {pivot_select} FROM read_parquet('{changes_path.as_posix()}') WHERE key IN ({key_list_sql}) GROUP BY type, id, version ) SELECT v.type, v.id, v.version, v.changeset, v.timestamp, v."user", {pivot_cols_sql} FROM read_parquet('{versions_path.as_posix()}') v LEFT JOIN pivoted p USING (type, id, version) ORDER BY v.type, v.id, v.version """ base_cols = ["type", "id", "version", "changeset", "timestamp", "user"] col_idx: dict = {c: i for i, c in enumerate(base_cols)} for k in pivot_keys: col_idx[f"{k}__value"] = len(col_idx) col_idx[f"{k}__change"] = len(col_idx) schema_fields = [ ("id", pa.int64()), ("version", pa.int64()), ("changeset", pa.int64()), ("obs_timestamp", pa.string()), ("last_obs_timestamp", pa.string()), ("last_tag_timestamp", pa.string()), ("user", pa.string()), ("last_tag_user", pa.string()), ("tag_value", pa.string()), ("last_tag_value", pa.string()), ("changed", pa.int8()), ("deleted", pa.int8()), ] for k in keep_keys: schema_fields.append((k, pa.string())) for k in keep_keys: schema_fields.append((f"{k}_last_value", pa.string())) schema_fields.append(("tag_key", pa.string())) schema = pa.schema(schema_fields) con = duckdb.connect() try: con.execute(f"SET memory_limit='{duckdb_memory_limit}'") con.execute(f"SET threads TO {int(threads)}") con.execute(f"SET temp_directory='{temp_dir.as_posix()}'") if verbose: print( f"DuckDB streaming observations " f"(threads={threads}, mem={duckdb_memory_limit})..." ) cursor = con.execute(sql) total = 0 buffer: list[dict] = [] with pq.ParquetWriter(output_path, schema, compression = "zstd") as writer: current_poi = None state = None while True: rows = cursor.fetchmany(batch_rows) if not rows: break for row in rows: poi_key = (row[col_idx["type"]], row[col_idx["id"]]) if poi_key != current_poi: current_poi = poi_key state = _init_scan_state(keep_keys) obs = _advance_scan_state( state, row, col_idx, tag_key, keep_keys ) if obs is not None: buffer.append(obs) total += 1 if buffer: writer.write_table(pa.Table.from_pylist(buffer, schema = schema)) buffer.clear() if buffer: writer.write_table(pa.Table.from_pylist(buffer, schema = schema)) buffer.clear() finally: con.close() if verbose: print(f"Wrote {total:,} observations to {output_path}") return total