Source code for gestalt._catalog

"""Catalog helpers for integration plugins.

The handwritten helpers in this module build and serialize catalog documents
around the generated ``Catalog`` protobuf messages exported by :mod:`gestalt`.
"""

from __future__ import annotations

import dataclasses
import pathlib
from collections.abc import Mapping
from dataclasses import MISSING
from typing import (
    Any,
    Iterable,
    Protocol,
    cast,
    get_origin,
    get_type_hints,
    runtime_checkable,
)

import yaml

from ._api import FIELD_DESCRIPTION_KEY, FIELD_REQUIRED_KEY, Request
from ._operations import OperationDefinition, is_optional_type, strip_optional

json_format: Any = cast(Any, None)
_struct_pb2: Any = cast(Any, None)
try:
    from google.protobuf import json_format as _json_format
    from google.protobuf import struct_pb2 as _google_struct_pb2
except ModuleNotFoundError:
    pass
else:
    json_format = _json_format
    _struct_pb2 = _google_struct_pb2

plugin_pb2: Any = cast(Any, None)
try:
    from .gen.v1 import plugin_pb2 as _plugin_pb2_module
except ModuleNotFoundError:
    pass
else:
    plugin_pb2 = _plugin_pb2_module

struct_pb2: Any = cast(Any, _struct_pb2)

Catalog: Any = plugin_pb2.Catalog if plugin_pb2 is not None else dict[str, Any]  # ty: ignore[unresolved-attribute]
CatalogOperation: Any = plugin_pb2.CatalogOperation if plugin_pb2 is not None else dict[str, Any]  # ty: ignore[unresolved-attribute]
CatalogParameter: Any = plugin_pb2.CatalogParameter if plugin_pb2 is not None else dict[str, Any]  # ty: ignore[unresolved-attribute]
OperationAnnotations: Any = plugin_pb2.OperationAnnotations if plugin_pb2 is not None else dict[str, Any]  # ty: ignore[unresolved-attribute]


