"""IndexedDB-style client helpers for provider processes."""
from __future__ import annotations
import datetime as _dt
import os
import queue
from dataclasses import dataclass, field
from typing import Any, Iterator, Protocol, cast
from urllib import parse as _urlparse
import grpc as _grpc
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from ._grpc_transport import (
insecure_internal_channel,
internal_channel_target,
secure_internal_channel,
)
from .gen.v1 import datastore_pb2 as _pb
from .gen.v1 import datastore_pb2_grpc as _pb_grpc
grpc: Any = cast(Any, _grpc)
pb: Any = cast(Any, _pb)
pb_grpc: Any = cast(Any, _pb_grpc)
struct_pb2: Any = cast(Any, _struct_pb2)
timestamp_pb2: Any = cast(Any, _timestamp_pb2)
ENV_INDEXEDDB_SOCKET = "GESTALT_INDEXEDDB_SOCKET"
_INDEXEDDB_SOCKET_TOKEN_SUFFIX = "_TOKEN"
_INDEXEDDB_RELAY_TOKEN_HEADER = "x-gestalt-host-service-relay-token"
#: Iterate in ascending key order.
CURSOR_NEXT = 0
#: Iterate in ascending key order while collapsing duplicate index keys.
CURSOR_NEXT_UNIQUE = 1
#: Iterate in descending key order.
CURSOR_PREV = 2
#: Iterate in descending key order while collapsing duplicate index keys.
CURSOR_PREV_UNIQUE = 3
[docs]
def indexeddb_socket_env(name: str | None = None) -> str:
"""Return the environment variable name for an IndexedDB socket binding."""
trimmed = (name or "").strip()
if not trimmed:
return ENV_INDEXEDDB_SOCKET
normalized = "".join(ch.upper() if ch.isalnum() else "_" for ch in trimmed)
return f"{ENV_INDEXEDDB_SOCKET}_{normalized}"
def indexeddb_socket_token_env(name: str | None = None) -> str:
"""Return the environment variable name for an IndexedDB relay token."""
return f"{indexeddb_socket_env(name)}{_INDEXEDDB_SOCKET_TOKEN_SUFFIX}"
[docs]
class NotFoundError(Exception):
"""Raised when an IndexedDB record, store, or cursor target is missing."""
pass
[docs]
class AlreadyExistsError(Exception):
"""Raised when an IndexedDB object already exists."""
pass
class TransactionError(Exception):
"""Raised when a transaction has failed or already finished."""
pass
[docs]
@dataclass
class KeyRange:
"""Lower and upper bounds for a cursor or range query."""
lower: Any = None
upper: Any = None
lower_open: bool = False
upper_open: bool = False
[docs]
@dataclass
class IndexSchema:
"""Definition for an index within an object store."""
name: str
key_path: list[str] = field(default_factory=list)
unique: bool = False
[docs]
@dataclass
class ObjectStoreSchema:
"""Schema definition for an object store."""
indexes: list[IndexSchema] = field(default_factory=list)
[docs]
class IndexedDB:
"""Client for a host-provided IndexedDB-compatible store."""
def __init__(self, name: str | None = None) -> None:
env_name = indexeddb_socket_env(name)
target = os.environ.get(env_name, "")
if not target:
raise RuntimeError(f"{env_name} is not set")
token = os.environ.get(indexeddb_socket_token_env(name), "")
self._channel = _indexeddb_channel(target, token=token)
self._stub = pb_grpc.IndexedDBStub(self._channel)
[docs]
def close(self) -> None:
"""Close the underlying gRPC channel."""
self._channel.close()
[docs]
def create_object_store(
self, name: str, schema: ObjectStoreSchema | None = None
) -> None:
"""Create an object store with an optional schema."""
pb_schema = pb.ObjectStoreSchema()
if schema:
for idx in schema.indexes:
pb_schema.indexes.append(
pb.IndexSchema(
name=idx.name, key_path=idx.key_path, unique=idx.unique
)
)
_grpc_call(
self._stub.CreateObjectStore,
pb.CreateObjectStoreRequest(name=name, schema=pb_schema),
)
[docs]
def delete_object_store(self, name: str) -> None:
"""Delete an object store by name."""
_grpc_call(self._stub.DeleteObjectStore, pb.DeleteObjectStoreRequest(name=name))
[docs]
def object_store(self, name: str) -> ObjectStore:
"""Return a client bound to an object store."""
return ObjectStore(self._stub, name)
[docs]
def transaction(
self,
stores: list[str],
mode: str = "readonly",
*,
durability_hint: str = "default",
) -> Transaction:
"""Start an explicit IndexedDB transaction."""
return Transaction(self._stub, stores, mode, durability_hint=durability_hint)
[docs]
def __enter__(self) -> IndexedDB:
"""Return the client for ``with`` statements."""
return self
[docs]
def __exit__(self, *args: Any) -> None:
"""Close the client at the end of a context manager block."""
self.close()
[docs]
class ObjectStore:
"""Client bound to a single IndexedDB object store."""
def __init__(self, stub: Any, store: str) -> None:
self._stub = stub
self._store = store
[docs]
def get(self, id: str) -> dict[str, Any]:
"""Fetch a record by primary key."""
resp = _grpc_call(
self._stub.Get, pb.ObjectStoreRequest(store=self._store, id=id)
)
return _record_to_dict(resp.record)
[docs]
def get_key(self, id: str) -> str:
"""Return the canonical key for a primary key lookup."""
resp = _grpc_call(
self._stub.GetKey, pb.ObjectStoreRequest(store=self._store, id=id)
)
return resp.key
[docs]
def add(self, record: dict[str, Any]) -> None:
"""Insert a new record."""
_grpc_call(
self._stub.Add,
pb.RecordRequest(store=self._store, record=_dict_to_record(record)),
)
[docs]
def put(self, record: dict[str, Any]) -> None:
"""Insert or replace a record."""
_grpc_call(
self._stub.Put,
pb.RecordRequest(store=self._store, record=_dict_to_record(record)),
)
[docs]
def delete(self, id: str) -> None:
"""Delete a record by primary key."""
_grpc_call(self._stub.Delete, pb.ObjectStoreRequest(store=self._store, id=id))
[docs]
def clear(self) -> None:
"""Delete every record in the store."""
_grpc_call(self._stub.Clear, pb.ObjectStoreNameRequest(store=self._store))
[docs]
def get_all(self, key_range: KeyRange | None = None) -> list[dict[str, Any]]:
"""Return all records that fall within ``key_range``."""
resp = _grpc_call(
self._stub.GetAll,
pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
),
)
return [_record_to_dict(r) for r in resp.records]
[docs]
def get_all_keys(self, key_range: KeyRange | None = None) -> list[str]:
"""Return all primary keys that fall within ``key_range``."""
resp = _grpc_call(
self._stub.GetAllKeys,
pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
),
)
return list(resp.keys)
[docs]
def count(self, key_range: KeyRange | None = None) -> int:
"""Return the number of matching records."""
resp = _grpc_call(
self._stub.Count,
pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
),
)
return resp.count
[docs]
def delete_range(self, key_range: KeyRange) -> int:
"""Delete all records within ``key_range``."""
resp = _grpc_call(
self._stub.DeleteRange,
pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
),
)
return resp.deleted
[docs]
def open_cursor(
self,
key_range: KeyRange | None = None,
direction: int = CURSOR_NEXT,
) -> Cursor:
"""Open a record cursor over the store."""
return Cursor(self._stub, self._store, key_range=key_range, direction=direction)
[docs]
def open_key_cursor(
self,
key_range: KeyRange | None = None,
direction: int = CURSOR_NEXT,
) -> Cursor:
"""Open a key-only cursor over the store."""
return Cursor(
self._stub,
self._store,
key_range=key_range,
direction=direction,
keys_only=True,
)
[docs]
def index(self, name: str) -> Index:
"""Return a client for a named index on this store."""
return Index(self._stub, self._store, name)
[docs]
class Index:
"""Client bound to a secondary index on an object store."""
def __init__(self, stub: Any, store: str, index: str) -> None:
self._stub = stub
self._store = store
self._index = index
[docs]
def get(self, *values: Any) -> dict[str, Any]:
"""Fetch the first matching record for the indexed values."""
resp = _grpc_call(self._stub.IndexGet, self._req(values))
return _record_to_dict(resp.record)
[docs]
def get_key(self, *values: Any) -> str:
"""Fetch the first matching primary key for the indexed values."""
resp = _grpc_call(self._stub.IndexGetKey, self._req(values))
return resp.key
[docs]
def get_all(
self, *values: Any, key_range: KeyRange | None = None
) -> list[dict[str, Any]]:
"""Return all records matching the indexed values and key range."""
resp = _grpc_call(self._stub.IndexGetAll, self._req(values, key_range))
return [_record_to_dict(r) for r in resp.records]
[docs]
def get_all_keys(
self, *values: Any, key_range: KeyRange | None = None
) -> list[str]:
"""Return all primary keys matching the indexed values and key range."""
resp = _grpc_call(self._stub.IndexGetAllKeys, self._req(values, key_range))
return list(resp.keys)
[docs]
def count(self, *values: Any, key_range: KeyRange | None = None) -> int:
"""Return the number of records matching the indexed values."""
resp = _grpc_call(self._stub.IndexCount, self._req(values, key_range))
return resp.count
[docs]
def delete(self, *values: Any) -> int:
"""Delete records matching the indexed values."""
resp = _grpc_call(self._stub.IndexDelete, self._req(values))
return resp.deleted
[docs]
def open_cursor(
self,
*values: Any,
key_range: KeyRange | None = None,
direction: int = CURSOR_NEXT,
) -> Cursor:
"""Open a record cursor over the indexed results."""
return Cursor(
self._stub,
self._store,
key_range=key_range,
direction=direction,
index=self._index,
values=values,
)
[docs]
def open_key_cursor(
self,
*values: Any,
key_range: KeyRange | None = None,
direction: int = CURSOR_NEXT,
) -> Cursor:
"""Open a key-only cursor over the indexed results."""
return Cursor(
self._stub,
self._store,
key_range=key_range,
direction=direction,
keys_only=True,
index=self._index,
values=values,
)
def _req(self, values: tuple[Any, ...], key_range: KeyRange | None = None) -> Any:
return pb.IndexQueryRequest(
store=self._store,
index=self._index,
values=[_to_typed_value(v) for v in values],
range=_kr_to_proto(key_range),
)
class Transaction:
"""Explicit IndexedDB transaction over a fixed object-store scope."""
def __init__(
self,
stub: Any,
stores: list[str],
mode: str = "readonly",
*,
durability_hint: str = "default",
) -> None:
self._stub = stub
self._closed = False
self._request_id = 0
self._request_iter = _RequestIterator()
self._request_iter.send(
pb.TransactionClientMessage(
begin=pb.BeginTransactionRequest(
stores=stores,
mode=_transaction_mode_to_proto(mode),
durability_hint=_durability_hint_to_proto(durability_hint),
)
)
)
self._response_iter = stub.Transaction(iter(self._request_iter))
try:
resp = next(self._response_iter)
except StopIteration:
self._closed = True
self._request_iter.close()
raise TransactionError("transaction stream ended during begin") from None
except grpc.RpcError as e:
self._closed = True
self._request_iter.close()
_raise_grpc_error(e)
if resp.WhichOneof("msg") != "begin":
self._closed = True
self._request_iter.close()
raise TransactionError("expected transaction begin response")
def object_store(self, name: str) -> TransactionObjectStore:
"""Return a transaction-scoped object store."""
return TransactionObjectStore(self, name)
def commit(self) -> None:
"""Commit the transaction."""
self._ensure_open()
self._closed = True
self._request_iter.send(
pb.TransactionClientMessage(commit=pb.TransactionCommitRequest())
)
try:
resp = next(self._response_iter)
except StopIteration:
self._request_iter.close()
raise TransactionError("transaction stream ended during commit") from None
except grpc.RpcError as e:
self._request_iter.close()
_raise_grpc_error(e)
self._request_iter.close()
if resp.WhichOneof("msg") != "commit":
raise TransactionError("expected transaction commit response")
_raise_rpc_status(resp.commit.error)
def abort(self) -> None:
"""Abort the transaction."""
if self._closed:
return
self._closed = True
self._request_iter.send(
pb.TransactionClientMessage(abort=pb.TransactionAbortRequest())
)
try:
resp = next(self._response_iter)
except StopIteration:
self._request_iter.close()
raise TransactionError("transaction stream ended during abort") from None
except grpc.RpcError as e:
self._request_iter.close()
_raise_grpc_error(e)
self._request_iter.close()
if resp.WhichOneof("msg") != "abort":
raise TransactionError("expected transaction abort response")
_raise_rpc_status(resp.abort.error)
def _send_operation(self, operation: Any) -> Any:
self._ensure_open()
self._request_id += 1
operation.request_id = self._request_id
self._request_iter.send(pb.TransactionClientMessage(operation=operation))
try:
resp = next(self._response_iter)
except StopIteration:
self._closed = True
self._request_iter.close()
raise TransactionError(
"transaction stream ended during operation"
) from None
except grpc.RpcError as e:
self._closed = True
self._request_iter.close()
_raise_grpc_error(e)
if resp.WhichOneof("msg") != "operation":
self._closed = True
self._request_iter.close()
raise TransactionError("expected transaction operation response")
op_resp = resp.operation
if op_resp.request_id != operation.request_id:
self._closed = True
self._request_iter.close()
raise TransactionError("transaction response request_id mismatch")
try:
_raise_rpc_status(op_resp.error)
except Exception:
self._closed = True
self._request_iter.close()
raise
return op_resp
def _ensure_open(self) -> None:
if self._closed:
raise TransactionError("transaction is already finished")
def __enter__(self) -> Transaction:
return self
def __exit__(self, exc_type: Any, _exc: Any, _tb: Any) -> None:
if exc_type is None:
self.commit()
else:
self.abort()
class TransactionObjectStore:
"""Transaction-scoped object store."""
def __init__(self, tx: Transaction, store: str) -> None:
self._tx = tx
self._store = store
def get(self, id: str) -> dict[str, Any]:
resp = self._tx._send_operation(
pb.TransactionOperation(get=pb.ObjectStoreRequest(store=self._store, id=id))
)
return _record_to_dict(resp.record.record)
def get_key(self, id: str) -> str:
resp = self._tx._send_operation(
pb.TransactionOperation(
get_key=pb.ObjectStoreRequest(store=self._store, id=id)
)
)
return resp.key.key
def add(self, record: dict[str, Any]) -> None:
self._tx._send_operation(
pb.TransactionOperation(
add=pb.RecordRequest(store=self._store, record=_dict_to_record(record))
)
)
def put(self, record: dict[str, Any]) -> None:
self._tx._send_operation(
pb.TransactionOperation(
put=pb.RecordRequest(store=self._store, record=_dict_to_record(record))
)
)
def delete(self, id: str) -> None:
self._tx._send_operation(
pb.TransactionOperation(
delete=pb.ObjectStoreRequest(store=self._store, id=id)
)
)
def clear(self) -> None:
self._tx._send_operation(
pb.TransactionOperation(clear=pb.ObjectStoreNameRequest(store=self._store))
)
def get_all(self, key_range: KeyRange | None = None) -> list[dict[str, Any]]:
resp = self._tx._send_operation(
pb.TransactionOperation(
get_all=pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
)
)
)
return [_record_to_dict(r) for r in resp.records.records]
def get_all_keys(self, key_range: KeyRange | None = None) -> list[str]:
resp = self._tx._send_operation(
pb.TransactionOperation(
get_all_keys=pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
)
)
)
return list(resp.keys.keys)
def count(self, key_range: KeyRange | None = None) -> int:
resp = self._tx._send_operation(
pb.TransactionOperation(
count=pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
)
)
)
return int(resp.count.count)
def delete_range(self, key_range: KeyRange) -> int:
resp = self._tx._send_operation(
pb.TransactionOperation(
delete_range=pb.ObjectStoreRangeRequest(
store=self._store, range=_kr_to_proto(key_range)
)
)
)
return int(resp.delete.deleted)
def index(self, name: str) -> TransactionIndex:
return TransactionIndex(self._tx, self._store, name)
class TransactionIndex:
"""Transaction-scoped secondary index."""
def __init__(self, tx: Transaction, store: str, index: str) -> None:
self._tx = tx
self._store = store
self._index = index
def get(self, *values: Any) -> dict[str, Any]:
resp = self._tx._send_operation(
pb.TransactionOperation(index_get=self._req(values))
)
return _record_to_dict(resp.record.record)
def get_key(self, *values: Any) -> str:
resp = self._tx._send_operation(
pb.TransactionOperation(index_get_key=self._req(values))
)
return resp.key.key
def get_all(
self, *values: Any, key_range: KeyRange | None = None
) -> list[dict[str, Any]]:
resp = self._tx._send_operation(
pb.TransactionOperation(index_get_all=self._req(values, key_range))
)
return [_record_to_dict(r) for r in resp.records.records]
def get_all_keys(
self, *values: Any, key_range: KeyRange | None = None
) -> list[str]:
resp = self._tx._send_operation(
pb.TransactionOperation(index_get_all_keys=self._req(values, key_range))
)
return list(resp.keys.keys)
def count(self, *values: Any, key_range: KeyRange | None = None) -> int:
resp = self._tx._send_operation(
pb.TransactionOperation(index_count=self._req(values, key_range))
)
return int(resp.count.count)
def delete(self, *values: Any) -> int:
resp = self._tx._send_operation(
pb.TransactionOperation(index_delete=self._req(values))
)
return int(resp.delete.deleted)
def _req(self, values: tuple[Any, ...], key_range: KeyRange | None = None) -> Any:
return pb.IndexQueryRequest(
store=self._store,
index=self._index,
values=[_to_typed_value(v) for v in values],
range=_kr_to_proto(key_range),
)
def _indexeddb_channel(raw_target: str, *, token: str = "") -> grpc.Channel:
target = raw_target.strip()
if not target:
raise RuntimeError("IndexedDB transport target is required")
if target.startswith("tcp://"):
address = target[len("tcp://") :].strip()
if not address:
raise RuntimeError(
f"IndexedDB tcp target {raw_target!r} is missing host:port"
)
return _with_indexeddb_relay_token(
insecure_internal_channel(internal_channel_target("tcp", address)),
token,
)
if target.startswith("tls://"):
address = target[len("tls://") :].strip()
if not address:
raise RuntimeError(
f"IndexedDB tls target {raw_target!r} is missing host:port"
)
return _with_indexeddb_relay_token(
secure_internal_channel(internal_channel_target("tls", address)),
token,
)
if target.startswith("unix://"):
socket_path = target[len("unix://") :].strip()
if not socket_path:
raise RuntimeError(
f"IndexedDB unix target {raw_target!r} is missing a socket path"
)
return _with_indexeddb_relay_token(
insecure_internal_channel(internal_channel_target("unix", socket_path)),
token,
)
if "://" in target:
parsed = _urlparse.urlparse(target)
raise RuntimeError(f"unsupported IndexedDB target scheme {parsed.scheme!r}")
return _with_indexeddb_relay_token(
insecure_internal_channel(internal_channel_target("unix", target)),
token,
)
def _with_indexeddb_relay_token(channel: grpc.Channel, token: str) -> grpc.Channel:
token = token.strip()
if not token:
return channel
class _ClientCallDetails(grpc.ClientCallDetails):
def __init__(
self,
method: str,
timeout: float | None,
metadata: Any,
credentials: Any,
wait_for_ready: bool | None,
compression: Any,
) -> None:
self.method = method
self.timeout = timeout
self.metadata = metadata
self.credentials = credentials
self.wait_for_ready = wait_for_ready
self.compression = compression
class _RelayTokenInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
def __init__(self, token: str) -> None:
self._token = token
def _details(
self, client_call_details: grpc.ClientCallDetails
) -> grpc.ClientCallDetails:
details = cast(_ClientCallDetailsFields, client_call_details)
metadata = list(details.metadata or [])
metadata.append((_INDEXEDDB_RELAY_TOKEN_HEADER, self._token))
return _ClientCallDetails(
details.method,
details.timeout,
metadata,
details.credentials,
details.wait_for_ready,
details.compression,
)
def intercept_unary_unary(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._details(client_call_details), request)
def intercept_stream_stream(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request_iterator: Any,
) -> Any:
return continuation(self._details(client_call_details), request_iterator)
return grpc.intercept_channel(channel, _RelayTokenInterceptor(token))
class _ClientCallDetailsFields(Protocol):
method: str
timeout: float | None
metadata: Any
credentials: Any
wait_for_ready: bool | None
compression: Any
class _RequestIterator:
def __init__(self) -> None:
self._q: queue.Queue[Any | None] = queue.Queue()
def send(self, msg: Any) -> None:
self._q.put(msg)
def close(self) -> None:
self._q.put(None)
def __iter__(self) -> Iterator[Any]:
return self
def __next__(self) -> Any:
item = self._q.get()
if item is None:
raise StopIteration
return item
[docs]
class Cursor:
"""Stateful cursor over object store or index results."""
def __init__(
self,
stub: Any,
store: str,
*,
key_range: KeyRange | None = None,
direction: int = CURSOR_NEXT,
keys_only: bool = False,
index: str = "",
values: tuple[Any, ...] = (),
) -> None:
self._keys_only = keys_only
self._closed = False
self._exhausted = False
self._index_cursor = bool(index)
self._key: Any = None
self._primary_key: str | None = None
self._record: dict[str, Any] | None = None
self._request_iter = _RequestIterator()
open_req = pb.OpenCursorRequest(
store=store,
range=_kr_to_proto(key_range),
direction=direction,
keys_only=keys_only,
index=index,
values=[_to_typed_value(v) for v in values],
)
self._request_iter.send(pb.CursorClientMessage(open=open_req))
self._response_iter = stub.OpenCursor(iter(self._request_iter))
# Read the open ack to surface creation errors synchronously.
try:
next(self._response_iter)
except grpc.RpcError as e:
self._closed = True
self._request_iter.close()
code = e.code()
details = e.details()
if code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError(details) from e
if code == grpc.StatusCode.ALREADY_EXISTS:
raise AlreadyExistsError(details) from e
raise
def _send_command(self, **kwargs: Any) -> Any:
cmd = pb.CursorCommand(**kwargs)
self._request_iter.send(pb.CursorClientMessage(command=cmd))
def _advance_to_next(self) -> bool:
try:
resp = next(self._response_iter)
except StopIteration:
self._closed = True
self._request_iter.close()
return False
except grpc.RpcError as e:
self._closed = True
self._request_iter.close()
code = e.code()
details = e.details()
if code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError(details) from e
if code == grpc.StatusCode.ALREADY_EXISTS:
raise AlreadyExistsError(details) from e
raise
result = resp.WhichOneof("result")
if result == "done":
self._key = None
self._primary_key = None
self._record = None
self._exhausted = True
return False
entry = resp.entry
keys = list(entry.key)
if not self._index_cursor and len(keys) == 1:
self._key = _key_value_to_python(keys[0])
elif len(keys) > 0:
self._key = [_key_value_to_python(k) for k in keys]
else:
self._key = None
self._primary_key = entry.primary_key
if not self._keys_only:
self._record = _record_to_dict(entry.record)
return True
[docs]
def continue_(self) -> bool:
"""Advance to the next matching cursor entry."""
if self._closed or self._exhausted:
return False
self._send_command(next=True)
return self._advance_to_next()
[docs]
def continue_to_key(self, key: Any) -> bool:
"""Advance the cursor to ``key`` or the next greater entry."""
if self._closed or self._exhausted:
return False
self._send_command(
continue_to_key=pb.CursorKeyTarget(
key=_cursor_key_to_proto(key, self._index_cursor),
)
)
return self._advance_to_next()
[docs]
def advance(self, count: int) -> bool:
"""Skip forward by ``count`` entries."""
if self._closed or self._exhausted:
return False
self._send_command(advance=count)
return self._advance_to_next()
@property
def key(self) -> Any:
"""Current key for the cursor entry."""
return self._key
@property
def primary_key(self) -> str | None:
"""Current primary key for the cursor entry."""
return self._primary_key
@property
def value(self) -> dict[str, Any]:
"""Current record value for the cursor entry."""
if self._keys_only:
raise TypeError("cursor opened with keys_only=True has no value")
if self._record is None:
raise TypeError("cursor is exhausted")
return self._record
def _refresh_from_entry(self, entry: Any) -> None:
keys = list(entry.key)
if not self._index_cursor and len(keys) == 1:
self._key = _key_value_to_python(keys[0])
elif len(keys) > 0:
self._key = [_key_value_to_python(k) for k in keys]
else:
self._key = None
self._primary_key = entry.primary_key
if not self._keys_only:
self._record = _record_to_dict(entry.record)
def _recv_mutation_ack(self) -> None:
try:
resp = next(self._response_iter)
except StopIteration:
self._closed = True
self._request_iter.close()
raise TypeError("cursor stream ended during mutation") from None
except grpc.RpcError as e:
self._closed = True
self._request_iter.close()
code = e.code()
details = e.details()
if code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError(details) from e
if code == grpc.StatusCode.ALREADY_EXISTS:
raise AlreadyExistsError(details) from e
raise
result = resp.WhichOneof("result")
if result == "entry":
self._refresh_from_entry(resp.entry)
[docs]
def delete(self) -> None:
"""Delete the current cursor entry."""
if self._exhausted:
raise NotFoundError("cursor is exhausted")
if self._closed:
raise TypeError("cursor is closed")
self._send_command(delete=True)
self._recv_mutation_ack()
[docs]
def update(self, value: dict[str, Any]) -> None:
"""Replace the current cursor entry with ``value``."""
if self._exhausted:
raise NotFoundError("cursor is exhausted")
if self._closed:
raise TypeError("cursor is closed")
self._send_command(update=_dict_to_record(value))
self._recv_mutation_ack()
[docs]
def close(self) -> None:
"""Close the cursor stream."""
if self._closed:
return
self._closed = True
self._key = None
self._primary_key = None
self._record = None
try:
self._send_command(close=True)
self._request_iter.close()
except Exception:
pass
[docs]
def __enter__(self) -> Cursor:
"""Return the cursor for ``with`` statements."""
return self
[docs]
def __exit__(self, *args: Any) -> None:
"""Close the cursor at the end of a context manager block."""
self.close()
def _grpc_call(method: Any, request: Any) -> Any:
try:
return method(request)
except grpc.RpcError as e:
_raise_grpc_error(e)
def _raise_grpc_error(err: grpc.RpcError) -> None:
code = err.code()
details = err.details()
if code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError(details) from err
if code == grpc.StatusCode.ALREADY_EXISTS:
raise AlreadyExistsError(details) from err
if code in (grpc.StatusCode.FAILED_PRECONDITION, grpc.StatusCode.INVALID_ARGUMENT):
raise TransactionError(details) from err
raise err
def _raise_rpc_status(status: Any) -> None:
if status is None or status.code == 0:
return
if status.code == 5:
raise NotFoundError(status.message)
if status.code == 6:
raise AlreadyExistsError(status.message)
raise TransactionError(status.message)
def _transaction_mode_to_proto(mode: str) -> int:
normalized = mode.replace("-", "").replace("_", "").lower()
if normalized == "readonly":
return pb.TRANSACTION_READONLY
if normalized == "readwrite":
return pb.TRANSACTION_READWRITE
raise ValueError(f"unsupported transaction mode: {mode!r}")
def _durability_hint_to_proto(hint: str) -> int:
normalized = hint.replace("-", "").replace("_", "").lower()
if normalized == "default":
return pb.TRANSACTION_DURABILITY_DEFAULT
if normalized == "strict":
return pb.TRANSACTION_DURABILITY_STRICT
if normalized == "relaxed":
return pb.TRANSACTION_DURABILITY_RELAXED
raise ValueError(f"unsupported transaction durability hint: {hint!r}")
def _dict_to_record(d: dict[str, Any]) -> Any:
record = pb.Record()
for key, value in d.items():
record.fields[key].CopyFrom(_to_typed_value(value))
return record
def _record_to_dict(record: Any) -> dict[str, Any]:
return {key: _typed_value_to_python(value) for key, value in record.fields.items()}
def _to_typed_value(v: Any) -> Any:
val = pb.TypedValue()
if v is None:
val.null_value = 0
elif isinstance(v, bool):
val.bool_value = v
elif isinstance(v, int) and not isinstance(v, bool):
val.int_value = v
elif isinstance(v, float):
val.float_value = v
elif isinstance(v, str):
val.string_value = v
elif isinstance(v, (bytes, bytearray, memoryview)):
val.bytes_value = bytes(v)
elif isinstance(v, _dt.datetime):
timestamp = timestamp_pb2.Timestamp()
dt = v if v.tzinfo is not None else v.replace(tzinfo=_dt.timezone.utc)
timestamp.FromDatetime(dt.astimezone(_dt.timezone.utc))
val.time_value.CopyFrom(timestamp)
else:
val.json_value.CopyFrom(_to_json_value(v))
return val
def _typed_value_to_python(v: Any) -> Any:
kind = v.WhichOneof("kind")
if kind in (None, "null_value"):
return None
if kind == "string_value":
return v.string_value
if kind == "int_value":
return v.int_value
if kind == "float_value":
return v.float_value
if kind == "bool_value":
return v.bool_value
if kind == "time_value":
return _dt.datetime.fromtimestamp(
v.time_value.seconds + (v.time_value.nanos / 1_000_000_000),
tz=_dt.timezone.utc,
)
if kind == "bytes_value":
return bytes(v.bytes_value)
if kind == "json_value":
return _json_value_to_python(v.json_value)
raise TypeError(f"unsupported typed value kind: {kind}")
def _to_json_value(v: Any) -> Any:
value = struct_pb2.Value()
if v is None:
value.null_value = 0
elif isinstance(v, bool):
value.bool_value = v
elif isinstance(v, (int, float)) and not isinstance(v, bool):
value.number_value = float(v)
elif isinstance(v, str):
value.string_value = v
elif isinstance(v, dict):
struct = struct_pb2.Struct()
for key, inner in v.items():
struct.fields[key].CopyFrom(_to_json_value(inner))
value.struct_value.CopyFrom(struct)
elif isinstance(v, (list, tuple)):
list_value = struct_pb2.ListValue()
for inner in v:
list_value.values.append(_to_json_value(inner))
value.list_value.CopyFrom(list_value)
else:
raise TypeError(f"unsupported JSON value type: {type(v)!r}")
return value
def _json_value_to_python(v: Any) -> Any:
kind = v.WhichOneof("kind")
if kind in (None, "null_value"):
return None
if kind == "number_value":
return v.number_value
if kind == "string_value":
return v.string_value
if kind == "bool_value":
return v.bool_value
if kind == "struct_value":
return {
key: _json_value_to_python(value)
for key, value in v.struct_value.fields.items()
}
if kind == "list_value":
return [_json_value_to_python(value) for value in v.list_value.values]
raise TypeError(f"unsupported JSON value kind: {kind}")
def _key_value_to_python(kv: Any) -> Any:
kind = kv.WhichOneof("kind")
if kind == "scalar":
return _typed_value_to_python(kv.scalar)
if kind == "array":
return [_key_value_to_python(elem) for elem in kv.array.elements]
return None
def _python_to_key_value(v: Any) -> Any:
if isinstance(v, (list, tuple)):
return pb.KeyValue(
array=pb.KeyValueArray(elements=[_python_to_key_value(elem) for elem in v])
)
return pb.KeyValue(scalar=_to_typed_value(v))
def _cursor_key_to_proto(key: Any, index_cursor: bool) -> list[Any]:
if index_cursor and isinstance(key, (list, tuple)):
return [_python_to_key_value(part) for part in key]
return [_python_to_key_value(key)]
def _kr_to_proto(kr: KeyRange | None) -> Any:
if kr is None:
return None
return pb.KeyRange(
lower=_to_typed_value(kr.lower) if kr.lower is not None else None,
upper=_to_typed_value(kr.upper) if kr.upper is not None else None,
lower_open=kr.lower_open,
upper_open=kr.upper_open,
)