"""S3-compatible client helpers for provider processes."""
from __future__ import annotations
import datetime as _dt
import io
import json
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, BinaryIO, Iterable, Iterator, Protocol, cast
from urllib import parse as _urlparse
import grpc
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from ._grpc_transport import (
insecure_internal_channel,
internal_channel_target,
secure_internal_channel,
)
from .gen.v1 import s3_pb2 as _pb
from .gen.v1 import s3_pb2_grpc as _pb_grpc
pb: Any = _pb
pb_grpc: Any = _pb_grpc
timestamp_pb2: Any = _timestamp_pb2
#: Base environment variable for discovering S3 runtime sockets.
ENV_S3_SOCKET = "GESTALT_S3_SOCKET"
_S3_SOCKET_TOKEN_SUFFIX = "_TOKEN"
_S3_RELAY_TOKEN_HEADER = "x-gestalt-host-service-relay-token"
ENV_S3_SOCKET_TOKEN = f"{ENV_S3_SOCKET}{_S3_SOCKET_TOKEN_SUFFIX}"
_WRITE_CHUNK_SIZE = 64 * 1024
_UTC = _dt.timezone.utc
BytesData = bytes
BytesLike = bytes | bytearray | memoryview
ObjectBody = BytesLike | BinaryIO | Iterable[bytes] | None
[docs]
def s3_socket_env(name: str | None = None) -> str:
"""Return the environment variable name for an S3 socket binding."""
trimmed = (name or "").strip()
if not trimmed:
return ENV_S3_SOCKET
normalized = "".join(
ch.upper() if ("a" <= ch <= "z" or "A" <= ch <= "Z" or "0" <= ch <= "9") else "_"
for ch in trimmed
)
return f"{ENV_S3_SOCKET}_{normalized}"
def s3_socket_token_env(name: str | None = None) -> str:
"""Return the environment variable name for an S3 relay token."""
return f"{s3_socket_env(name)}{_S3_SOCKET_TOKEN_SUFFIX}"
[docs]
class S3NotFoundError(Exception):
"""Raised when the requested object does not exist."""
pass
[docs]
class S3PreconditionFailedError(Exception):
"""Raised when conditional request headers fail."""
pass
[docs]
class S3InvalidRangeError(Exception):
"""Raised when a requested byte range is invalid."""
pass
[docs]
@dataclass
class ObjectRef:
"""Reference to an S3 object, optionally pinned to a version."""
bucket: str
key: str
version_id: str = ""
[docs]
@dataclass
class ByteRange:
"""Inclusive byte range used for object reads."""
start: int | None = None
end: int | None = None
[docs]
@dataclass
class ReadOptions:
"""Conditional and ranged read options for an object request."""
range: ByteRange | None = None
if_match: str = ""
if_none_match: str = ""
if_modified_since: _dt.datetime | None = None
if_unmodified_since: _dt.datetime | None = None
[docs]
@dataclass
class WriteOptions:
"""Metadata and precondition options for object writes."""
content_type: str = ""
cache_control: str = ""
content_disposition: str = ""
content_encoding: str = ""
content_language: str = ""
metadata: dict[str, str] = field(default_factory=dict)
if_match: str = ""
if_none_match: str = ""
[docs]
@dataclass
class ListOptions:
"""Query options for listing objects within a bucket."""
bucket: str
prefix: str = ""
delimiter: str = ""
continuation_token: str = ""
start_after: str = ""
max_keys: int = 0
[docs]
@dataclass
class ListPage:
"""One page of object-listing results."""
objects: list[ObjectMeta] = field(default_factory=list)
common_prefixes: list[str] = field(default_factory=list)
next_continuation_token: str = ""
has_more: bool = False
[docs]
@dataclass
class CopyOptions:
"""Conditional headers for copy operations."""
if_match: str = ""
if_none_match: str = ""
[docs]
class PresignMethod(str, Enum):
"""HTTP methods supported by presigned object URLs."""
GET = "GET"
PUT = "PUT"
DELETE = "DELETE"
HEAD = "HEAD"
[docs]
@dataclass
class PresignOptions:
"""Options for generating a presigned object URL."""
method: PresignMethod | str | None = None
expires: _dt.timedelta | None = None
content_type: str = ""
content_disposition: str = ""
headers: dict[str, str] = field(default_factory=dict)
[docs]
@dataclass
class PresignResult:
"""Presigned URL response returned by the provider."""
url: str
method: PresignMethod | str | None = None
expires_at: _dt.datetime | None = None
headers: dict[str, str] = field(default_factory=dict)
ObjectAccessURLOptions = PresignOptions
ObjectAccessURL = PresignResult
[docs]
class S3ReadStream:
"""Streaming object body reader returned by :meth:`S3.read_object`."""
def __init__(self, stream: Any) -> None:
self._stream = stream
self._buffer = bytearray()
self._closed = False
def __iter__(self) -> Iterator[bytes]:
"""Iterate over the remaining data chunks."""
return self.iter_chunks()
[docs]
def __enter__(self) -> S3ReadStream:
"""Return the stream for ``with`` statements."""
return self
[docs]
def __exit__(self, *args: Any) -> None:
"""Close the stream at the end of a context manager block."""
self.close()
[docs]
def iter_chunks(self) -> Iterator[bytes]:
"""Yield the remaining object body chunks."""
if self._buffer:
chunk = bytes(self._buffer)
self._buffer.clear()
if chunk:
yield chunk
while True:
chunk = self._recv_chunk()
if chunk is None:
return
if chunk:
yield chunk
[docs]
def read(self, size: int = -1) -> bytes:
"""Read up to ``size`` bytes from the stream."""
if size == 0:
return b""
if size < 0:
parts: list[bytes] = []
if self._buffer:
parts.append(bytes(self._buffer))
self._buffer.clear()
while True:
chunk = self._recv_chunk()
if chunk is None:
break
if chunk:
parts.append(chunk)
return b"".join(parts)
while len(self._buffer) < size:
chunk = self._recv_chunk()
if chunk is None:
break
if chunk:
self._buffer.extend(chunk)
out = bytes(self._buffer[:size])
del self._buffer[:size]
return out
[docs]
def close(self) -> None:
"""Cancel the underlying RPC stream and discard buffered bytes."""
self._closed = True
self._buffer.clear()
cancel = getattr(self._stream, "cancel", None)
if callable(cancel):
cancel()
def _recv_chunk(self) -> bytes | None:
if self._closed:
return None
try:
msg = next(self._stream)
except StopIteration:
self._closed = True
return None
except grpc.RpcError as error:
self._closed = True
raise _map_grpc_error(error) from error
if msg.WhichOneof("result") == "meta":
raise RuntimeError("s3: read stream yielded metadata after the first frame")
return bytes(msg.data)
class _ClientCallDetails(grpc.ClientCallDetails):
def __init__(
self,
method: str,
timeout: float | None,
metadata: Any,
credentials: Any,
wait_for_ready: bool | None,
compression: Any,
) -> None:
self.method = method
self.timeout = timeout
self.metadata = metadata
self.credentials = credentials
self.wait_for_ready = wait_for_ready
self.compression = compression
class _ClientCallDetailsFields(Protocol):
method: str
timeout: float | None
metadata: Any
credentials: Any
wait_for_ready: bool | None
compression: Any
class _RelayTokenInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
):
def __init__(self, token: str) -> None:
self._token = token
def _details(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
details = cast(_ClientCallDetailsFields, client_call_details)
metadata = list(details.metadata or [])
metadata.append((_S3_RELAY_TOKEN_HEADER, self._token))
return _ClientCallDetails(
details.method,
details.timeout,
metadata,
details.credentials,
details.wait_for_ready,
details.compression,
)
def intercept_unary_unary(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._details(client_call_details), request)
def intercept_unary_stream(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request: Any,
) -> Any:
return continuation(self._details(client_call_details), request)
def intercept_stream_unary(
self,
continuation: Any,
client_call_details: grpc.ClientCallDetails,
request_iterator: Any,
) -> Any:
return continuation(self._details(client_call_details), request_iterator)
def _s3_channel(target: str, *, token: str = "") -> Any:
scheme, address = _parse_s3_target(target)
if scheme == "unix":
channel = insecure_internal_channel(internal_channel_target("unix", address))
elif scheme == "tcp":
channel = insecure_internal_channel(internal_channel_target("tcp", address))
elif scheme == "tls":
channel = secure_internal_channel(internal_channel_target("tls", address))
else:
raise RuntimeError(f"unsupported s3 transport scheme {scheme!r}")
token = token.strip()
if not token:
return channel
return grpc.intercept_channel(channel, _RelayTokenInterceptor(token))
def _parse_s3_target(raw: str) -> tuple[str, str]:
target = raw.strip()
if not target:
raise RuntimeError("s3 transport target is required")
if target.startswith("tcp://"):
address = target.removeprefix("tcp://").strip()
if not address:
raise RuntimeError(f"s3 tcp target {raw!r} is missing host:port")
return "tcp", address
if target.startswith("tls://"):
address = target.removeprefix("tls://").strip()
if not address:
raise RuntimeError(f"s3 tls target {raw!r} is missing host:port")
return "tls", address
if target.startswith("unix://"):
address = target.removeprefix("unix://").strip()
if not address:
raise RuntimeError(f"s3 unix target {raw!r} is missing a socket path")
return "unix", address
if "://" in target:
parsed = _urlparse.urlparse(target)
raise RuntimeError(f"unsupported s3 target scheme {parsed.scheme!r}")
return "unix", target
[docs]
class S3:
"""Client for a host-provided Gestalt S3 runtime."""
def __init__(self, name: str | None = None) -> None:
env_name = s3_socket_env(name)
target = os.environ.get(env_name, "")
if not target:
raise RuntimeError(f"{env_name} is not set")
token = os.environ.get(s3_socket_token_env(name), "")
self._channel = _s3_channel(target, token=token)
self._stub = pb_grpc.S3Stub(self._channel)
self._object_access_stub = pb_grpc.S3ObjectAccessStub(self._channel)
[docs]
def close(self) -> None:
"""Close the underlying gRPC channel."""
self._channel.close()
[docs]
def object(self, bucket: str, key: str) -> S3Object:
"""Return an object helper for the latest version."""
return S3Object(self, ObjectRef(bucket=bucket, key=key))
[docs]
def object_version(self, bucket: str, key: str, version_id: str) -> S3Object:
"""Return an object helper pinned to a specific version."""
return S3Object(self, ObjectRef(bucket=bucket, key=key, version_id=version_id))
[docs]
def head_object(self, ref: ObjectRef) -> ObjectMeta:
"""Fetch metadata for an object without reading its body."""
resp = _grpc_call(self._stub.HeadObject, pb.HeadObjectRequest(ref=_object_ref_to_proto(ref)))
return _object_meta_from_proto(resp.meta)
[docs]
def read_object(
self,
ref: ObjectRef,
opts: ReadOptions | None = None,
) -> tuple[ObjectMeta, S3ReadStream]:
"""Open a streaming read for an object."""
request = pb.ReadObjectRequest(ref=_object_ref_to_proto(ref))
if opts is not None:
if opts.range is not None:
request.range.CopyFrom(_byte_range_to_proto(opts.range))
request.if_match = opts.if_match
request.if_none_match = opts.if_none_match
if opts.if_modified_since is not None:
request.if_modified_since.CopyFrom(_timestamp_to_proto(opts.if_modified_since))
if opts.if_unmodified_since is not None:
request.if_unmodified_since.CopyFrom(
_timestamp_to_proto(opts.if_unmodified_since)
)
stream = self._stub.ReadObject(request)
try:
first = next(stream)
except StopIteration as error:
raise RuntimeError("s3: read stream ended before metadata") from error
except grpc.RpcError as error:
raise _map_grpc_error(error) from error
if first.WhichOneof("result") != "meta":
raise RuntimeError("s3: read stream did not start with metadata")
return _object_meta_from_proto(first.meta), S3ReadStream(stream)
[docs]
def write_object(
self,
ref: ObjectRef,
body: ObjectBody = None,
opts: WriteOptions | None = None,
) -> ObjectMeta:
"""Write an object body and return the resulting metadata."""
open_request = pb.WriteObjectOpen(ref=_object_ref_to_proto(ref))
if opts is not None:
open_request.content_type = opts.content_type
open_request.cache_control = opts.cache_control
open_request.content_disposition = opts.content_disposition
open_request.content_encoding = opts.content_encoding
open_request.content_language = opts.content_language
open_request.metadata.update(dict(opts.metadata))
open_request.if_match = opts.if_match
open_request.if_none_match = opts.if_none_match
response = _grpc_call(
self._stub.WriteObject,
_write_request_iter(open_request=open_request, body=body),
)
return _object_meta_from_proto(response.meta)
[docs]
def delete_object(self, ref: ObjectRef) -> None:
"""Delete an object."""
_grpc_call(self._stub.DeleteObject, pb.DeleteObjectRequest(ref=_object_ref_to_proto(ref)))
[docs]
def list_objects(self, opts: ListOptions) -> ListPage:
"""List objects within a bucket."""
resp = _grpc_call(
self._stub.ListObjects,
pb.ListObjectsRequest(
bucket=opts.bucket,
prefix=opts.prefix,
delimiter=opts.delimiter,
continuation_token=opts.continuation_token,
start_after=opts.start_after,
max_keys=opts.max_keys,
),
)
return _list_page_from_proto(resp)
[docs]
def copy_object(
self,
source: ObjectRef,
destination: ObjectRef,
opts: CopyOptions | None = None,
) -> ObjectMeta:
"""Copy an object and return metadata for the destination."""
request = pb.CopyObjectRequest(
source=_object_ref_to_proto(source),
destination=_object_ref_to_proto(destination),
)
if opts is not None:
request.if_match = opts.if_match
request.if_none_match = opts.if_none_match
resp = _grpc_call(self._stub.CopyObject, request)
return _object_meta_from_proto(resp.meta)
[docs]
def presign_object(
self,
ref: ObjectRef,
opts: PresignOptions | None = None,
) -> PresignResult:
"""Generate a presigned URL for an object operation."""
request = pb.PresignObjectRequest(ref=_object_ref_to_proto(ref))
if opts is not None:
request.method = _presign_method_to_proto(opts.method)
if opts.expires is not None:
request.expires_seconds = int(opts.expires.total_seconds())
request.content_type = opts.content_type
request.content_disposition = opts.content_disposition
request.headers.update(dict(opts.headers))
resp = _grpc_call(self._stub.PresignObject, request)
result = _presign_result_from_proto(resp)
if result.method is None and opts is not None:
result.method = _normalize_presign_method(opts.method)
return result
[docs]
def create_object_access_url(
self,
ref: ObjectRef,
opts: ObjectAccessURLOptions | None = None,
) -> ObjectAccessURL:
"""Create a host-mediated object access URL."""
request = pb.CreateObjectAccessURLRequest(ref=_object_ref_to_proto(ref))
if opts is not None:
request.method = _presign_method_to_proto(opts.method)
if opts.expires is not None:
request.expires_seconds = int(opts.expires.total_seconds())
request.content_type = opts.content_type
request.content_disposition = opts.content_disposition
request.headers.update(dict(opts.headers))
resp = _grpc_call(self._object_access_stub.CreateObjectAccessURL, request)
result = _presign_result_from_proto(resp)
if result.method is None and opts is not None:
result.method = _normalize_presign_method(opts.method)
return result
[docs]
def create_access_url(
self,
ref: ObjectRef,
opts: ObjectAccessURLOptions | None = None,
) -> ObjectAccessURL:
"""Create a host-mediated object access URL."""
return self.create_object_access_url(ref, opts)
[docs]
def __enter__(self) -> S3:
"""Return the client for ``with`` statements."""
return self
[docs]
def __exit__(self, *args: Any) -> None:
"""Close the client at the end of a context manager block."""
self.close()
[docs]
class S3Object:
"""Convenience wrapper for a single S3 object reference."""
def __init__(self, client: S3, ref: ObjectRef) -> None:
self._client = client
self.ref = ref
[docs]
def stat(self) -> ObjectMeta:
"""Fetch object metadata."""
return self._client.head_object(self.ref)
[docs]
def exists(self) -> bool:
"""Return whether the object exists."""
try:
self.stat()
return True
except S3NotFoundError:
return False
[docs]
def stream(self, opts: ReadOptions | None = None) -> tuple[ObjectMeta, S3ReadStream]:
"""Open a streaming read for the object."""
return self._client.read_object(self.ref, opts)
[docs]
def bytes(self, opts: ReadOptions | None = None) -> BytesData:
"""Read the full object body as bytes."""
_meta, stream = self.stream(opts)
with stream:
return stream.read()
[docs]
def text(self, opts: ReadOptions | None = None, *, encoding: str = "utf-8") -> str:
"""Read the full object body as text."""
return self.bytes(opts).decode(encoding)
[docs]
def json(self, opts: ReadOptions | None = None) -> Any:
"""Read and decode the full object body as JSON."""
return json.loads(self.bytes(opts))
[docs]
def write(
self,
body: ObjectBody = None,
opts: WriteOptions | None = None,
) -> ObjectMeta:
"""Write an object body."""
return self._client.write_object(self.ref, body, opts)
[docs]
def write_bytes(
self,
body: BytesLike,
opts: WriteOptions | None = None,
) -> ObjectMeta:
"""Write bytes to the object."""
return self.write(body, opts)
[docs]
def write_text(
self,
body: str,
opts: WriteOptions | None = None,
*,
encoding: str = "utf-8",
) -> ObjectMeta:
"""Encode and write text to the object."""
return self.write(body.encode(encoding), opts)
[docs]
def write_json(self, value: Any, opts: WriteOptions | None = None) -> ObjectMeta:
"""Encode and write JSON to the object."""
payload = json.dumps(value).encode("utf-8")
if opts is None:
opts = WriteOptions(content_type="application/json")
elif not opts.content_type:
opts = WriteOptions(
content_type="application/json",
cache_control=opts.cache_control,
content_disposition=opts.content_disposition,
content_encoding=opts.content_encoding,
content_language=opts.content_language,
metadata=dict(opts.metadata),
if_match=opts.if_match,
if_none_match=opts.if_none_match,
)
return self.write(payload, opts)
[docs]
def delete(self) -> None:
"""Delete the object."""
self._client.delete_object(self.ref)
[docs]
def presign(self, opts: PresignOptions | None = None) -> PresignResult:
"""Generate a presigned URL for this object."""
return self._client.presign_object(self.ref, opts)
[docs]
def create_access_url(self, opts: ObjectAccessURLOptions | None = None) -> ObjectAccessURL:
"""Create a host-mediated object access URL for this object."""
return self._client.create_object_access_url(self.ref, opts)
def _write_request_iter(
*,
open_request: Any,
body: ObjectBody,
) -> Iterator[Any]:
yield pb.WriteObjectRequest(open=open_request)
for chunk in _body_chunks(body):
if chunk:
yield pb.WriteObjectRequest(data=chunk)
def _body_chunks(
body: ObjectBody,
) -> Iterator[bytes]:
if body is None:
return
if isinstance(body, (bytes, bytearray, memoryview)):
data = bytes(body)
for start in range(0, len(data), _WRITE_CHUNK_SIZE):
yield data[start : start + _WRITE_CHUNK_SIZE]
return
if isinstance(body, io.IOBase):
while True:
chunk = body.read(_WRITE_CHUNK_SIZE)
if chunk in (b"", None):
return
yield _ensure_bytes(chunk)
reader = getattr(body, "read", None)
if callable(reader):
while True:
chunk = reader(_WRITE_CHUNK_SIZE)
if chunk in (b"", None):
return
yield _ensure_bytes(chunk)
for chunk in body:
piece = _ensure_bytes(chunk)
if piece:
yield piece
def _ensure_bytes(value: Any) -> bytes:
if isinstance(value, bytes):
return value
if isinstance(value, bytearray):
return bytes(value)
if isinstance(value, memoryview):
return value.tobytes()
raise TypeError("s3: body chunks must be bytes")
def _grpc_call(fn: Any, request: Any) -> Any:
try:
return fn(request)
except grpc.RpcError as error:
raise _map_grpc_error(error) from error
def _map_grpc_error(error: grpc.RpcError) -> Exception:
code = error.code() # ty: ignore[unresolved-attribute]
details = error.details() # ty: ignore[unresolved-attribute]
if code == grpc.StatusCode.NOT_FOUND:
return S3NotFoundError(details)
if code == grpc.StatusCode.FAILED_PRECONDITION:
return S3PreconditionFailedError(details)
if code == grpc.StatusCode.OUT_OF_RANGE:
return S3InvalidRangeError(details)
return error
def _object_ref_to_proto(ref: ObjectRef) -> Any:
return pb.S3ObjectRef(bucket=ref.bucket, key=ref.key, version_id=ref.version_id)
def _object_meta_from_proto(meta: Any) -> ObjectMeta:
last_modified: _dt.datetime | None = None
if meta.HasField("last_modified"):
last_modified = meta.last_modified.ToDatetime(tzinfo=_UTC)
return ObjectMeta(
ref=ObjectRef(
bucket=meta.ref.bucket,
key=meta.ref.key,
version_id=meta.ref.version_id,
),
etag=meta.etag,
size=meta.size,
content_type=meta.content_type,
last_modified=last_modified,
metadata=dict(meta.metadata),
storage_class=meta.storage_class,
)
def _byte_range_to_proto(range_value: ByteRange) -> Any:
out = pb.ByteRange()
if range_value.start is not None:
out.start = range_value.start
if range_value.end is not None:
out.end = range_value.end
return out
def _timestamp_to_proto(value: _dt.datetime) -> Any:
if value.tzinfo is None:
value = value.replace(tzinfo=_UTC)
else:
value = value.astimezone(_UTC)
out = timestamp_pb2.Timestamp()
out.FromDatetime(value)
return out
def _list_page_from_proto(resp: Any) -> ListPage:
return ListPage(
objects=[_object_meta_from_proto(item) for item in resp.objects],
common_prefixes=list(resp.common_prefixes),
next_continuation_token=resp.next_continuation_token,
has_more=resp.has_more,
)
def _presign_method_to_proto(method: PresignMethod | str | None) -> Any:
normalized = _presign_method_value(method)
return {
PresignMethod.GET.value: pb.PRESIGN_METHOD_GET,
PresignMethod.PUT.value: pb.PRESIGN_METHOD_PUT,
PresignMethod.DELETE.value: pb.PRESIGN_METHOD_DELETE,
PresignMethod.HEAD.value: pb.PRESIGN_METHOD_HEAD,
}.get(normalized, pb.PRESIGN_METHOD_UNSPECIFIED)
def _presign_method_from_proto(value: Any) -> PresignMethod | str | None:
return {
pb.PRESIGN_METHOD_GET: PresignMethod.GET,
pb.PRESIGN_METHOD_PUT: PresignMethod.PUT,
pb.PRESIGN_METHOD_DELETE: PresignMethod.DELETE,
pb.PRESIGN_METHOD_HEAD: PresignMethod.HEAD,
}.get(value)
def _normalize_presign_method(method: PresignMethod | str | None) -> PresignMethod | str | None:
normalized = _presign_method_value(method)
return {
PresignMethod.GET.value: PresignMethod.GET,
PresignMethod.PUT.value: PresignMethod.PUT,
PresignMethod.DELETE.value: PresignMethod.DELETE,
PresignMethod.HEAD.value: PresignMethod.HEAD,
}.get(normalized, method if method else None)
def _presign_method_value(method: PresignMethod | str | None) -> str:
if isinstance(method, PresignMethod):
return method.value.upper()
return str(method or "").strip().upper()
def _presign_result_from_proto(resp: Any) -> PresignResult:
expires_at: _dt.datetime | None = None
if resp.HasField("expires_at"):
expires_at = resp.expires_at.ToDatetime(tzinfo=_UTC)
return PresignResult(
url=resp.url,
method=_presign_method_from_proto(resp.method),
expires_at=expires_at,
headers=dict(resp.headers),
)