[docs] @runtime_checkable class SessionCatalogProvider(Protocol): """Protocol for plugins that return a per-request catalog.""" def catalog_for_request( self, request: Request ) -> Catalog | Mapping[str, Any] | None: ...
def build_catalog( *, plugin_name: str, operations: Iterable[OperationDefinition], ) -> Catalog: """Build a catalog protobuf from authored operation definitions.""" if plugin_pb2 is None: return { "name": plugin_name, "operations": [_catalog_operation(op) for op in operations], } return Catalog( name=plugin_name, operations=[_catalog_operation(op) for op in operations], ) def catalog_to_proto(catalog: Catalog | Mapping[str, Any] | None) -> Catalog | None: """Normalize catalog input to a protobuf message.""" if catalog is None: return None if plugin_pb2 is not None and isinstance(catalog, Catalog): return catalog if isinstance(catalog, Mapping): return _catalog_from_mapping(catalog) raise TypeError("catalog must be a gestalt.Catalog or mapping") def catalog_to_dict( catalog: Catalog | Mapping[str, Any], *, field_style: str = "yaml" ) -> dict[str, Any]: """Convert a catalog protobuf or mapping into plain Python data.""" if plugin_pb2 is not None and isinstance(catalog, Catalog): raw = json_format.MessageToDict( catalog, preserving_proto_field_name=(field_style == "yaml") ) if "operations" not in raw: raw["operations"] = [] return raw if isinstance(catalog, Mapping): return dict(catalog) raise TypeError("catalog must be a gestalt.Catalog or mapping") def write_catalog( path: str | pathlib.Path, *, catalog: Catalog | Mapping[str, Any] ) -> None: """Write a catalog document to YAML on disk.""" catalog_path = pathlib.Path(path) catalog_path.parent.mkdir(parents=True, exist_ok=True) if isinstance(catalog, Mapping): as_dict = dict(catalog) else: as_dict = json_format.MessageToDict(catalog, preserving_proto_field_name=True) if "operations" not in as_dict: as_dict["operations"] = [] data = yaml.dump(as_dict, default_flow_style=False, sort_keys=False) catalog_path.write_text(data, encoding="utf-8") def _catalog_operation(operation: OperationDefinition) -> CatalogOperation: if plugin_pb2 is None: raw: dict[str, Any] = { "id": operation.id, "method": operation.method, "title": operation.title, "description": operation.description, "read_only": operation.read_only, "parameters": _catalog_parameters(operation.input_type), "allowed_roles": list(operation.allowed_roles), "tags": list(operation.tags), } if operation.visible is not None: raw["visible"] = operation.visible return raw op = cast( Any, CatalogOperation( id=operation.id, method=operation.method, title=operation.title, description=operation.description, read_only=operation.read_only, ), ) op.parameters.extend(_catalog_parameters(operation.input_type)) op.allowed_roles.extend(operation.allowed_roles) op.tags.extend(operation.tags) if operation.visible is not None: op.visible = operation.visible return op def _catalog_parameters(input_type: Any) -> list[CatalogParameter]: if input_type is None: return [] input_type = strip_optional(input_type) origin = get_origin(input_type) if origin is not None: input_type = origin if not dataclasses.is_dataclass(input_type): return [] type_hints = get_type_hints(input_type) parameters: list[CatalogParameter] = [] for field_definition in dataclasses.fields(input_type): annotation = type_hints.get(field_definition.name, field_definition.type) description = str( field_definition.metadata.get(FIELD_DESCRIPTION_KEY, "") ).strip() required = field_definition.metadata.get(FIELD_REQUIRED_KEY) if required is None: required = ( field_definition.default is MISSING and field_definition.default_factory is MISSING and not is_optional_type(annotation) ) if plugin_pb2 is None: param: dict[str, Any] = { "name": field_definition.name, "type": _catalog_type(annotation), "description": description, "required": bool(required), } if field_definition.default is not MISSING: param["default"] = field_definition.default parameters.append(param) continue param = cast( Any, CatalogParameter( name=field_definition.name, type=_catalog_type(annotation), ), ) param.description = description param.required = bool(required) if field_definition.default is not MISSING: param.default.CopyFrom( struct_pb2.Value(string_value=str(field_definition.default)) if isinstance(field_definition.default, str) else _to_proto_value(field_definition.default) ) parameters.append(param) return parameters def _to_proto_value(value: Any) -> Any: if value is None: return struct_pb2.Value(null_value=0) if isinstance(value, bool): return struct_pb2.Value(bool_value=value) if isinstance(value, (int, float)): return struct_pb2.Value(number_value=float(value)) if isinstance(value, str): return struct_pb2.Value(string_value=value) return struct_pb2.Value(string_value=str(value)) def _catalog_type(annotation: Any) -> str: actual_type = strip_optional(annotation) origin = get_origin(actual_type) if origin in (list, tuple, set): return "array" if origin is dict: return "object" if actual_type is str: return "string" if actual_type is bool: return "boolean" if actual_type is int: return "integer" if actual_type is float: return "number" if dataclasses.is_dataclass(actual_type): return "object" if actual_type in (dict, list, tuple, set): return "object" if actual_type is dict else "array" return "object" def _catalog_from_mapping(data: Mapping[str, Any]) -> Catalog: if plugin_pb2 is None: return dict(data) catalog = cast( Any, Catalog( name=data.get("name", ""), display_name=data.get("display_name", data.get("displayName", "")), description=data.get("description", ""), icon_svg=data.get("icon_svg", data.get("iconSvg", "")), ), ) for raw_op in data.get("operations", []): op = cast( Any, CatalogOperation( id=raw_op.get("id", ""), method=raw_op.get("method", ""), title=raw_op.get("title", ""), description=raw_op.get("description", ""), input_schema=raw_op.get("input_schema", raw_op.get("inputSchema", "")), output_schema=raw_op.get("output_schema", raw_op.get("outputSchema", "")), read_only=raw_op.get("read_only", raw_op.get("readOnly", False)), transport=raw_op.get("transport", ""), ), ) visible = raw_op.get("visible") if visible is not None: op.visible = visible op.allowed_roles.extend( raw_op.get("allowed_roles", raw_op.get("allowedRoles", [])) ) raw_ann = raw_op.get("annotations") or {} if raw_ann: op.annotations.CopyFrom( OperationAnnotations( read_only_hint=raw_ann.get( "read_only_hint", raw_ann.get("readOnlyHint") ), idempotent_hint=raw_ann.get( "idempotent_hint", raw_ann.get("idempotentHint") ), destructive_hint=raw_ann.get( "destructive_hint", raw_ann.get("destructiveHint") ), open_world_hint=raw_ann.get( "open_world_hint", raw_ann.get("openWorldHint") ), ) ) for raw_param in raw_op.get("parameters", []): param = cast( Any, CatalogParameter( name=raw_param.get("name", ""), type=raw_param.get("type", ""), description=raw_param.get("description", ""), required=raw_param.get("required", False), ), ) op.parameters.append(param) op.tags.extend(raw_op.get("tags", [])) op.required_scopes.extend( raw_op.get("required_scopes", raw_op.get("requiredScopes", [])) ) catalog.operations.append(op) return catalog