Source code for gestalt._plugin

"""Plugin registration and decorator helpers for integration providers."""

from __future__ import annotations

import copy
import datetime as dt
import inspect
import json
import pathlib
import re
import sys
import types
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Final

import yaml

from ._api import Request, Subject
from ._catalog_helpers import catalog_parameters
from ._http_subject import HTTPSubjectRequest, clone_http_subject_request
from ._operations import (
    OperationDefinition,
    OperationResult,
    execute_operation,
    inspect_handler,
    run_sync,
)

DEFAULT_OPERATION_METHOD: Final[str] = "POST"

if TYPE_CHECKING:
    from ._catalog import Catalog


@dataclass(frozen=True)
class ConnectedToken:
    """Normalized connection payload passed into :meth:`Plugin.post_connect`."""

    id: str = ""
    subject_id: str = ""
    integration: str = ""
    connection: str = ""
    instance: str = ""
    access_token: str = ""
    refresh_token: str = ""
    scopes: str = ""
    expires_at: dt.datetime | None = None
    last_refreshed_at: dt.datetime | None = None
    refresh_error_count: int = 0
    metadata_json: str = ""
    metadata: dict[str, str] = field(default_factory=dict)
    created_at: dt.datetime | None = None
    updated_at: dt.datetime | None = None


