Source code for gestalt._cache

"""Cache client helpers for provider processes."""

from __future__ import annotations

import datetime as _dt
import os
from dataclasses import dataclass
from typing import Any, Iterable
from urllib import parse as _urlparse

import grpc
from google.protobuf import duration_pb2 as _duration_pb2

from ._grpc_transport import (
    insecure_internal_channel,
    internal_channel_target,
    secure_internal_channel,
)
from .gen.v1 import cache_pb2 as _pb
from .gen.v1 import cache_pb2_grpc as _pb_grpc

pb: Any = _pb
pb_grpc: Any = _pb_grpc
duration_pb2: Any = _duration_pb2

ENV_CACHE_SOCKET = "GESTALT_CACHE_SOCKET"
_CACHE_SOCKET_TOKEN_SUFFIX = "_TOKEN"
_CACHE_RELAY_TOKEN_HEADER = "x-gestalt-host-service-relay-token"
ENV_CACHE_SOCKET_TOKEN = f"{ENV_CACHE_SOCKET}{_CACHE_SOCKET_TOKEN_SUFFIX}"


[docs] def cache_socket_env(name: str | None = None) -> str: """Return the environment variable name for a cache socket binding.""" trimmed = (name or "").strip() if not trimmed: return ENV_CACHE_SOCKET normalized = "".join( ch.upper() if ch.isascii() and ch.isalnum() else "_" for ch in trimmed ) return f"{ENV_CACHE_SOCKET}_{normalized}"
def cache_socket_token_env(name: str | None = None) -> str: """Return the environment variable name for a cache relay token.""" return f"{cache_socket_env(name)}{_CACHE_SOCKET_TOKEN_SUFFIX}"
[docs] @dataclass(frozen=True) class CacheEntry: """Cache key and value pair used by batch operations.""" key: str value: bytes
[docs] class Cache: """Client for a host-provided Gestalt cache provider.""" def __init__(self, name: str | None = None) -> None: env_name = cache_socket_env(name) target = os.environ.get(env_name, "") if not target: raise RuntimeError(f"{env_name} is not set") token = os.environ.get(cache_socket_token_env(name), "") self._channel = _cache_channel(target, token=token) self._stub = pb_grpc.CacheStub(self._channel)
[docs] def close(self) -> None: """Close the underlying gRPC channel.""" self._channel.close()
[docs] def get(self, key: str) -> bytes | None: """Return the cached value for ``key`` if it exists.""" resp = _grpc_call(self._stub.Get, pb.CacheGetRequest(key=key)) if not resp.found: return None return bytes(resp.value)
[docs] def get_many(self, keys: list[str]) -> dict[str, bytes]: """Return the subset of ``keys`` that currently exist.""" resp = _grpc_call(self._stub.GetMany, pb.CacheGetManyRequest(keys=keys)) out: dict[str, bytes] = {} for entry in resp.entries: if entry.found: out[entry.key] = bytes(entry.value) return out
[docs] def set( self, key: str, value: bytes, ttl: _dt.timedelta | None = None, ) -> None: """Store ``value`` for ``key`` with an optional TTL.""" _grpc_call( self._stub.Set, pb.CacheSetRequest(key=key, value=bytes(value), ttl=_duration_from_ttl(ttl)), )
[docs] def set_many( self, entries: Iterable[CacheEntry], ttl: _dt.timedelta | None = None, ) -> None: """Store multiple cache entries with one RPC.""" _grpc_call( self._stub.SetMany, pb.CacheSetManyRequest( entries=[ pb.CacheSetEntry(key=entry.key, value=bytes(entry.value)) for entry in entries ], ttl=_duration_from_ttl(ttl), ), )
[docs] def delete(self, key: str) -> bool: """Delete ``key`` and return whether an entry existed.""" resp = _grpc_call(self._stub.Delete, pb.CacheDeleteRequest(key=key)) return bool(resp.deleted)
[docs] def delete_many(self, keys: list[str]) -> int: """Delete multiple keys and return the number removed.""" resp = _grpc_call(self._stub.DeleteMany, pb.CacheDeleteManyRequest(keys=keys)) return int(resp.deleted)
[docs] def touch(self, key: str, ttl: _dt.timedelta) -> bool: """Refresh the TTL for ``key`` if the entry exists.""" resp = _grpc_call( self._stub.Touch, pb.CacheTouchRequest(key=key, ttl=_duration_from_ttl(ttl)), ) return bool(resp.touched)
[docs] def __enter__(self) -> Cache: """Return the cache client for ``with`` statements.""" return self
[docs] def __exit__(self, *args: Any) -> None: """Close the cache client at the end of a context manager block.""" self.close()
def _duration_from_ttl(ttl: _dt.timedelta | None) -> Any: if ttl is None: return None if ttl.total_seconds() <= 0: return None duration = duration_pb2.Duration() duration.FromTimedelta(ttl) return duration class _RelayTokenInterceptor(grpc.UnaryUnaryClientInterceptor): def __init__(self, token: str) -> None: self._token = token def intercept_unary_unary(self, continuation: Any, client_call_details: Any, request: Any) -> Any: metadata = list(client_call_details.metadata or []) metadata.append((_CACHE_RELAY_TOKEN_HEADER, self._token)) details = _ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, metadata=metadata, credentials=client_call_details.credentials, wait_for_ready=client_call_details.wait_for_ready, compression=client_call_details.compression, ) return continuation(details, request) class _ClientCallDetails(grpc.ClientCallDetails): def __init__( self, *, method: str, timeout: float | None, metadata: list[tuple[str, str]], credentials: Any, wait_for_ready: bool | None, compression: Any, ) -> None: self.method = method self.timeout = timeout self.metadata = metadata self.credentials = credentials self.wait_for_ready = wait_for_ready self.compression = compression def _cache_channel(target: str, *, token: str = "") -> Any: scheme, address = _parse_cache_target(target) if scheme == "unix": channel = insecure_internal_channel(internal_channel_target("unix", address)) elif scheme == "tcp": channel = insecure_internal_channel(internal_channel_target("tcp", address)) elif scheme == "tls": channel = secure_internal_channel(internal_channel_target("tls", address)) else: raise RuntimeError(f"unsupported cache transport scheme {scheme!r}") if token: channel = grpc.intercept_channel(channel, _RelayTokenInterceptor(token)) return channel def _parse_cache_target(raw: str) -> tuple[str, str]: target = raw.strip() if not target: raise RuntimeError("cache transport target is required") if target.startswith("tcp://"): address = target.removeprefix("tcp://").strip() if not address: raise RuntimeError(f"cache tcp target {raw!r} is missing host:port") return "tcp", address if target.startswith("tls://"): address = target.removeprefix("tls://").strip() if not address: raise RuntimeError(f"cache tls target {raw!r} is missing host:port") return "tls", address if target.startswith("unix://"): address = target.removeprefix("unix://").strip() if not address: raise RuntimeError(f"cache unix target {raw!r} is missing a socket path") return "unix", address if "://" in target: parsed = _urlparse.urlparse(target) raise RuntimeError(f"unsupported cache target scheme {parsed.scheme!r}") return "unix", target def _grpc_call(method: Any, request: Any) -> Any: try: return method(request) except grpc.RpcError: raise