[docs] class Plugin: """Integration plugin definition and operation registry. ``Plugin`` collects operation handlers, optional configuration hooks, and optional session catalog hooks before handing control to the runtime: .. code-block:: python from gestalt import Model, Plugin class SearchInput(Model): query: str plugin = Plugin("search") @plugin.operation(title="Search") def search(params: SearchInput): return {"query": params.query} """ def __init__( self, name: str, *, module_name: str | None = None, ) -> None: self.name = _slug_name(name) self._module_name = module_name self._operations: dict[str, OperationDefinition] = {} self._configure_handler: Any = None self._session_catalog_handler: tuple[Any, bool] | None = None self._post_connect_handler: tuple[Any, bool] | None = None self._http_subject_handler: tuple[Any, bool] | None = None
[docs] @classmethod def from_manifest( cls, path: str | pathlib.Path, *, base_dir: pathlib.Path | None = None, ) -> "Plugin": """Build a plugin name from a manifest path.""" manifest_path = pathlib.Path(path) if not manifest_path.is_absolute(): resolved_base = base_dir if base_dir is not None else pathlib.Path.cwd() manifest_path = resolved_base / manifest_path return cls(_derive_name_from_manifest(manifest_path))
[docs] def configure(self, func: Any) -> Any: """Register a configuration hook for provider name and config data.""" self._configure_handler = func return func
[docs] def session_catalog(self, func: Any) -> Any: """Register a per-request catalog hook.""" self._session_catalog_handler = (func, _inspect_session_catalog_handler(func)) return func
[docs] def post_connect(self, func: Any) -> Any: """Register a connect-time metadata hook.""" self._post_connect_handler = (func, _inspect_post_connect_handler(func)) return func
[docs] def http_subject(self, func: Any) -> Any: """Register a hosted HTTP subject-resolution hook.""" self._http_subject_handler = (func, _inspect_http_subject_handler(func)) return func
[docs] def operation( self, func: Any | None = None, /, *, id: str | None = None, method: str = DEFAULT_OPERATION_METHOD, title: str = "", description: str = "", allowed_roles: list[str] | None = None, tags: list[str] | None = None, read_only: bool = False, visible: bool | None = None, ) -> Any: """Register an operation handler on this plugin.""" def decorator(handler: Any) -> Any: operation_id = (id or handler.__name__).strip() if not operation_id: raise ValueError("operation id is required") if operation_id in self._operations: raise ValueError(f"duplicate operation id {operation_id!r}") input_type, takes_request = inspect_handler(handler) self._operations[operation_id] = OperationDefinition( id=operation_id, method=(method or DEFAULT_OPERATION_METHOD).upper(), title=title.strip(), description=description.strip(), allowed_roles=_normalize_allowed_roles(allowed_roles), tags=list(tags or []), read_only=read_only, visible=visible, handler=handler, input_type=input_type, takes_request=takes_request, ) return handler if func is None: return decorator return decorator(func)
[docs] def configure_provider(self, name: str, config: dict[str, Any]) -> None: """Invoke the registered configuration hook if one exists.""" handler = self._resolve_configure_handler() if handler is None: return run_sync(handler(name, config))
[docs] def execute( self, operation: str, params: dict[str, Any], request: Request ) -> OperationResult: """Execute a registered operation against request parameters.""" return execute_operation( self._operations.get(operation), params=params, request=request, )
def _static_catalog_dict(self) -> dict[str, Any]: return { "name": self.name, "operations": [ _catalog_operation_dict(operation) for operation in self._operations.values() ], }
[docs] def catalog_dict(self) -> dict[str, Any]: """Return the static plugin catalog as a plain dictionary.""" return copy.deepcopy(self._static_catalog_dict())
[docs] def write_catalog(self, path: str | pathlib.Path) -> None: """Write the static plugin catalog to disk.""" catalog_path = pathlib.Path(path) catalog_path.parent.mkdir(parents=True, exist_ok=True) data = yaml.dump( self._static_catalog_dict(), default_flow_style=False, sort_keys=False, ) catalog_path.write_text(data, encoding="utf-8")
[docs] def supports_session_catalog(self) -> bool: """Report whether the plugin exposes a session catalog hook.""" return self._resolve_session_catalog_handler() is not None
[docs] def supports_post_connect(self) -> bool: """Report whether the plugin exposes a post-connect metadata hook.""" return self._resolve_post_connect_handler() is not None
[docs] def supports_http_subject(self) -> bool: """Report whether the plugin exposes a hosted HTTP subject hook.""" return self._resolve_http_subject_handler() is not None
[docs] def catalog_for_request(self, request: Request) -> Catalog | dict[str, Any] | None: """Return a per-request catalog if the plugin defines one.""" definition = self._resolve_session_catalog_handler() if definition is None: return None handler, takes_request = definition if takes_request: return run_sync(handler(request)) return run_sync(handler())
[docs] def post_connect_metadata(self, token: ConnectedToken) -> dict[str, str] | None: """Return additional stored metadata after a connection is established.""" definition = self._resolve_post_connect_handler() if definition is None: return None handler, takes_token = definition if takes_token: return run_sync(handler(token)) return run_sync(handler())
[docs] def resolve_http_subject( self, request: HTTPSubjectRequest, context: Request, ) -> Subject | None: """Resolve an incoming hosted HTTP request to a Gestalt subject.""" definition = self._resolve_http_subject_handler() if definition is None: return None handler, takes_context = definition copied_request = clone_http_subject_request(request) copied_context = copy.deepcopy(context) if takes_context: return run_sync(handler(copied_request, copied_context)) return run_sync(handler(copied_request))
[docs] def serve(self) -> None: """Start the integration runtime for this plugin.""" from . import _runtime _runtime.serve(self)
def _resolve_configure_handler(self) -> Any: if self._configure_handler is not None: return self._configure_handler if not self._module_name: return None module = sys.modules.get(self._module_name) if module is None: return None configure = getattr(module, "configure", None) if callable(configure): self._configure_handler = configure return self._configure_handler def _resolve_session_catalog_handler(self) -> tuple[Any, bool] | None: if self._session_catalog_handler is not None: return self._session_catalog_handler if not self._module_name: return None module = sys.modules.get(self._module_name) if module is None: return None session_catalog = getattr(module, "session_catalog", None) if ( callable(session_catalog) and getattr(session_catalog, "__module__", None) == module.__name__ ): self._session_catalog_handler = ( session_catalog, _inspect_session_catalog_handler(session_catalog), ) return self._session_catalog_handler def _resolve_post_connect_handler(self) -> tuple[Any, bool] | None: if self._post_connect_handler is not None: return self._post_connect_handler if not self._module_name: return None module = sys.modules.get(self._module_name) if module is None: return None post_connect = getattr(module, "post_connect", None) if ( callable(post_connect) and getattr(post_connect, "__module__", None) == module.__name__ ): self._post_connect_handler = ( post_connect, _inspect_post_connect_handler(post_connect), ) return self._post_connect_handler def _resolve_http_subject_handler(self) -> tuple[Any, bool] | None: if self._http_subject_handler is not None: return self._http_subject_handler if not self._module_name: return None module = sys.modules.get(self._module_name) if module is None: return None http_subject = getattr(module, "resolve_http_subject", None) if ( callable(http_subject) and getattr(http_subject, "__module__", None) == module.__name__ ): self._http_subject_handler = ( http_subject, _inspect_http_subject_handler(http_subject), ) return self._http_subject_handler
class _ModulePluginRegistry: def __init__(self) -> None: self._plugins: dict[str, Plugin] = {} def for_function(self, func: Any) -> "Plugin": module = sys.modules.get(func.__module__) if module is None: raise RuntimeError(f"module {func.__module__!r} is not loaded") return self.for_module(module) def for_module(self, module: types.ModuleType) -> "Plugin": existing_plugin = getattr(module, "plugin", None) if isinstance(existing_plugin, Plugin): if existing_plugin._module_name is None: existing_plugin._module_name = module.__name__ self._plugins[module.__name__] = existing_plugin return existing_plugin plugin = self._plugins.get(module.__name__) if plugin is None: plugin = Plugin(_module_plugin_name(module), module_name=module.__name__) self._plugins[module.__name__] = plugin if not isinstance(getattr(module, "plugin", None), Plugin): setattr(module, "plugin", plugin) return plugin _MODULE_PLUGINS = _ModulePluginRegistry()
[docs] def operation( func: Any | None = None, /, *, id: str | None = None, method: str = DEFAULT_OPERATION_METHOD, title: str = "", description: str = "", allowed_roles: list[str] | None = None, tags: list[str] | None = None, read_only: bool = False, visible: bool | None = None, ) -> Any: """Register an operation on the calling module's implicit plugin. This decorator is useful when a module-level ``plugin`` object would be redundant: .. code-block:: python from gestalt import Model, operation class SearchInput(Model): query: str @operation(title="Search") def search(params: SearchInput): return {"query": params.query} """ def decorator(handler: Any) -> Any: plugin = _MODULE_PLUGINS.for_function(handler) return plugin.operation( id=id, method=method, title=title, description=description, allowed_roles=allowed_roles, tags=tags, read_only=read_only, visible=visible, )(handler) if func is None: return decorator return decorator(func)
[docs] def session_catalog(func: Any | None = None, /) -> Any: """Register a per-request catalog hook on the implicit module plugin.""" def decorator(handler: Any) -> Any: plugin = _MODULE_PLUGINS.for_function(handler) return plugin.session_catalog(handler) if func is None: return decorator return decorator(func)
def post_connect(func: Any | None = None, /) -> Any: """Register a connect-time metadata hook on the implicit module plugin.""" def decorator(handler: Any) -> Any: plugin = _MODULE_PLUGINS.for_function(handler) return plugin.post_connect(handler) if func is None: return decorator return decorator(func) def http_subject(func: Any | None = None, /) -> Any: """Register a hosted HTTP subject hook on the implicit module plugin.""" def decorator(handler: Any) -> Any: plugin = _MODULE_PLUGINS.for_function(handler) return plugin.http_subject(handler) if func is None: return decorator return decorator(func) def _module_plugin(module: types.ModuleType) -> "Plugin": return _MODULE_PLUGINS.for_module(module) def _normalize_allowed_roles(allowed_roles: list[str] | None) -> list[str]: normalized: list[str] = [] seen: set[str] = set() for role in allowed_roles or []: trimmed = role.strip() if not trimmed or trimmed in seen: continue seen.add(trimmed) normalized.append(trimmed) return normalized def _inspect_session_catalog_handler(func: Any) -> bool: signature = inspect.signature(func) parameters = list(signature.parameters.values()) type_hints = inspect.get_annotations(func, eval_str=True) if len(parameters) > 1: raise TypeError("session catalog handlers may declare at most one parameter") if not parameters: return False annotation = type_hints.get(parameters[0].name, parameters[0].annotation) if annotation not in (inspect.Signature.empty, Request): raise TypeError( "session catalog handler parameter must be annotated as gestalt.Request" ) return True def _inspect_post_connect_handler(func: Any) -> bool: signature = inspect.signature(func) parameters = list(signature.parameters.values()) type_hints = inspect.get_annotations(func, eval_str=True) if len(parameters) > 1: raise TypeError("post_connect handlers may declare at most one parameter") if not parameters: return False annotation = type_hints.get(parameters[0].name, parameters[0].annotation) if annotation not in (inspect.Signature.empty, ConnectedToken): raise TypeError( "post_connect handler parameter must be annotated as gestalt.ConnectedToken" ) return True def _inspect_http_subject_handler(func: Any) -> bool: signature = inspect.signature(func) parameters = list(signature.parameters.values()) type_hints = inspect.get_annotations(func, eval_str=True) if len(parameters) not in (1, 2): raise TypeError( "http subject handlers must declare request and optional context parameters" ) request_annotation = type_hints.get(parameters[0].name, parameters[0].annotation) if request_annotation not in (inspect.Signature.empty, HTTPSubjectRequest): raise TypeError( "http subject handler request parameter must be annotated as " "gestalt.HTTPSubjectRequest" ) if len(parameters) == 1: return False context_annotation = type_hints.get(parameters[1].name, parameters[1].annotation) if context_annotation not in (inspect.Signature.empty, Request): raise TypeError( "http subject handler context parameter must be annotated as " "gestalt.Request" ) return True def _catalog_operation_dict(operation: OperationDefinition) -> dict[str, Any]: payload: dict[str, Any] = { "id": operation.id, "method": operation.method, } if operation.title: payload["title"] = operation.title if operation.description: payload["description"] = operation.description if operation.read_only: payload["read_only"] = True parameters = _catalog_parameters_dict(operation.input_type) if parameters: payload["parameters"] = parameters if operation.allowed_roles: payload["allowed_roles"] = list(operation.allowed_roles) if operation.tags: payload["tags"] = list(operation.tags) if operation.visible is not None: payload["visible"] = operation.visible return payload def _catalog_parameters_dict(input_type: Any) -> list[dict[str, Any]]: parameters: list[dict[str, Any]] = [] for parameter in catalog_parameters(input_type): param: dict[str, Any] = { "name": parameter.name, "type": parameter.type, } if parameter.description: param["description"] = parameter.description if parameter.required: param["required"] = True if parameter.has_default: param["default"] = parameter.default parameters.append(param) return parameters def _module_plugin_name(module: types.ModuleType) -> str: file_path = getattr(module, "__file__", None) if file_path: manifest_path = pathlib.Path(file_path).resolve().parent / "manifest.yaml" return _derive_name_from_manifest(manifest_path) return _slug_name(module.__name__.rsplit(".", 1)[-1]) def _derive_name_from_manifest(path: pathlib.Path) -> str: manifest_path = path / "manifest.yaml" if path.is_dir() else path fallback_name = manifest_path.parent.name or "plugin" manifest_format = manifest_path.suffix.lower() try: text = manifest_path.read_text(encoding="utf-8") except OSError: return _slug_name(fallback_name) if manifest_format == ".json": return _name_from_json_manifest(text, fallback_name) return _name_from_yaml_manifest(text, fallback_name) def _name_from_manifest_dict(data: Any, fallback_name: str) -> str: if not isinstance(data, dict): return _slug_name(fallback_name) source = data.get("source") if isinstance(source, str) and source.strip(): return _slug_name(source.rsplit("/", 1)[-1]) display_name = data.get("display_name") if isinstance(display_name, str) and display_name.strip(): return _slug_name(display_name) return _slug_name(fallback_name) def _name_from_json_manifest(text: str, fallback_name: str) -> str: try: data = json.loads(text) except json.JSONDecodeError: return _slug_name(fallback_name) return _name_from_manifest_dict(data, fallback_name) class _TagIgnoringLoader(yaml.SafeLoader): pass def _construct_ignore_tag( loader: yaml.SafeLoader, _suffix: str, node: yaml.Node ) -> Any: if isinstance(node, yaml.ScalarNode): return loader.construct_scalar(node) if isinstance(node, yaml.SequenceNode): return loader.construct_sequence(node) if isinstance(node, yaml.MappingNode): return loader.construct_mapping(node) return None _TagIgnoringLoader.add_multi_constructor("", _construct_ignore_tag) def _name_from_yaml_manifest(text: str, fallback_name: str) -> str: try: data = yaml.load(text, Loader=_TagIgnoringLoader) except yaml.YAMLError: return _slug_name(fallback_name) return _name_from_manifest_dict(data, fallback_name) def _slug_name(value: str) -> str: cleaned = re.sub(r"[^A-Za-z0-9._-]+", "-", value.strip()).strip("-") return cleaned or "plugin"