diff --git a/.gitignore b/.gitignore index a325469..d72b83c 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,7 @@ dna/ torbi/ # Development/analysis folders (local only) +.nra-cache/ archive/ audit/ plans/ diff --git a/pyproject.toml b/pyproject.toml index 1dd9dfc..5d6f7da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,11 @@ classifiers = [ ] dependencies = [ + "arraybridge>=0.2.9", "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", + "imageio>=2.37.0", "zarr>=2.18.0,<3.0", # Required for ZarrStorageBackend "ome-zarr>=0.11.0", # Required for OME-ZARR HCS compliance ] @@ -197,4 +199,4 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"__init__.py" = ["F401"] # unused imports \ No newline at end of file +"__init__.py" = ["F401"] # unused imports diff --git a/src/polystore/__init__.py b/src/polystore/__init__.py index 5c38d68..123c449 100644 --- a/src/polystore/__init__.py +++ b/src/polystore/__init__.py @@ -26,10 +26,10 @@ get_backend, ) from .constants import Backend, MemoryType, TransportMode -from .disk import DiskStorageBackend +from .disk import DiskBackend, DiskStorageBackend from .filemanager import FileManager from .formats import FileFormat, DEFAULT_IMAGE_EXTENSIONS -from .memory import MemoryStorageBackend +from .memory import MemoryBackend, MemoryStorageBackend from .metadata_writer import ( AtomicMetadataWriter, MetadataWriteError, @@ -76,7 +76,9 @@ "register_cleanup_callback", "STORAGE_BACKENDS", "DiskStorageBackend", + "DiskBackend", "MemoryStorageBackend", + "MemoryBackend", "FileManager", "file_lock", "atomic_write_json", diff --git a/src/polystore/backend_registry.py b/src/polystore/backend_registry.py index ad8ac52..eb4cb21 100644 --- a/src/polystore/backend_registry.py +++ b/src/polystore/backend_registry.py @@ -74,7 +74,7 @@ def create_storage_registry() -> Dict[str, DataSink]: # Backends that require context-specific initialization (e.g., plate_root) # These are registered lazily when needed, not at startup - SKIP_BACKENDS = {'virtual_workspace', 'omero_local'} + SKIP_BACKENDS = {'virtual_workspace', 'omero_local', 'bioformats'} registry = {} for backend_type in STORAGE_BACKENDS.keys(): @@ -157,4 +157,3 @@ def cleanup_all_backends() -> None: _backend_instances.clear() logger.info("All backend instances cleaned up") - diff --git a/src/polystore/base.py b/src/polystore/base.py index 2b033fc..baa52fc 100644 --- a/src/polystore/base.py +++ b/src/polystore/base.py @@ -10,6 +10,7 @@ import logging import threading from abc import ABC, abstractmethod +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union from .constants import Backend @@ -34,13 +35,29 @@ class PicklableBackend(ABC): The pattern is: 1. Main process: Backend stores connection params via get_connection_params() - 2. Pickling: ProcessingContext preserves these params + 2. Pickling: FileManager preserves these params 3. Worker process: Backend recreates connection using set_connection_params() This uses nominal typing (ABC) not structural typing (Protocol), so explicit inheritance is required for isinstance() checks to work. """ + @classmethod + def from_connection_params( + cls, + params: Optional[Dict[str, Any]], + ) -> "PicklableBackend": + """ + Recreate a backend instance from worker-safe connection parameters. + + The default contract is a no-argument constructor followed by + set_connection_params(). Backends with required constructor arguments + must override this method. + """ + backend = cls() + backend.set_connection_params(params) + return backend + @abstractmethod def get_connection_params(self) -> Optional[Dict[str, Any]]: """ @@ -110,6 +127,11 @@ def requires_filesystem_validation(self) -> bool: Default is True for backwards compatibility. """ + def supports_file_path(self, path: Union[str, Path]) -> bool: + """Return whether this backend can save the requested file path.""" + del path + return self.supports_arbitrary_files + class DataSink(BackendBase): """ @@ -514,7 +536,7 @@ def get_backend(backend_type: str) -> DataSink: """ ensure_storage_registry() - if hasattr(backend_type, "value"): + if isinstance(backend_type, Enum): backend_type = backend_type.value backend_key = str(backend_type).lower() if backend_key not in storage_registry: @@ -546,15 +568,16 @@ def reset_memory_backend() -> None: # Clear files from existing memory backend while preserving directories memory_backend = storage_registry[Backend.MEMORY.value] - # DEBUG: Log what's in memory before clearing existing_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(existing_keys)} entries BEFORE clear") - logger.info(f"🔍 VFS_CLEAR: First 10 keys: {existing_keys[:10]}") + logger.debug("Memory backend has %s entries before clear", len(existing_keys)) + logger.debug("First memory backend keys before clear: %s", existing_keys[:10]) memory_backend.clear_files_only() - # DEBUG: Log what's in memory after clearing remaining_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(remaining_keys)} entries AFTER clear (directories only)") - logger.info(f"🔍 VFS_CLEAR: First 10 remaining keys: {remaining_keys[:10]}") + logger.debug( + "Memory backend has %s entries after clear (directories only)", + len(remaining_keys), + ) + logger.debug("First memory backend keys after clear: %s", remaining_keys[:10]) logger.info("Memory backend reset - files cleared, directories preserved") diff --git a/src/polystore/bioformats_java.py b/src/polystore/bioformats_java.py new file mode 100644 index 0000000..41c7824 --- /dev/null +++ b/src/polystore/bioformats_java.py @@ -0,0 +1,223 @@ +"""Shared Java Bio-Formats bridge for metadata discovery and plane loading.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from threading import Lock +from typing import Any, Callable + +import numpy as np + + +class BioFormatsJavaUnavailableError(RuntimeError): + """Raised when the Java Bio-Formats runtime cannot be initialized.""" + + +@dataclass(frozen=True, slots=True) +class BioFormatsOpenedReader: + """Open Bio-Formats reader plus its OME metadata store.""" + + reader: Any + metadata: Any + + def close(self) -> None: + self.reader.close() + + +class BioFormatsJavaContext: + """Lazy JVM/ImageJ context for Bio-Formats Java access.""" + + _lock = Lock() + _instance: "BioFormatsJavaContext | None" = None + + def __init__(self, imagej_module: Any, scyjava_module: Any): + self.imagej = imagej_module + self.scyjava = scyjava_module + self.ij = None + self.ImageReader = None + self.MetadataTools = None + self.FormatTools = None + + @classmethod + def instance(cls) -> "BioFormatsJavaContext": + with cls._lock: + if cls._instance is None: + cls._instance = cls._create() + return cls._instance + + @classmethod + def _create(cls) -> "BioFormatsJavaContext": + try: + import imagej + import scyjava + except ImportError as exc: + raise BioFormatsJavaUnavailableError( + "Bio-Formats support requires the optional bioformats/fiji dependencies." + ) from exc + return cls(imagej, scyjava) + + def ensure_initialized(self) -> None: + if self.ij is not None: + return + try: + self.ij = self.imagej.init("sc.fiji:fiji", mode="headless") + self.ImageReader = self.scyjava.jimport("loci.formats.ImageReader") + self.MetadataTools = self.scyjava.jimport("loci.formats.MetadataTools") + self.FormatTools = self.scyjava.jimport("loci.formats.FormatTools") + except Exception as exc: + raise BioFormatsJavaUnavailableError( + "Could not initialize Fiji/Bio-Formats through pyimagej." + ) from exc + + def open_reader(self, source_path: str | Path) -> BioFormatsOpenedReader: + self.ensure_initialized() + metadata = self.MetadataTools.createOMEXMLMetadata() + reader = self.ImageReader() + try: + reader.setMetadataStore(metadata) + reader.setId(str(source_path)) + return BioFormatsOpenedReader(reader=reader, metadata=metadata) + except Exception: + reader.close() + raise + + +def java_int(value: Any) -> int | None: + """Convert nullable Java primitive wrappers to Python int.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(int) + + +def java_float(value: Any) -> float | None: + """Convert nullable Java numeric wrappers to Python float.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(float) + + +def java_str(value: Any) -> str | None: + """Convert nullable Java strings to Python strings.""" + if value is None: + return None + return str(value) + + +def _read_java_value(value: Any) -> Any: + return value.value() + + +def _read_java_get_value(value: Any) -> Any: + return value.getValue() + + +@dataclass(frozen=True, slots=True) +class JavaScalarProjector: + """Project nullable Java scalar wrappers to Python scalar values.""" + + readers: tuple[Callable[[Any], Any], ...] + + def unwrap(self, value: Any) -> Any: + for reader in self.readers: + try: + return reader(value) + except AttributeError: + continue + return value + + +@dataclass(frozen=True, slots=True) +class OptionalJavaScalar: + """Nullable Java scalar after wrapper unwrapping.""" + + value: Any | None + + @classmethod + def from_java( + cls, + value: Any, + readers: tuple[Callable[[Any], Any], ...], + ) -> "OptionalJavaScalar": + if value is None: + return cls(None) + return cls(JavaScalarProjector(readers).unwrap(value)) + + def convert(self, converter: Callable[[Any], Any]) -> Any | None: + if self.value is None: + return None + return converter(self.value) + + +JAVA_SCALAR_PROJECTOR = JavaScalarProjector( + readers=( + _read_java_value, + _read_java_get_value, + ) +) + + +def load_bioformats_plane( + *, + source_path: Path, + series_index: int, + plane_index: int, +) -> np.ndarray: + """Load a single 2D Bio-Formats plane through the Java ImageReader.""" + context = BioFormatsJavaContext.instance() + opened = context.open_reader(source_path) + reader = opened.reader + try: + reader.setSeries(series_index) + if reader.getRGBChannelCount() != 1: + raise ValueError( + "Bio-Formats RGB/interleaved planes are not yet representable as " + "OpenHCS scalar channel planes." + ) + raw = bytes(reader.openBytes(plane_index)) + dtype = PixelDtypeCatalog.from_format_tools(context.FormatTools).dtype( + pixel_type=int(reader.getPixelType()), + little_endian=bool(reader.isLittleEndian()), + ) + array = np.frombuffer(raw, dtype=dtype) + return array.reshape((int(reader.getSizeY()), int(reader.getSizeX()))) + finally: + opened.close() + + +@dataclass(frozen=True, slots=True) +class PixelDtypeSpec: + """NumPy dtype projection for one Bio-Formats pixel type.""" + + key: int + dtype_code: str + endian_sensitive: bool = True + + def dtype(self, *, little_endian: bool) -> np.dtype: + if not self.endian_sensitive: + return np.dtype(self.dtype_code) + endian = "<" if little_endian else ">" + return np.dtype(endian + self.dtype_code) + + +@dataclass(frozen=True, slots=True) +class PixelDtypeCatalog: + """Authoritative Bio-Formats pixel-type to NumPy dtype mapping.""" + + specs_by_key: dict[int, PixelDtypeSpec] + + @classmethod + def from_format_tools(cls, format_tools: Any) -> "PixelDtypeCatalog": + specs = ( + PixelDtypeSpec(int(format_tools.INT8), "i1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.UINT8), "u1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.INT16), "i2"), + PixelDtypeSpec(int(format_tools.UINT16), "u2"), + PixelDtypeSpec(int(format_tools.INT32), "i4"), + PixelDtypeSpec(int(format_tools.UINT32), "u4"), + PixelDtypeSpec(int(format_tools.FLOAT), "f4"), + PixelDtypeSpec(int(format_tools.DOUBLE), "f8"), + ) + return cls({spec.key: spec for spec in specs}) + + def dtype(self, *, pixel_type: int, little_endian: bool) -> np.dtype: + try: + return self.specs_by_key[pixel_type].dtype(little_endian=little_endian) + except KeyError as exc: + raise ValueError(f"Unsupported Bio-Formats pixel type: {pixel_type}") from exc diff --git a/src/polystore/bioformats_storage.py b/src/polystore/bioformats_storage.py new file mode 100644 index 0000000..ba17dcf --- /dev/null +++ b/src/polystore/bioformats_storage.py @@ -0,0 +1,258 @@ +"""Structured-reference backend for Bio-Formats-backed virtual workspaces.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from fnmatch import fnmatch +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from .base import PicklableBackend, ReadOnlyBackend +from .constants import Backend +from .exceptions import StorageResolutionError +from .metadata_writer import get_metadata_path + + +@dataclass(frozen=True, slots=True) +class BioFormatsPlaneRef: + """Serializable reference to one Bio-Formats image plane.""" + + source_path: Path + series_index: int + plane_index: int + c: int + z: int + t: int + reader: str = "bioformats" + + @classmethod + def from_mapping( + cls, + payload: Dict[str, Any], + *, + plate_root: Path, + ) -> "BioFormatsPlaneRef": + source_path = Path(payload["source_path"]) + if not source_path.is_absolute(): + source_path = plate_root / source_path + return cls( + source_path=source_path, + series_index=int(payload.get("series_index", 0)), + plane_index=int(payload["plane_index"]), + c=int(payload["c"]), + z=int(payload["z"]), + t=int(payload["t"]), + reader=str(payload.get("reader", "bioformats")), + ) + + +class BioFormatsStorageBackend(ReadOnlyBackend, PicklableBackend): + """Load normalized virtual source keys from structured Bio-Formats refs.""" + + _backend_type = Backend.BIOFORMATS.value + + def __init__(self, plate_root: Path | None = None): + self.plate_root = None if plate_root is None else Path(plate_root) + self._mapping_cache: Optional[Dict[str, Dict[str, Any]]] = None + self._cache_mtime: Optional[float] = None + + def get_connection_params(self) -> Optional[Dict[str, Any]]: + if self.plate_root is None: + return None + return {"plate_root": str(self.plate_root)} + + def set_connection_params(self, params: Optional[Dict[str, Any]]) -> None: + if not params: + self.plate_root = None + self._mapping_cache = None + self._cache_mtime = None + return + self.plate_root = Path(params["plate_root"]) + self._mapping_cache = None + self._cache_mtime = None + + def load(self, file_path: Union[str, Path], **kwargs) -> Any: + ref = self._resolve_ref(file_path) + if ref.reader == "npy": + return _load_npy_plane(ref) + if ref.reader != "bioformats": + raise BioFormatsReaderUnavailableError( + f"Unsupported Bio-Formats reader {ref.reader!r}." + ) + from .bioformats_java import load_bioformats_plane + + return load_bioformats_plane( + source_path=ref.source_path, + series_index=ref.series_index, + plane_index=ref.plane_index, + ) + + def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: + return [self.load(file_path, **kwargs) for file_path in file_paths] + + def list_files( + self, + directory: Union[str, Path], + pattern: Optional[str] = None, + extensions: Optional[Set[str]] = None, + recursive: bool = False, + **kwargs, + ) -> List[str]: + plate_root = self._require_plate_root() + relative_dir = self.relative_to_root(directory) + normalized_dir = _normalize_relative_path(str(relative_dir)) + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) + results = [] + for virtual_path in self._load_mapping().keys(): + if not _virtual_path_in_directory( + virtual_path, + normalized_dir=normalized_dir, + recursive=recursive, + ): + continue + path = Path(virtual_path) + if lowercase_extensions is not None and path.suffix.lower() not in lowercase_extensions: + continue + if pattern is not None and not fnmatch(path.name, pattern): + continue + results.append(str(plate_root / virtual_path)) + return results + + def exists(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + if not relative: + return True + mapping = self._load_mapping() + return relative in mapping or any( + virtual_path.startswith(relative + "/") + for virtual_path in mapping + ) + + def is_file(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return relative in self._load_mapping() + + def is_dir(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return not relative or any( + virtual_path.startswith(relative + "/") + for virtual_path in self._load_mapping() + ) + + def list_dir(self, path: Union[str, Path]) -> List[str]: + relative = self.normalized_relative_path(path) + prefix = "" if not relative else relative + "/" + names = set() + for virtual_path in self._load_mapping(): + if not virtual_path.startswith(prefix): + continue + remainder = virtual_path[len(prefix):] + if remainder: + names.add(remainder.split("/", 1)[0]) + return sorted(names) + + def _resolve_ref(self, path: Union[str, Path]) -> BioFormatsPlaneRef: + plate_root = self._require_plate_root() + relative_path = self.normalized_relative_path(path) + mapping = self._load_mapping() + try: + payload = mapping[relative_path] + except KeyError as exc: + raise StorageResolutionError( + f"Path not in Bio-Formats workspace mapping: {relative_path}" + ) from exc + if not isinstance(payload, dict): + raise StorageResolutionError( + f"Bio-Formats workspace mapping for {relative_path!r} is not structured." + ) + return BioFormatsPlaneRef.from_mapping(payload, plate_root=plate_root) + + def _load_mapping(self) -> Dict[str, Dict[str, Any]]: + plate_root = self._require_plate_root() + metadata_path = get_metadata_path(plate_root) + if not metadata_path.exists(): + raise FileNotFoundError(f"Metadata not found: {metadata_path}") + current_mtime = metadata_path.stat().st_mtime + if self._mapping_cache is not None and self._cache_mtime == current_mtime: + return self._mapping_cache + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + combined_mapping: Dict[str, Dict[str, Any]] = {} + for subdirectory in metadata.get("subdirectories", {}).values(): + if Backend.BIOFORMATS.value not in subdirectory.get("available_backends", {}): + continue + workspace_mapping = subdirectory.get("workspace_mapping", {}) + for virtual_path, ref_payload in workspace_mapping.items(): + if isinstance(ref_payload, dict): + combined_mapping[_normalize_relative_path(str(virtual_path))] = ref_payload + if not combined_mapping: + raise ValueError(f"No Bio-Formats workspace_mapping in {metadata_path}") + self._mapping_cache = combined_mapping + self._cache_mtime = current_mtime + return combined_mapping + + def _require_plate_root(self) -> Path: + if self.plate_root is None: + raise StorageResolutionError("BioFormatsStorageBackend requires plate_root.") + return self.plate_root + + def relative_to_root(self, path: Union[str, Path]) -> Path: + plate_root = self._require_plate_root() + path_obj = Path(path) + if not path_obj.is_absolute(): + return path_obj + try: + return path_obj.relative_to(plate_root) + except ValueError as exc: + raise StorageResolutionError( + f"Path {path_obj} is outside Bio-Formats plate root {plate_root}." + ) from exc + + def normalized_relative_path(self, path: Union[str, Path]) -> str: + return _normalize_relative_path(str(self.relative_to_root(path))) + + +class BioFormatsReaderUnavailableError(RuntimeError): + """Raised when a production Bio-Formats reader has not been configured.""" + + +def _load_npy_plane(ref: BioFormatsPlaneRef) -> Any: + import numpy as np + + array = np.load(ref.source_path) + if array.ndim == 2: + return array + if array.ndim == 5: + return array[ref.t - 1, ref.z - 1, ref.c - 1] + if array.ndim == 3: + return array[ref.plane_index] + raise ValueError( + f"Unsupported npy Bio-Formats fixture shape {array.shape} for {ref.source_path}." + ) + + +def _normalize_relative_path(path: str) -> str: + normalized = path.replace("\\", "/") + return "" if normalized == "." else normalized + + +def _virtual_path_in_directory( + virtual_path: str, + *, + normalized_dir: str, + recursive: bool, +) -> bool: + if recursive: + return not normalized_dir or virtual_path.startswith(normalized_dir + "/") + return _normalize_relative_path(str(Path(virtual_path).parent)) == normalized_dir diff --git a/src/polystore/config.py b/src/polystore/config.py index c45e3e2..57879b5 100644 --- a/src/polystore/config.py +++ b/src/polystore/config.py @@ -13,7 +13,7 @@ class CompressorConfig: """Minimal compressor config used by Zarr backend when real compressors aren't provided.""" name: str = 'none' - def create_compressor(self, level: Optional[int], shuffle: bool = True) -> Optional[Any]: + def create(self, level: Optional[int], shuffle: bool = True) -> Optional[Any]: """Return a compressor object acceptable to zarr or None to disable compression.""" # Minimal fallback: return None (no compression) return None @@ -25,3 +25,7 @@ class ZarrConfig: compression_level: Optional[int] = None compressor: CompressorConfig = field(default_factory=CompressorConfig) chunk_strategy: ZarrChunkStrategy = ZarrChunkStrategy.WELL + + @property + def compressor_factory(self) -> CompressorConfig: + return self.compressor diff --git a/src/polystore/constants.py b/src/polystore/constants.py index 3a27cfb..0103236 100644 --- a/src/polystore/constants.py +++ b/src/polystore/constants.py @@ -19,6 +19,7 @@ class Backend(Enum): FIJI_STREAM = "fiji_stream" OMERO_LOCAL = "omero_local" VIRTUAL_WORKSPACE = "virtual_workspace" + BIOFORMATS = "bioformats" class TransportMode(Enum): diff --git a/src/polystore/disk.py b/src/polystore/disk.py index 40c33d9..a492770 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -9,6 +9,8 @@ import logging import os import shutil +import importlib +from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -23,7 +25,7 @@ def optional_import(module_name): try: - return __import__(module_name) + return importlib.import_module(module_name) except ImportError: return None @@ -44,6 +46,7 @@ def optional_import(module_name): cupy = get_cupy() tf = get_tf() tifffile = optional_import("tifffile") +imageio = optional_import("imageio.v3") # Optional arraybridge integration for memory conversion try: @@ -72,6 +75,30 @@ def is_registered(self, ext: str) -> bool: return ext.lower() in self._writers and ext.lower() in self._readers +@dataclass(frozen=True, slots=True) +class DiskFileFormatRegistration: + """One disk file format registration with explicit dependency availability.""" + + file_format: FileFormat + writer: Callable + reader: Callable + + +@dataclass(frozen=True, slots=True) +class DiskGlobPattern: + """Nominal glob pattern used by disk file listing.""" + + value: str + + @classmethod + def from_optional(cls, pattern: Optional[str]) -> "DiskGlobPattern": + if pattern is None: + return cls("*") + if pattern == "": + raise ValueError("Disk file listing pattern cannot be empty.") + return cls(pattern) + + class DiskStorageBackend(StorageBackend): """Disk storage backend with automatic registration.""" _backend_type = Backend.DISK.value @@ -87,32 +114,59 @@ def _register_formats(self): Complex formats (CSV, JSON, TIFF, ROI.ZIP, TEXT) use custom handlers. Simple formats (NumPy, Torch, CuPy, JAX, TensorFlow) use library save/load directly. """ - # Format handler metadata: (FileFormat enum, module_check, writer, reader) - # None for writer/reader means use the format's library save/load directly - format_handlers = [ - # Simple formats - use library save/load directly - (FileFormat.NUMPY, True, np.save, np.load), - (FileFormat.TORCH, torch, torch.save if torch else None, torch.load if torch else None), - (FileFormat.JAX, (jax and jnp), self._jax_writer, self._jax_reader), - (FileFormat.CUPY, cupy, self._cupy_writer, self._cupy_reader), - (FileFormat.TENSORFLOW, tf, self._tensorflow_writer, self._tensorflow_reader), - - # Complex formats - use custom handlers - (FileFormat.TIFF, tifffile, self._tiff_writer, self._tiff_reader), - (FileFormat.TEXT, True, self._text_writer, self._text_reader), - (FileFormat.JSON, True, self._json_writer, self._json_reader), - (FileFormat.CSV, True, self._csv_writer, self._csv_reader), - (FileFormat.ROI, True, self._roi_zip_writer, self._roi_zip_reader), - ] + format_handlers = self._available_format_registrations() # Register all available formats - for file_format, module_available, writer, reader in format_handlers: - if not module_available or writer is None or reader is None: - continue - + for registration in format_handlers: # Register all extensions for this format - for ext in file_format.extensions: - self.format_registry.register(ext.lower(), writer, reader) + for ext in registration.file_format.extensions: + self.format_registry.register( + ext.lower(), + registration.writer, + registration.reader, + ) + + def _available_format_registrations(self) -> List[DiskFileFormatRegistration]: + registrations = [ + DiskFileFormatRegistration(FileFormat.NUMPY, np.save, np.load), + DiskFileFormatRegistration(FileFormat.TEXT, self._text_writer, self._text_reader), + DiskFileFormatRegistration(FileFormat.JSON, self._json_writer, self._json_reader), + DiskFileFormatRegistration(FileFormat.CSV, self._csv_writer, self._csv_reader), + DiskFileFormatRegistration(FileFormat.ROI, self._roi_zip_writer, self._roi_zip_reader), + ] + if torch is not None: + registrations.append( + DiskFileFormatRegistration(FileFormat.TORCH, torch.save, torch.load) + ) + if jax is not None and jnp is not None: + registrations.append( + DiskFileFormatRegistration(FileFormat.JAX, self._jax_writer, self._jax_reader) + ) + if cupy is not None: + registrations.append( + DiskFileFormatRegistration(FileFormat.CUPY, self._cupy_writer, self._cupy_reader) + ) + if tf is not None: + registrations.append( + DiskFileFormatRegistration( + FileFormat.TENSORFLOW, + self._tensorflow_writer, + self._tensorflow_reader, + ) + ) + if tifffile is not None: + registrations.append( + DiskFileFormatRegistration(FileFormat.TIFF, self._tiff_writer, self._tiff_reader) + ) + if imageio is not None: + registrations.append( + DiskFileFormatRegistration( + FileFormat.RASTER_IMAGE, + self._image_writer, + self._image_reader, + ) + ) + return registrations # Format-specific writer/reader functions (pickleable) # Only needed for formats that require special handling beyond library save/load @@ -164,6 +218,14 @@ def _tiff_reader(self, path): else: return tifffile.imread(str(path)) + def _image_writer(self, path, data, **kwargs): + """Write standard raster images using imageio.""" + imageio.imwrite(path, np.asarray(data)) + + def _image_reader(self, path): + """Read standard raster images using imageio.""" + return imageio.imread(path) + def _text_writer(self, path, data, **kwargs): """Write text data to file. Accepts and ignores extra kwargs for compatibility.""" path.write_text(str(data)) @@ -261,7 +323,7 @@ def load(self, file_path: Union[str, Path], **kwargs) -> Any: ext = disk_path.suffix.lower() if not self.format_registry.is_registered(ext): - raise ValueError(f"No writer registered for extension '{ext}'") + raise ValueError(f"No reader registered for extension '{ext}'") try: reader = self.format_registry.get_reader(ext) @@ -356,13 +418,7 @@ def save_batch(self, data_list: List[Any], output_paths: List[Union[str, Path]], else: cpu_data_list.append(np.asarray(data)) else: - # Fallback conversion without arraybridge - if hasattr(data, "cpu") and hasattr(data, "numpy"): - cpu_data_list.append(data.cpu().numpy()) - elif hasattr(data, "get"): - cpu_data_list.append(data.get()) - else: - cpu_data_list.append(np.asarray(data)) + cpu_data_list.append(np.asarray(data)) # Save converted data using existing save method for cpu_data, output_path in zip(cpu_data_list, output_paths): @@ -397,7 +453,7 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, # Use breadth-first traversal to prioritize shallower files files = self._list_files_breadth_first(disk_directory, pattern) else: - glob_pattern = pattern if pattern else "*" + glob_pattern = DiskGlobPattern.from_optional(pattern).value # Include both regular files and symlinks (even broken ones) files = [p for p in disk_directory.glob(glob_pattern) if p.is_file() or p.is_symlink()] @@ -738,88 +794,37 @@ def _save_rois(self, rois: List, output_path: Path, images_dir: str = None, **kw Path where ROIs were saved """ import zipfile - import numpy as np - from .roi import PolygonShape, PolylineShape, MaskShape, PointShape, EllipseShape + from .roi import ( + ROI_ZIP_METADATA_MEMBER, + ROIArchivePath, + roi_zip_metadata_payload, + ) + from .roi_converters import FijiROIConverter - output_path = Path(output_path) + output_path = ROIArchivePath.from_output_path(output_path).path # Ensure output directory exists output_path.parent.mkdir(parents=True, exist_ok=True) - # Ensure output path has .roi.zip extension - if not output_path.name.endswith('.roi.zip'): - output_path = output_path.with_suffix('.roi.zip') - - try: - from roifile import ImagejRoi, ROI_TYPE - except ImportError: - logger.error("roifile library not available - cannot save ROIs") - raise ImportError("roifile library required for ROI saving. Install with: pip install roifile") - # Create .roi.zip archive - roi_count = 0 + roi_members = FijiROIConverter.rois_to_imagej_members(rois) + metadata_by_filename = {} + with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zf: - for idx, roi in enumerate(rois): - for shape in roi.shapes: - if isinstance(shape, PolygonShape): - # Convert polygon to ImageJ ROI - # roifile expects (x, y) coordinates, but we have (y, x) - coords_xy = shape.coordinates[:, [1, 0]] # Swap columns - ij_roi = ImagejRoi.frompoints(coords_xy) - - # Use incrementing counter for unique filenames (avoid duplicate names from label values) - ij_roi.name = f"ROI_{roi_count + 1}" - - # Write to zip archive - roi_bytes = ij_roi.tobytes() - zf.writestr(f"{roi_count + 1:04d}.roi", roi_bytes) - roi_count += 1 - - elif isinstance(shape, PolylineShape): - # Convert polyline to ImageJ polyline ROI - # roifile expects (x, y) coordinates, but we have (y, x) - coords_xy = shape.coordinates[:, [1, 0]] # Swap columns - ij_roi = ImagejRoi.frompoints(coords_xy) - ij_roi.roitype = ROI_TYPE.POLYLINE - - # Use incrementing counter for unique filenames - ij_roi.name = f"ROI_{roi_count + 1}" - - # Write to zip archive - roi_bytes = ij_roi.tobytes() - zf.writestr(f"{roi_count + 1:04d}.roi", roi_bytes) - roi_count += 1 - - elif isinstance(shape, PointShape): - # Convert point to ImageJ ROI - coords_xy = np.array([[shape.x, shape.y]]) - ij_roi = ImagejRoi.frompoints(coords_xy) - - ij_roi.name = f"ROI_{roi_count + 1}" - - roi_bytes = ij_roi.tobytes() - zf.writestr(f"{roi_count + 1:04d}.roi", roi_bytes) - roi_count += 1 - - elif isinstance(shape, EllipseShape): - # Convert ellipse to polygon approximation (ImageJ ROI format limitation) - # Generate 64 points around the ellipse - theta = np.linspace(0, 2 * np.pi, 64) - x = shape.center_x + shape.radius_x * np.cos(theta) - y = shape.center_y + shape.radius_y * np.sin(theta) - coords_xy = np.column_stack([x, y]) - - ij_roi = ImagejRoi.frompoints(coords_xy) - ij_roi.name = f"ROI_{roi_count + 1}" - - roi_bytes = ij_roi.tobytes() - zf.writestr(f"{roi_count + 1:04d}.roi", roi_bytes) - roi_count += 1 - - elif isinstance(shape, MaskShape): - # Skip mask shapes - ImageJ ROI format doesn't support binary masks - logger.warning(f"Skipping mask shape for ROI {idx} - not supported in ImageJ .roi format") - continue - - logger.info(f"Saved {roi_count} ROIs to .roi.zip archive: {output_path}") + for roi_count, member in enumerate(roi_members, start=1): + roi_filename = f"{roi_count:04d}.roi" + imagej_roi = member.imagej_roi + imagej_roi.name = f"ROI_{roi_count}" + metadata_by_filename[roi_filename] = member.metadata + zf.writestr(roi_filename, imagej_roi.tobytes()) + if metadata_by_filename: + zf.writestr( + ROI_ZIP_METADATA_MEMBER, + roi_zip_metadata_payload(metadata_by_filename), + ) + + logger.info(f"Saved {len(roi_members)} ROIs to .roi.zip archive: {output_path}") return str(output_path) + + +DiskBackend = DiskStorageBackend diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 4d52817..24a56b9 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -12,32 +12,140 @@ """ import logging -import time -from pathlib import Path -from typing import Any, List, Union +from enum import Enum -import zmq - -from .constants import Backend, TransportMode +from .constants import Backend from .streaming_constants import StreamingDataType -from .streaming import StreamingBackend +from .streaming import ( + FilePath, + RoiStreamPayload, + StreamingBuiltBatch, + StreamingBackend, + StreamingComponentNamesRequest, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +from .streaming.viewer_transport import ViewerStreamItemPayload, ViewerStreamRequest from .roi_converters import FijiROIConverter -from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode +from zmqruntime.viewer_protocol import ( + ViewerBatchItemWireField, + ViewerBatchWireField, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +class FijiDisplayWireField(str, Enum): + """Fiji-specific display fields inside the shared viewer display payload.""" + + LUT = "lut" + AUTO_CONTRAST = "auto_contrast" + + +class FijiDisplayPayload: + """Display payload projection for Fiji stream messages.""" + + @staticmethod + def auto_contrast_value(display_config) -> bool: + return display_config.auto_contrast + + @classmethod + def from_display_config(cls, display_config) -> dict[str, ViewerWireValue]: + return { + FijiDisplayWireField.LUT.value: display_config.get_lut_name(), + FijiDisplayWireField.AUTO_CONTRAST.value: cls.auto_contrast_value( + display_config + ), + } + + +class FijiMessageMetadata: + """Typed access to optional Fiji message metadata.""" + + @staticmethod + def component_names_metadata(message: ViewerWireMapping) -> ViewerWireValue: + return message[ViewerBatchWireField.COMPONENT_NAMES_METADATA.value] + + +class FijiRoiPayload: + """ROI payload inspection for Fiji logging.""" + + @staticmethod + def count(item_data: ViewerWireMapping) -> int: + if ViewerBatchItemWireField.ROIS.value not in item_data: + raise ValueError("Fiji ROI payload missing required 'rois' field") + return len(item_data[ViewerBatchItemWireField.ROIS.value]) + + class FijiStreamingBackend(StreamingBackend): """Fiji streaming backend with ZMQ publisher pattern (matches Napari architecture).""" _backend_type = Backend.FIJI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'fiji' SHM_PREFIX = 'fiji_' - # __init__, _get_publisher, save, cleanup now inherited from ABC + def display_payload_extra( + self, + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return ViewerDisplayPayloadExtra.from_mapping( + FijiDisplayPayload.from_display_config(stream_request.display_config) + ) + + def message_extra( + self, + stream_request: ViewerStreamRequest, + ) -> dict[str, ViewerWireValue]: + return stream_request.message_extra_payload_with_images_dir() + + def component_names_request( + self, + stream_request: ViewerStreamRequest, + ) -> StreamingComponentNamesRequest: + return StreamingComponentNamesRequest.from_stream_request( + stream_request, + log_prefix="🏷️ FIJI BACKEND", + verbose=True, + ) + + def after_batch_message_built( + self, + stream_request: ViewerStreamRequest, + built_batch: StreamingBuiltBatch, + ) -> None: + logger.info( + "🏷️ FIJI BACKEND: Final component_names_metadata: %s", + FijiMessageMetadata.component_names_metadata(built_batch.message), + ) + + for item in built_batch.batch_images: + logger.info( + "🔍 FIJI BACKEND: Added %s item to batch", + item[ViewerBatchItemWireField.DATA_TYPE.value], + ) + + data_types = [ + item[ViewerBatchItemWireField.DATA_TYPE.value] + for item in built_batch.batch_images + ] + type_counts = { + data_type: data_types.count(data_type) + for data_type in set(data_types) + } + logger.info( + "📤 FIJI BACKEND: Sending batch message with %d items to port %s: %s", + len(built_batch.batch_images), + stream_request.port, + type_counts, + ) - def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: + def _prepare_rois_data( + self, + data: RoiStreamPayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: """ Prepare ROIs data for transmission. @@ -53,126 +161,48 @@ def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: rois_encoded = FijiROIConverter.encode_rois_for_transmission(roi_bytes_list) return { - 'path': str(file_path), - 'rois': rois_encoded, + ViewerBatchItemWireField.PATH.value: str(file_path), + ViewerBatchItemWireField.ROIS.value: rois_encoded, } - def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - logger.info(f"🔍 FIJI BACKEND: Detected data type: {data_type} for path: {file_path}") - if data_type == StreamingDataType.SHAPES: - logger.info(f"🔍 FIJI BACKEND: Preparing ROI data for {file_path}") - item_data = self._prepare_rois_data(data, file_path) - data_type_value = "rois" - logger.info(f"🔍 FIJI BACKEND: ROI data prepared: {len(item_data.get('rois', []))} ROIs") - else: - logger.info(f"🔍 FIJI BACKEND: Preparing image data for {file_path}") - item_data = self._create_shared_memory(data, file_path) - data_type_value = "image" - return item_data, data_type_value - - def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], **kwargs) -> None: - """Stream batch of images or ROIs to Fiji via ZMQ.""" - - logger.info(f"📦 FIJI BACKEND: save_batch called with {len(data_list)} items") - - # Filter to only supported file types - data_list, file_paths, skipped = self._filter_streamable_files(data_list, file_paths) - if not data_list: - return - - # Extract kwargs using generic polymorphic names - host = kwargs.get('host', 'localhost') - port = kwargs['port'] - transport_mode = kwargs['transport_mode'] - transport_config = kwargs.get('transport_config') - display_config = kwargs['display_config'] - microscope_handler = kwargs['microscope_handler'] - source = kwargs.get('source', 'unknown_source') # Pre-built source value - images_dir = kwargs.get('images_dir') # Source image subdirectory for ROI mapping - plate_path = kwargs.get('plate_path') - logger.info(f"🏷️ FIJI BACKEND: plate_path = {plate_path}") - logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {microscope_handler}") - display_payload_extra = { - "lut": display_config.get_lut_name(), - "auto_contrast": display_config.auto_contrast if hasattr(display_config, "auto_contrast") else True, - } - message_extra = { - "images_dir": images_dir, - } - - message, batch_images, image_ids = self._build_batch_message( - data_list, - file_paths, - microscope_handler, - source, - display_config, - self._prepare_batch_item, - plate_path=plate_path, - component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, - display_payload_extra=display_payload_extra, - message_extra=message_extra, - ) - + def _prepare_batch_item( + self, + request: StreamingItemPreparationRequest, + ) -> ViewerStreamItemPayload: logger.info( - "🏷️ FIJI BACKEND: Final component_names_metadata: %s", - message.get("component_names_metadata", {}), - ) - - for item in batch_images: - logger.info(f"🔍 FIJI BACKEND: Added {item['data_type']} item to batch") - - # Log batch composition - data_types = [item['data_type'] for item in batch_images] - type_counts = {dt: data_types.count(dt) for dt in set(data_types)} - logger.info(f"📤 FIJI BACKEND: Sending batch message with {len(batch_images)} items to port {port}: {type_counts}") - - # Register sent images with queue tracker BEFORE sending - # This prevents race condition with IPC mode where acks arrive before registration - self._register_with_queue_tracker( - port, - image_ids, - transport_mode=transport_mode, - transport_config=transport_config, + "🔍 FIJI BACKEND: Detected data type: %s for path: %s", + request.streaming_data_type, + request.item_path.value, ) - - # Create FRESH REQ socket for each send - REQ sockets cannot be reused - # This prevents the "Operation cannot be accomplished in current state" error - # when multiple streams happen concurrently - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, + if request.streaming_data_type == StreamingDataType.SHAPES: + logger.info( + "🔍 FIJI BACKEND: Preparing ROI data for %s", + request.item_path.value, + ) + item_data = self._prepare_rois_data( + request.data, + request.item_path.value, + ) + output_streaming_data_type = StreamingDataType.ROIS + logger.info( + "🔍 FIJI BACKEND: ROI data prepared: %d ROIs", + FijiRoiPayload.count(item_data), + ) + else: + logger.info( + "🔍 FIJI BACKEND: Preparing image data for %s", + request.item_path.value, + ) + item_data = self.create_shared_memory_payload( + request.data, + request.item_path.value, + ) + output_streaming_data_type = StreamingDataType.IMAGE + return ViewerStreamItemPayload( + item_payload=item_data, + streaming_data_type=output_streaming_data_type, ) - if self._context is None: - self._context = zmq.Context() - - socket = self._context.socket(zmq.REQ) - socket.connect(url) - time.sleep(0.1) # Brief delay for connection to establish - - try: - # Send with REQ socket (BLOCKING - worker waits for Fiji to acknowledge) - # Worker blocks until Fiji receives, copies data from shared memory, and sends ack - # This guarantees no messages are lost and shared memory is only closed after Fiji is done - logger.info(f"📤 FIJI BACKEND: Sending batch of {len(batch_images)} images to Fiji on port {port} (REQ/REP - blocking until ack)") - socket.send_json(message) # Blocking send - - # Wait for acknowledgment from Fiji (REP socket) - # Fiji will only reply after it has copied all data from shared memory - ack_response = socket.recv_json() - logger.info(f"✅ FIJI BACKEND: Received ack from Fiji: {ack_response.get('status', 'unknown')}") - - finally: - # Always close the socket - never reuse REQ sockets - socket.close() - - # Clean up publisher's handles after successful send - # Receiver will unlink the shared memory after copying the data - self._cleanup_shared_memory_blocks(batch_images, unlink=False) - # cleanup() now inherited from ABC def __del__(self): diff --git a/src/polystore/filemanager.py b/src/polystore/filemanager.py index 16a348f..23ac2e5 100644 --- a/src/polystore/filemanager.py +++ b/src/polystore/filemanager.py @@ -6,11 +6,12 @@ """ import logging +from enum import Enum from pathlib import Path from typing import List, Set, Union, Tuple, Any from .formats import DEFAULT_IMAGE_EXTENSIONS -from .base import DataSink +from .base import DataSink, PicklableBackend from .exceptions import StorageResolutionError logger = logging.getLogger(__name__) @@ -49,6 +50,28 @@ def __init__(self, registry): logger.debug("FileManager initialized with registry") + def __getstate__(self) -> dict[str, Any]: + picklable_backends = {} + for backend_key, backend_instance in self.registry.items(): + if isinstance(backend_instance, PicklableBackend): + picklable_backends[backend_key] = ( + backend_instance.get_connection_params() + ) + return {"picklable_backends": picklable_backends} + + def __setstate__(self, state: dict[str, Any]) -> None: + from .backend_registry import STORAGE_BACKENDS + from .base import ensure_storage_registry, storage_registry + + ensure_storage_registry() + STORAGE_BACKENDS._discover() + for backend_key, connection_params in state["picklable_backends"].items(): + backend_class = STORAGE_BACKENDS[backend_key] + storage_registry[backend_key] = backend_class.from_connection_params( + connection_params + ) + self.registry = storage_registry + def _get_backend(self, backend_name: str) -> DataSink: """ Get a backend by name. @@ -75,7 +98,7 @@ def _get_backend(self, backend_name: str) -> DataSink: # Normalize backend name if backend_name is None: raise StorageResolutionError("Backend name must be provided") - if hasattr(backend_name, "value"): + if isinstance(backend_name, Enum): backend_name = backend_name.value backend_name = str(backend_name).lower() @@ -139,14 +162,7 @@ def save(self, data: Any, output_path: Union[str, Path], backend: str, **kwargs) try: backend_instance = self._get_backend(backend) - # If materialization context exists, merge it into kwargs - # This allows backends to access context like images_dir for OMERO ROI/analysis linking - if hasattr(self, '_materialization_context') and self._materialization_context: - # Merge context into kwargs (kwargs takes precedence if keys overlap) - merged_kwargs = {**self._materialization_context, **kwargs} - backend_instance.save(data, output_path, **merged_kwargs) - else: - backend_instance.save(data, output_path, **kwargs) + backend_instance.save(data, output_path, **kwargs) except StorageResolutionError: # Allow specific backend errors to propagate if they are StorageResolutionError raise except Exception as e: diff --git a/src/polystore/formats.py b/src/polystore/formats.py index ddfb9a5..3643361 100644 --- a/src/polystore/formats.py +++ b/src/polystore/formats.py @@ -20,6 +20,7 @@ class FileFormat(Enum): # Image formats TIFF = "tiff" + RASTER_IMAGE = "raster_image" # Data formats CSV = "csv" @@ -44,6 +45,7 @@ def extensions(self): FileFormat.TENSORFLOW: [".tf"], FileFormat.ZARR: [".zarr"], FileFormat.TIFF: [".tif", ".tiff"], + FileFormat.RASTER_IMAGE: [".bmp", ".gif", ".jpeg", ".jpg", ".png"], FileFormat.CSV: [".csv"], FileFormat.JSON: [".json"], FileFormat.TEXT: [".txt"], @@ -51,7 +53,14 @@ def extensions(self): } # Default image extensions -DEFAULT_IMAGE_EXTENSIONS = {".tif", ".tiff", ".TIF", ".TIFF"} +DEFAULT_IMAGE_EXTENSIONS = { + extension + for extensions in ( + FILE_FORMAT_EXTENSIONS[FileFormat.TIFF], + FILE_FORMAT_EXTENSIONS[FileFormat.RASTER_IMAGE], + ) + for extension in extensions +} def get_format_from_extension(ext: str) -> FileFormat: diff --git a/src/polystore/memory.py b/src/polystore/memory.py index a59114f..872d581 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -139,6 +139,9 @@ def list_files( if self._memory_store[dir_key] is not None: raise NotADirectoryError(f"Path is not a directory: {directory}") + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) result = [] dir_prefix = dir_key + "/" if not dir_key.endswith("/") else dir_key @@ -159,7 +162,10 @@ def list_files( filename = Path(rel_path).name # If pattern is None, match all files if pattern is None or fnmatch(filename, pattern): - if not extensions or Path(filename).suffix in extensions: + if ( + lowercase_extensions is None + or Path(filename).suffix.lower() in lowercase_extensions + ): # Calculate depth for breadth-first sorting depth = rel_path.count('/') result.append((Path(path), depth)) @@ -651,3 +657,6 @@ def __init__(self, target: str): def __repr__(self): return f"" + + +MemoryBackend = MemoryStorageBackend diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index 630bcc8..87ce4f2 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -13,32 +13,74 @@ """ import logging -import time -from pathlib import Path -from typing import Any, List, Union - -import zmq - -from .constants import Backend, TransportMode -from .streaming_constants import StreamingDataType -from .streaming import StreamingBackend +from enum import Enum + +from .constants import Backend +from .streaming import ( + FilePath, + RoiStreamPayload, + StreamingBackend, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +from .streaming.viewer_transport import ViewerStreamItemPayload, ViewerStreamRequest from .roi_converters import NapariROIConverter -from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode +from zmqruntime.viewer_protocol import ( + ViewerBatchItemWireField, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +class NapariDisplayWireField(str, Enum): + """Napari-specific display fields inside the shared viewer display payload.""" + + COLORMAP = "colormap" + VARIABLE_SIZE_HANDLING = "variable_size_handling" + + +class NapariDisplayPayload: + """Display payload projection for Napari stream messages.""" + + @staticmethod + def variable_size_handling_value(display_config): + variable_size_handling = display_config.variable_size_handling + if variable_size_handling is None: + return None + return variable_size_handling.value + + @classmethod + def from_display_config(cls, display_config) -> dict[str, ViewerWireValue]: + return { + NapariDisplayWireField.COLORMAP.value: display_config.get_colormap_name(), + NapariDisplayWireField.VARIABLE_SIZE_HANDLING.value: ( + cls.variable_size_handling_value(display_config) + ), + } + + class NapariStreamingBackend(StreamingBackend): """Napari streaming backend with automatic registration.""" _backend_type = Backend.NAPARI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'napari' SHM_PREFIX = 'napari_' - # __init__, _get_publisher, save, cleanup now inherited from ABC + def display_payload_extra( + self, + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return ViewerDisplayPayloadExtra.from_mapping( + NapariDisplayPayload.from_display_config(stream_request.display_config) + ) - def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: + def _prepare_shapes_data( + self, + data: RoiStreamPayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: """ Prepare shapes data for transmission. @@ -52,107 +94,29 @@ def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: shapes_data = NapariROIConverter.rois_to_shapes(data) return { - 'path': str(file_path), - 'shapes': shapes_data, + ViewerBatchItemWireField.PATH.value: str(file_path), + ViewerBatchItemWireField.SHAPES.value: shapes_data, } - def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - if data_type in (StreamingDataType.SHAPES, StreamingDataType.POINTS): - item_data = self._prepare_shapes_data(data, file_path) - data_type_value = data_type.value + def _prepare_batch_item( + self, + request: StreamingItemPreparationRequest, + ) -> ViewerStreamItemPayload: + if request.streaming_data_type.uses_napari_vector_payload: + item_data = self._prepare_shapes_data( + request.data, + request.item_path.value, + ) else: - item_data = self._create_shared_memory(data, file_path) - data_type_value = data_type.value - return item_data, data_type_value - - def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], **kwargs) -> None: - """ - Stream multiple images or ROIs to napari as a batch. - - Args: - data_list: List of image data or ROI lists - file_paths: List of path identifiers - **kwargs: Additional metadata - """ - # Filter to only supported file types - data_list, file_paths, skipped = self._filter_streamable_files(data_list, file_paths) - if not data_list: - return - - # Extract kwargs using generic polymorphic names - host = kwargs.get('host', 'localhost') - port = kwargs['port'] - transport_mode = kwargs['transport_mode'] - transport_config = kwargs.get('transport_config') - display_config = kwargs['display_config'] - microscope_handler = kwargs['microscope_handler'] - source = kwargs.get('source', 'unknown_source') # Pre-built source value - plate_path = kwargs.get('plate_path') - display_payload_extra = { - "colormap": display_config.get_colormap_name(), - "variable_size_handling": display_config.variable_size_handling.value - if hasattr(display_config, "variable_size_handling") and display_config.variable_size_handling - else None, - } - - message, batch_images, image_ids = self._build_batch_message( - data_list, - file_paths, - microscope_handler, - source, - display_config, - self._prepare_batch_item, - plate_path=plate_path, - display_payload_extra=display_payload_extra, - ) - - # Register sent images with queue tracker BEFORE sending - # This prevents race condition with IPC mode where acks arrive before registration - self._register_with_queue_tracker( - port, - image_ids, - transport_mode=transport_mode, - transport_config=transport_config, - ) - - # Create FRESH REQ socket for each send - REQ sockets cannot be reused - # This prevents the "Operation cannot be accomplished in current state" error - # when multiple streams happen concurrently - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, + item_data = self.create_shared_memory_payload( + request.data, + request.item_path.value, + ) + return ViewerStreamItemPayload( + item_payload=item_data, + streaming_data_type=request.streaming_data_type, ) - if self._context is None: - self._context = zmq.Context() - - socket = self._context.socket(zmq.REQ) - socket.connect(url) - time.sleep(0.1) # Brief delay for connection to establish - - try: - # Send with REQ socket (BLOCKING - worker waits for Napari to acknowledge) - # Worker blocks until Napari receives, copies data from shared memory, and sends ack - # This guarantees no messages are lost and shared memory is only closed after Napari is done - logger.info(f"📤 NAPARI BACKEND: Sending batch of {len(batch_images)} images to Napari on port {port} (REQ/REP - blocking until ack)") - socket.send_json(message) # Blocking send - - # Wait for acknowledgment from Napari (REP socket) - # Napari will only reply after it has copied all data from shared memory - ack_response = socket.recv_json() - logger.info(f"✅ NAPARI BACKEND: Received ack from Napari: {ack_response.get('status', 'unknown')}") - - finally: - # Always close the socket - never reuse REQ sockets - socket.close() - - # Clean up publisher's handles after successful send - # Receiver will unlink the shared memory after copying the data - self._cleanup_shared_memory_blocks(batch_images, unlink=False) - # cleanup() now inherited from ABC def __del__(self): diff --git a/src/polystore/roi.py b/src/polystore/roi.py index fb6bdb6..a626856 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -6,18 +6,108 @@ """ import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np +from metaclass_registry import AutoRegisterMeta +from zmqruntime.viewer_protocol import ( + ViewerSourceSpatialDomainPayload, + ViewerSourceSpatialWireField, +) from .constants import Backend +from .formats import FileFormat logger = logging.getLogger(__name__) +ROI_ZIP_EXTENSION = FileFormat.ROI.extensions[0] +ROI_ZIP_METADATA_MEMBER = "__polystore_roi_metadata__.json" +ROI_ZIP_SUFFIXES = tuple(Path(f"archive{ROI_ZIP_EXTENSION}").suffixes) +ROI_TUPLE_METADATA_KEYS = frozenset( + ( + "bbox", + "centroid", + "plane_indices", + "plane_shape", + ViewerSourceSpatialWireField.SOURCE_SPATIAL_SHAPE_YX.value, + ViewerSourceSpatialWireField.SPATIAL_ORIGIN_YX.value, + ) +) + + +@dataclass(frozen=True, slots=True) +class ROIArchivePath: + """Normalized filesystem path for an ImageJ ROI zip archive.""" + + path: Path + + @classmethod + def from_output_path(cls, output_path: Union[str, Path]) -> "ROIArchivePath": + path = Path(output_path) + if tuple(path.suffixes[-len(ROI_ZIP_SUFFIXES):]) == ROI_ZIP_SUFFIXES: + return cls(path) + return cls(path.with_suffix(ROI_ZIP_EXTENSION)) + + +def roi_zip_metadata_payload(metadata_by_filename: Dict[str, Dict[str, Any]]) -> str: + """Serialize per-ROI metadata into an ImageJ-compatible zip sidecar.""" + import json + + return json.dumps( + { + filename: _jsonable_roi_metadata(metadata) + for filename, metadata in metadata_by_filename.items() + }, + sort_keys=True, + ) + + +def load_roi_zip_metadata(zip_file: Any) -> Dict[str, Dict[str, Any]]: + """Load per-ROI metadata sidecar from a .roi.zip archive.""" + import json + + if ROI_ZIP_METADATA_MEMBER not in zip_file.namelist(): + return {} + payload = json.loads(zip_file.read(ROI_ZIP_METADATA_MEMBER).decode("utf-8")) + if not isinstance(payload, dict): + raise ValueError( + f"Invalid ROI metadata sidecar in {ROI_ZIP_METADATA_MEMBER}: expected mapping." + ) + return { + str(filename): _restore_roi_metadata(metadata) + for filename, metadata in payload.items() + if isinstance(metadata, dict) + } + + +def _jsonable_roi_metadata(value: Any) -> Any: + if isinstance(value, np.generic): + return value.item() + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, tuple): + return [_jsonable_roi_metadata(item) for item in value] + if isinstance(value, list): + return [_jsonable_roi_metadata(item) for item in value] + if isinstance(value, dict): + return {str(key): _jsonable_roi_metadata(item) for key, item in value.items()} + return value + + +def _restore_roi_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + restored = dict(metadata) + for key in ROI_TUPLE_METADATA_KEYS: + value = restored.get(key) + if isinstance(value, list): + restored[key] = tuple(value) + return restored + + class ShapeType(Enum): """ROI shape types.""" POLYGON = "polygon" @@ -27,11 +117,26 @@ class ShapeType(Enum): ELLIPSE = "ellipse" +class ShapeTypeRegistryBase(ABC): + """Shared declaration surface for shape-type keyed registries.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + +class ROIShape(ABC): + """Nominal base for all ROI shape records.""" + + shape_type: ClassVar[ShapeType] + + @dataclass(frozen=True) -class PolygonShape: +class PolygonShape(ROIShape): """Polygon ROI shape defined by vertex coordinates.""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates - shape_type: ShapeType = field(default=ShapeType.POLYGON, init=False) + shape_type: ClassVar[ShapeType] = ShapeType.POLYGON def __post_init__(self): if self.coordinates.ndim != 2 or self.coordinates.shape[1] != 2: @@ -41,10 +146,10 @@ def __post_init__(self): @dataclass(frozen=True) -class PolylineShape: +class PolylineShape(ROIShape): """Polyline ROI shape defined by path coordinates (open path, not closed polygon).""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates - shape_type: ShapeType = field(default=ShapeType.POLYLINE, init=False) + shape_type: ClassVar[ShapeType] = ShapeType.POLYLINE def __post_init__(self): if self.coordinates.ndim != 2 or self.coordinates.shape[1] != 2: @@ -54,11 +159,11 @@ def __post_init__(self): @dataclass(frozen=True) -class MaskShape: +class MaskShape(ROIShape): """Binary mask ROI shape.""" mask: np.ndarray # 2D boolean array bbox: Tuple[int, int, int, int] # (min_y, min_x, max_y, max_x) - shape_type: ShapeType = field(default=ShapeType.MASK, init=False) + shape_type: ClassVar[ShapeType] = ShapeType.MASK def __post_init__(self): if self.mask.ndim != 2: @@ -68,21 +173,21 @@ def __post_init__(self): @dataclass(frozen=True) -class PointShape: +class PointShape(ROIShape): """Point ROI shape.""" y: float x: float - shape_type: ShapeType = field(default=ShapeType.POINT, init=False) + shape_type: ClassVar[ShapeType] = ShapeType.POINT @dataclass(frozen=True) -class EllipseShape: +class EllipseShape(ROIShape): """Ellipse ROI shape.""" center_y: float center_x: float radius_y: float radius_x: float - shape_type: ShapeType = field(default=ShapeType.ELLIPSE, init=False) + shape_type: ClassVar[ShapeType] = ShapeType.ELLIPSE @dataclass(frozen=True) @@ -95,76 +200,341 @@ def __post_init__(self): if not self.shapes: raise ValueError("ROI must have at least one shape") for shape in self.shapes: - if not hasattr(shape, "shape_type"): - raise ValueError(f"Shape {shape} must have shape_type attribute") + if not isinstance(shape, ROIShape): + raise ValueError(f"Shape {shape} must be an ROIShape") -def extract_rois_from_labeled_mask( - labeled_mask: np.ndarray, - min_area: int = 10, - extract_contours: bool = True, -) -> List[ROI]: - """Extract ROIs from a labeled segmentation mask.""" - from skimage import measure - from skimage.measure import regionprops - from scipy.ndimage import find_objects +@dataclass(frozen=True, slots=True) +class LabeledMaskROIExtractionRequest: + """Request to extract ROIs from a labeled mask or stack.""" - if labeled_mask.ndim != 2: - raise ValueError(f"Labeled mask must be 2D, got shape {labeled_mask.shape}") + labeled_mask: np.ndarray + min_area: int = 10 + extract_contours: bool = True + spatial_origin_yx: Optional[Tuple[int, int]] = None + source_spatial_shape_yx: Optional[Tuple[int, int]] = None - if not np.issubdtype(labeled_mask.dtype, np.integer): - labeled_mask = labeled_mask.astype(np.int32) - regions = regionprops(labeled_mask) - slices = find_objects(labeled_mask) +class LabeledMaskROIExtractor(ABC, metaclass=AutoRegisterMeta): + """Registered extraction behavior for one labeled-mask dimensional family.""" - rois = [] - for region in regions: - if region.area < min_area: - continue - - metadata = { - "label": int(region.label), - "area": float(region.area), - "perimeter": float(region.perimeter), - "centroid": tuple(float(c) for c in region.centroid), - "bbox": tuple(int(b) for b in region.bbox), - } + __registry_key__ = "__name__" + __skip_if_no_key__ = True - shapes = [] - if extract_contours: - label_idx = region.label - 1 - if label_idx < len(slices) and slices[label_idx] is not None: - slice_y, slice_x = slices[label_idx] - cropped_mask = labeled_mask[slice_y, slice_x] - binary_mask = (cropped_mask == region.label).astype(np.uint8) - padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) - contours = measure.find_contours(padded_mask, level=0.5) - offset_y = slice_y.start - offset_x = slice_x.start - padding_offset = np.array([offset_y, offset_x]) - 1 - for contour in contours: - if len(contour) >= 3: - contour_full = contour + padding_offset - shapes.append(PolygonShape(coordinates=contour_full)) - else: - binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=region.bbox)) + @classmethod + def for_request( + cls, + request: LabeledMaskROIExtractionRequest, + ) -> "LabeledMaskROIExtractor": + for extractor_type in cls.__registry__.values(): + extractor = extractor_type() + if extractor.accepts(request.labeled_mask): + return extractor + raise ValueError( + "No ROI extractor registered for labeled mask shape " + f"{request.labeled_mask.shape}." + ) - if shapes: - rois.append(ROI(shapes=shapes, metadata=metadata)) + @abstractmethod + def accepts(self, labeled_mask: np.ndarray) -> bool: + """Return whether this extractor owns the mask dimensionality.""" - logger.info(f"Extracted {len(rois)} ROIs from labeled mask") - return rois + @abstractmethod + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + """Extract ROIs from the request.""" + + +class TwoDimensionalLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from a single 2D labeled mask.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim == 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + from skimage import measure + from skimage.measure import regionprops + from scipy.ndimage import find_objects + + labeled_mask = request.labeled_mask + if not np.issubdtype(labeled_mask.dtype, np.integer): + labeled_mask = labeled_mask.astype(np.int32) + + regions = regionprops(labeled_mask) + slices = find_objects(labeled_mask) + origin_y, origin_x = request.spatial_origin_yx or (0, 0) + + rois = [] + for region in regions: + if region.area < request.min_area: + continue + min_y, min_x, max_y, max_x = region.bbox + + metadata = { + "label": int(region.label), + "area": float(region.area), + "perimeter": float(region.perimeter), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), + } + metadata.update( + ViewerSourceSpatialDomainPayload( + origin_yx=request.spatial_origin_yx, + source_shape_yx=request.source_spatial_shape_yx, + ).to_wire_mapping() + ) + + shapes = [] + if request.extract_contours: + label_idx = region.label - 1 + if label_idx < len(slices) and slices[label_idx] is not None: + slice_y, slice_x = slices[label_idx] + cropped_mask = labeled_mask[slice_y, slice_x] + binary_mask = (cropped_mask == region.label).astype(np.uint8) + padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) + contours = measure.find_contours(padded_mask, level=0.5) + offset_y = slice_y.start + offset_x = slice_x.start + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 + for contour in contours: + if len(contour) >= 3: + contour_full = contour + padding_offset + shapes.append(PolygonShape(coordinates=contour_full)) + else: + binary_mask = labeled_mask == region.label + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) + + if shapes: + rois.append(ROI(shapes=shapes, metadata=metadata)) + + logger.info(f"Extracted {len(rois)} ROIs from labeled mask") + return rois + + +class NonSpatialLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Treat scalar and otherwise non-spatial label payloads as empty ROI sets.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim < 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + return [] + + +class StackedLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from all 2D planes in a labeled-mask stack.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim > 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + stack = request.labeled_mask + leading_shape = stack.shape[:-2] + rois: list[ROI] = [] + for plane_indices in np.ndindex(leading_shape): + plane_request = LabeledMaskROIExtractionRequest( + labeled_mask=stack[plane_indices], + min_area=request.min_area, + extract_contours=request.extract_contours, + spatial_origin_yx=request.spatial_origin_yx, + source_spatial_shape_yx=request.source_spatial_shape_yx, + ) + for roi in TwoDimensionalLabeledMaskROIExtractor().extract(plane_request): + rois.append(self._with_plane_metadata(roi, plane_indices, leading_shape)) + return rois + + @staticmethod + def _with_plane_metadata( + roi: ROI, + plane_indices: tuple[int, ...], + leading_shape: tuple[int, ...], + ) -> ROI: + return ROI( + shapes=roi.shapes, + metadata={ + **roi.metadata, + "plane_indices": tuple(int(index) for index in plane_indices), + "plane_shape": tuple(int(size) for size in leading_shape), + }, + ) + + +class ROIJsonShapeDecoder(ShapeTypeRegistryBase, ABC, metaclass=AutoRegisterMeta): + """Decode one serialized ROI shape variant.""" + + @classmethod + def for_serialized_shape(cls, shape_dict: Dict[str, Any]) -> "ROIJsonShapeDecoder": + record = SerializedROIShapeRecord(shape_dict) + shape_type = record.shape_type() + try: + shape_key = ShapeType(shape_type) + except ValueError: + raise ValueError(f"Unknown ROI shape type: {shape_type!r}") from None + return cls.__registry__[shape_key]() + + @abstractmethod + def decode(self, shape_dict: Dict[str, Any]) -> Any: + """Return the concrete ROI shape represented by ``shape_dict``.""" + + +@dataclass(frozen=True, slots=True) +class SerializedROIShapeRecord: + """Typed access to one serialized ROI shape record.""" + + payload: Dict[str, Any] + + def shape_type(self) -> str: + value = self.required("type") + if not isinstance(value, str): + raise TypeError("Serialized ROI shape 'type' must be a string.") + return value + + def coordinates(self) -> np.ndarray: + return np.array(self.required("coordinates")) + + def mask(self) -> np.ndarray: + return np.array(self.required("mask"), dtype=bool) + + def bbox(self) -> Tuple[int, int, int, int]: + return tuple(self.required("bbox")) + + def point_yx(self) -> Tuple[float, float]: + return (self.numeric("y"), self.numeric("x")) + + def ellipse(self) -> "SerializedEllipseShape": + return SerializedEllipseShape( + center_y=self.numeric("center_y"), + center_x=self.numeric("center_x"), + radius_y=self.numeric("radius_y"), + radius_x=self.numeric("radius_x"), + ) + + def numeric(self, key: str) -> float: + value = self.required(key) + if not isinstance(value, (int, float)): + raise TypeError(f"Serialized ROI shape field {key!r} must be numeric.") + return float(value) + + def required(self, key: str) -> Any: + if key not in self.payload: + raise ValueError(f"Serialized ROI shape missing required field {key!r}.") + return self.payload[key] + + +@dataclass(frozen=True, slots=True) +class SerializedEllipseShape: + """Nominal serialized ellipse shape fields.""" + + center_y: float + center_x: float + radius_y: float + radius_x: float + + def to_shape(self) -> EllipseShape: + return EllipseShape( + center_y=self.center_y, + center_x=self.center_x, + radius_y=self.radius_y, + radius_x=self.radius_x, + ) + + +@dataclass(frozen=True, slots=True) +class SerializedROIRecord: + """Typed access to one serialized ROI record.""" + + payload: Dict[str, Any] + + def metadata(self) -> Dict[str, Any]: + value = self.required("metadata") + if not isinstance(value, dict): + raise TypeError("Serialized ROI 'metadata' must be a mapping.") + return value + + def shapes(self) -> Tuple[Dict[str, Any], ...]: + value = self.required("shapes") + if not isinstance(value, list): + raise TypeError("Serialized ROI 'shapes' must be a list.") + return tuple(value) + + def required(self, key: str) -> Any: + if key not in self.payload: + raise ValueError(f"Serialized ROI missing required field {key!r}.") + return self.payload[key] + + +class PolygonROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYGON + + def decode(self, shape_dict: Dict[str, Any]) -> PolygonShape: + return PolygonShape( + coordinates=SerializedROIShapeRecord(shape_dict).coordinates() + ) + + +class PolylineROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYLINE + + def decode(self, shape_dict: Dict[str, Any]) -> PolylineShape: + return PolylineShape( + coordinates=SerializedROIShapeRecord(shape_dict).coordinates() + ) + + +class MaskROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.MASK + + def decode(self, shape_dict: Dict[str, Any]) -> MaskShape: + record = SerializedROIShapeRecord(shape_dict) + return MaskShape( + mask=record.mask(), + bbox=record.bbox(), + ) + + +class PointROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POINT + + def decode(self, shape_dict: Dict[str, Any]) -> PointShape: + y, x = SerializedROIShapeRecord(shape_dict).point_yx() + return PointShape(y=y, x=x) + + +class EllipseROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.ELLIPSE + + def decode(self, shape_dict: Dict[str, Any]) -> EllipseShape: + return SerializedROIShapeRecord(shape_dict).ellipse().to_shape() + + +def extract_rois_from_labeled_mask( + labeled_mask: np.ndarray, + min_area: int = 10, + extract_contours: bool = True, + spatial_origin_yx: Optional[Tuple[int, int]] = None, + source_spatial_shape_yx: Optional[Tuple[int, int]] = None, +) -> List[ROI]: + """Extract ROIs from a labeled segmentation mask.""" + request = LabeledMaskROIExtractionRequest( + labeled_mask=np.asarray(labeled_mask), + min_area=min_area, + extract_contours=extract_contours, + spatial_origin_yx=spatial_origin_yx, + source_spatial_shape_yx=source_spatial_shape_yx, + ) + return LabeledMaskROIExtractor.for_request(request).extract(request) def _get_backend_from_filemanager(filemanager: Any, backend: Union[str, Backend]): - backend_name = backend.value if hasattr(backend, "value") else str(backend) - if hasattr(filemanager, "_get_backend"): - return filemanager._get_backend(backend_name) - if hasattr(filemanager, "registry"): - return filemanager.registry[backend_name] - raise AttributeError("FileManager does not provide backend lookup") + backend_name = backend.value if isinstance(backend, Backend) else str(backend) + return filemanager._get_backend(backend_name) def materialize_rois( @@ -172,17 +542,11 @@ def materialize_rois( output_path: str, filemanager: Any, backend: Union[str, Backend], + images_dir: str | None = None, ) -> str: """Materialize ROIs to backend-specific format.""" backend_obj = _get_backend_from_filemanager(filemanager, backend) - - images_dir = None - if hasattr(filemanager, "_materialization_context"): - images_dir = filemanager._materialization_context.get("images_dir") - - if hasattr(backend_obj, "_save_rois"): - return backend_obj._save_rois(rois, Path(output_path), images_dir=images_dir) - raise NotImplementedError(f"Backend {backend} does not support ROI saving") + return backend_obj._save_rois(rois, Path(output_path), images_dir=images_dir) def load_rois_from_json(json_path: Path) -> List[ROI]: @@ -200,34 +564,12 @@ def load_rois_from_json(json_path: Path) -> List[ROI]: rois = [] for roi_dict in rois_data: - metadata = roi_dict.get("metadata", {}) + record = SerializedROIRecord(roi_dict) + metadata = record.metadata() shapes = [] - for shape_dict in roi_dict.get("shapes", []): - shape_type = shape_dict.get("type") - - if shape_type == "polygon": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolygonShape(coordinates=coordinates)) - elif shape_type == "polyline": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolylineShape(coordinates=coordinates)) - elif shape_type == "mask": - mask = np.array(shape_dict["mask"], dtype=bool) - bbox = tuple(shape_dict["bbox"]) - shapes.append(MaskShape(mask=mask, bbox=bbox)) - elif shape_type == "point": - shapes.append(PointShape(y=shape_dict["y"], x=shape_dict["x"])) - elif shape_type == "ellipse": - shapes.append( - EllipseShape( - center_y=shape_dict["center_y"], - center_x=shape_dict["center_x"], - radius_y=shape_dict["radius_y"], - radius_x=shape_dict["radius_x"], - ) - ) - else: - logger.warning(f"Unknown shape type: {shape_type}, skipping") + for shape_dict in record.shapes(): + decoder = ROIJsonShapeDecoder.for_serialized_shape(shape_dict) + shapes.append(decoder.decode(shape_dict)) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) @@ -250,23 +592,26 @@ def load_rois_from_zip(zip_path: Path) -> List[ROI]: rois = [] with zipfile.ZipFile(zip_path, "r") as zf: + metadata_by_filename = load_roi_zip_metadata(zf) for filename in zf.namelist(): if not filename.endswith(".roi"): continue - try: - roi_bytes = zf.read(filename) - ij_roi = ImagejRoi.frombytes(roi_bytes) - coords = ij_roi.coordinates() - if coords is not None and len(coords) > 0: - coords_yx = coords[:, [1, 0]] - if ij_roi.roitype == ROI_TYPE.POLYLINE: - shape = PolylineShape(coordinates=coords_yx) - else: - shape = PolygonShape(coordinates=coords_yx) - rois.append(ROI(shapes=[shape], metadata={"label": ij_roi.name or filename.replace(".roi", "")})) - except Exception as exc: - logger.warning(f"Failed to load ROI from {filename}: {exc}") - continue + roi_bytes = zf.read(filename) + ij_roi = ImagejRoi.frombytes(roi_bytes) + coords = ij_roi.coordinates() + if coords is None or len(coords) == 0: + raise ValueError(f"ImageJ ROI member {filename!r} has no coordinates.") + coords_yx = coords[:, [1, 0]] + if ij_roi.roitype == ROI_TYPE.POLYLINE: + shape = PolylineShape(coordinates=coords_yx) + else: + shape = PolygonShape(coordinates=coords_yx) + if filename not in metadata_by_filename: + raise ValueError( + f"ROI archive {zip_path} missing metadata sidecar entry for {filename!r}." + ) + metadata = dict(metadata_by_filename[filename]) + rois.append(ROI(shapes=[shape], metadata=metadata)) if not rois: raise ValueError(f"No valid ROIs found in {zip_path}") diff --git a/src/polystore/roi_converters.py b/src/polystore/roi_converters.py index 46e8631..e746a9e 100644 --- a/src/polystore/roi_converters.py +++ b/src/polystore/roi_converters.py @@ -7,93 +7,548 @@ """ import logging -from typing import Any, Dict, List, Tuple +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, ClassVar, Dict, List, Tuple import numpy as np +from metaclass_registry import AutoRegisterMeta -from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI -from .streaming_constants import NapariShapeType +from .roi import ( + EllipseShape, + MaskShape, + PointShape, + PolygonShape, + PolylineShape, + ROI, + ROIShape, + ShapeType, + ShapeTypeRegistryBase, +) +from .streaming_constants import StreamingDataType logger = logging.getLogger(__name__) -class NapariROIConverter: - """Convert ROI objects to Napari shapes format.""" +class UnsupportedImageJROIShapeError(ValueError): + """Raised when an ROI shape has no ImageJ .roi representation.""" + + +class UnsupportedNapariROIShapeError(ValueError): + """Raised when an ROI shape has no Napari vector-payload representation.""" + + +@dataclass(frozen=True, slots=True) +class ImageJROIMember: + """ImageJ ROI archive member with the metadata that must follow it.""" + + imagej_roi: Any + metadata: Dict[str, Any] + + +@dataclass(frozen=True, slots=True) +class NapariShapeTypeAlias: + """Inert alias from Napari wire shape names to ROI shape types.""" + + alias: str + shape_type: ShapeType + + +NAPARI_SHAPE_TYPE_ALIASES = ( + NapariShapeTypeAlias("path", ShapeType.POLYLINE), + NapariShapeTypeAlias("points", ShapeType.POINT), +) + + +@dataclass(frozen=True, slots=True) +class NapariShapeMetadata: + """Required metadata carried by one Napari ROI shape payload.""" + + label: Any + area: Any + centroid_yx: tuple[Any, Any] + + @classmethod + def from_shape_payload( + cls, + shape_dict: Mapping[str, Any], + *, + area: Any | None = None, + centroid_yx: Sequence[Any] | None = None, + ) -> "NapariShapeMetadata": + metadata = required_shape_metadata(shape_dict) + return cls.from_metadata(metadata, area=area, centroid_yx=centroid_yx) + + @classmethod + def from_metadata( + cls, + metadata: Mapping[str, Any], + *, + area: Any | None = None, + centroid_yx: Sequence[Any] | None = None, + ) -> "NapariShapeMetadata": + if "label" not in metadata: + raise ValueError("Napari shape metadata missing required 'label'.") + resolved_area = cls._required_value(metadata, "area") if area is None else area + resolved_centroid = ( + cls._required_value(metadata, "centroid") + if centroid_yx is None + else centroid_yx + ) + if len(resolved_centroid) != 2: + raise ValueError( + "Napari shape metadata 'centroid' must contain exactly two values." + ) + return cls( + label=metadata["label"], + area=resolved_area, + centroid_yx=(resolved_centroid[0], resolved_centroid[1]), + ) - _SHAPE_DIMENSION_HANDLERS = { - "polygon": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "polyline": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "ellipse": lambda shape_dict, prepend_dims: np.hstack( + @staticmethod + def _required_value(metadata: Mapping[str, Any], field: str) -> Any: + if field not in metadata: + raise ValueError(f"Napari shape metadata missing required {field!r}.") + return metadata[field] + + +def required_shape_metadata(shape_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """Return the required metadata mapping from a Napari ROI shape payload.""" + if "metadata" not in shape_dict: + raise ValueError("Napari shape payload missing required 'metadata'.") + metadata = shape_dict["metadata"] + if not isinstance(metadata, Mapping): + raise TypeError("Napari shape payload 'metadata' must be a mapping.") + return metadata + + +@dataclass(slots=True) +class NapariShapeProperties: + """Mutable Napari shape-property columns collected during layer conversion.""" + + label: list[Any] = field(default_factory=list) + area: list[Any] = field(default_factory=list) + centroid_y: list[Any] = field(default_factory=list) + centroid_x: list[Any] = field(default_factory=list) + + def append(self, metadata: NapariShapeMetadata) -> None: + self.label.append(metadata.label) + self.area.append(metadata.area) + self.centroid_y.append(metadata.centroid_yx[0]) + self.centroid_x.append(metadata.centroid_yx[1]) + + def to_mapping(self) -> dict[str, list[Any]]: + return asdict(self) + + +@dataclass(frozen=True, slots=True) +class NapariEllipsePayload: + """Required geometric fields for one Napari ellipse payload.""" + + center_yx: np.ndarray + radii_yx: np.ndarray + + @classmethod + def from_shape_payload( + cls, + shape_dict: Mapping[str, Any], + ) -> "NapariEllipsePayload": + return cls( + center_yx=np.array(cls._required_field(shape_dict, "center")), + radii_yx=np.array(cls._required_field(shape_dict, "radii")), + ) + + @staticmethod + def _required_field(shape_dict: Mapping[str, Any], field_name: str) -> Any: + if field_name not in shape_dict: + raise ValueError( + f"Napari ellipse payload missing required {field_name!r}." + ) + return shape_dict[field_name] + + def corner_rows(self) -> np.ndarray: + return np.array( [ - np.tile(prepend_dims, (4, 1)), - np.array( - [ - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - ] - ), + [ + self.center_yx[0] - self.radii_yx[0], + self.center_yx[1] - self.radii_yx[1], + ], + [ + self.center_yx[0] - self.radii_yx[0], + self.center_yx[1] + self.radii_yx[1], + ], + [ + self.center_yx[0] + self.radii_yx[0], + self.center_yx[1] + self.radii_yx[1], + ], + [ + self.center_yx[0] + self.radii_yx[0], + self.center_yx[1] - self.radii_yx[1], + ], ] - ), - "point": lambda shape_dict, prepend_dims: np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1), - } + ) + + def bounding_box_rows(self) -> np.ndarray: + return np.array( + [ + self.center_yx - self.radii_yx, + self.center_yx + self.radii_yx, + ] + ) + + +class NapariShapeConverter(ShapeTypeRegistryBase, ABC, metaclass=AutoRegisterMeta): + """Registered conversion behavior for one ROI shape type.""" + + @classmethod + def for_shape_dict(cls, shape_dict: Dict[str, Any]) -> "NapariShapeConverter": + return cls.__registry__[_shape_type_from_napari(shape_dict["type"])]() + + def append_common_properties( + self, + metadata: NapariShapeMetadata, + properties: NapariShapeProperties, + ) -> None: + properties.append(metadata) + + @abstractmethod + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + """Add dimensions to a 2D shape to make it nD.""" + + @abstractmethod + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: NapariShapeProperties, + ) -> None: + """Append this shape to a Napari layer payload.""" + + +def _shape_type_from_napari(shape_type: object) -> ShapeType: + value = str(shape_type.value) if isinstance(shape_type, Enum) else str(shape_type) + for alias in NAPARI_SHAPE_TYPE_ALIASES: + if alias.alias == value: + return alias.shape_type + return ShapeType(value) + + +class CoordinateNapariShapeConverter(NapariShapeConverter): + """Shared converter for coordinate-list shapes.""" + + napari_shape_type: ClassVar[str] + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + coordinates = np.array(shape_dict["coordinates"]) + return np.hstack([np.tile(prepend_dims, (len(coordinates), 1)), coordinates]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: NapariShapeProperties, + ) -> None: + napari_shapes.append(np.array(shape_dict["coordinates"])) + shape_types.append(self.napari_shape_type) + self.append_common_properties( + NapariShapeMetadata.from_shape_payload(shape_dict), + properties, + ) + + +class PolygonNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYGON + napari_shape_type = "polygon" + + +class PolylineNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYLINE + napari_shape_type = "path" + + +class EllipseNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.ELLIPSE + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + ellipse = NapariEllipsePayload.from_shape_payload(shape_dict) + return np.hstack([np.tile(prepend_dims, (4, 1)), ellipse.corner_rows()]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: NapariShapeProperties, + ) -> None: + ellipse = NapariEllipsePayload.from_shape_payload(shape_dict) + napari_shapes.append(ellipse.bounding_box_rows()) + shape_types.append("ellipse") + self.append_common_properties( + NapariShapeMetadata.from_shape_payload(shape_dict), + properties, + ) + + +class PointNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.POINT + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + coordinates = np.array(shape_dict["coordinates"]) + if coordinates.ndim == 1: + coordinates = coordinates.reshape(1, -1) + return np.hstack([np.tile(prepend_dims, (len(coordinates), 1)), coordinates]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: NapariShapeProperties, + ) -> None: + coordinates = np.array(shape_dict["coordinates"]) + if coordinates.ndim == 1: + coordinates = coordinates.reshape(1, -1) + for coordinate in coordinates: + napari_shapes.append(np.array([coordinate])) + shape_types.append("point") + self.append_common_properties( + NapariShapeMetadata.from_shape_payload( + shape_dict, + centroid_yx=coordinate, + area=0, + ), + properties, + ) + + +class ROIShapeConverterRegistryBase(ShapeTypeRegistryBase): + """Shared lookup behavior for ROI-shape converter registries.""" + + @classmethod + def for_shape(cls, shape: ROIShape) -> Any: + return cls.__registry__[shape.shape_type]() + + +class ROIShapeNapariPayloadConverter( + ROIShapeConverterRegistryBase, + ABC, + metaclass=AutoRegisterMeta, +): + """Registered projection from ROI shape objects into Napari wire payloads.""" + + streaming_data_type: ClassVar[StreamingDataType] = StreamingDataType.SHAPES + + @classmethod + def streaming_data_type_for_rois(cls, rois: List[ROI]) -> StreamingDataType: + shape_stream_types = tuple( + cls.for_shape(shape).streaming_data_type + for roi in rois + for shape in roi.shapes + ) + if shape_stream_types and all( + stream_type == StreamingDataType.POINTS + for stream_type in shape_stream_types + ): + return StreamingDataType.POINTS + return StreamingDataType.SHAPES + + @abstractmethod + def shape_payloads( + self, + shape: ROIShape, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], ...]: + """Return one or more Napari shape dictionaries for this ROI shape.""" + + +class CoordinateROIShapeNapariPayloadConverter(ROIShapeNapariPayloadConverter): + """Shared Napari payload projection for coordinate-list ROI shapes.""" + + napari_payload_type: ClassVar[str] + + def shape_payloads( + self, + shape: ROIShape, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], ...]: + return ( + { + "type": self.napari_payload_type, + "coordinates": self.coordinates_yx(shape).tolist(), + "metadata": metadata, + }, + ) + + @abstractmethod + def coordinates_yx(self, shape: ROIShape) -> np.ndarray: + """Return shape coordinates as an Nx2 YX array.""" + + +class PolygonROIShapeNapariPayloadConverter(CoordinateROIShapeNapariPayloadConverter): + shape_type = ShapeType.POLYGON + napari_payload_type = "polygon" + + def coordinates_yx(self, shape: PolygonShape) -> np.ndarray: + return shape.coordinates + + +class PolylineROIShapeNapariPayloadConverter(CoordinateROIShapeNapariPayloadConverter): + shape_type = ShapeType.POLYLINE + napari_payload_type = "path" + + def coordinates_yx(self, shape: PolylineShape) -> np.ndarray: + return shape.coordinates + + +class EllipseROIShapeNapariPayloadConverter(ROIShapeNapariPayloadConverter): + shape_type = ShapeType.ELLIPSE + + def shape_payloads( + self, + shape: EllipseShape, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], ...]: + return ( + { + "type": "ellipse", + "center": [shape.center_y, shape.center_x], + "radii": [shape.radius_y, shape.radius_x], + "metadata": metadata, + }, + ) + + +class PointROIShapeNapariPayloadConverter(ROIShapeNapariPayloadConverter): + shape_type = ShapeType.POINT + streaming_data_type = StreamingDataType.POINTS + + def shape_payloads( + self, + shape: PointShape, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], ...]: + return ( + { + "type": "points", + "coordinates": [[shape.y, shape.x]], + "metadata": metadata, + }, + ) + + +class MaskROIShapeNapariPayloadConverter(ROIShapeNapariPayloadConverter): + shape_type = ShapeType.MASK + + def shape_payloads( + self, + shape: MaskShape, + metadata: Dict[str, Any], + ) -> tuple[Dict[str, Any], ...]: + raise UnsupportedNapariROIShapeError( + "MaskShape cannot be represented as a Napari vector ROI payload." + ) + + +class ImageJROIShapeConverter( + ROIShapeConverterRegistryBase, + ABC, + metaclass=AutoRegisterMeta, +): + """Registered projection from ROI shape objects into ImageJ ROI records.""" + + @abstractmethod + def imagej_roi(self, shape: ROIShape, name: str) -> Any: + """Return a roifile ImagejRoi for this ROI shape.""" + + +class PolygonImageJROIShapeConverter(ImageJROIShapeConverter): + shape_type = ShapeType.POLYGON + + def imagej_roi(self, shape: PolygonShape, name: str) -> Any: + from roifile import ImagejRoi + + imagej_roi = ImagejRoi.frompoints(shape.coordinates[:, [1, 0]]) + imagej_roi.name = name + return imagej_roi + + +class PolylineImageJROIShapeConverter(ImageJROIShapeConverter): + shape_type = ShapeType.POLYLINE + + def imagej_roi(self, shape: PolylineShape, name: str) -> Any: + from roifile import ImagejRoi, ROI_TYPE + + imagej_roi = ImagejRoi.frompoints(shape.coordinates[:, [1, 0]]) + imagej_roi.roitype = ROI_TYPE.POLYLINE + imagej_roi.name = name + return imagej_roi + + +class EllipseImageJROIShapeConverter(ImageJROIShapeConverter): + shape_type = ShapeType.ELLIPSE + + def imagej_roi(self, shape: EllipseShape, name: str) -> Any: + from roifile import ImagejRoi, ROI_TYPE + + left = shape.center_x - shape.radius_x + top = shape.center_y - shape.radius_y + width = 2 * shape.radius_x + height = 2 * shape.radius_y + imagej_roi = ImagejRoi.frompoints( + np.array([[left, top], [left + width, top + height]]) + ) + imagej_roi.roitype = ROI_TYPE.OVAL + imagej_roi.name = name + return imagej_roi + + +class PointImageJROIShapeConverter(ImageJROIShapeConverter): + shape_type = ShapeType.POINT + + def imagej_roi(self, shape: PointShape, name: str) -> Any: + from roifile import ImagejRoi + + imagej_roi = ImagejRoi.frompoints(np.array([[shape.x, shape.y]])) + imagej_roi.name = name + return imagej_roi + + +class MaskImageJROIShapeConverter(ImageJROIShapeConverter): + shape_type = ShapeType.MASK + + def imagej_roi(self, shape: MaskShape, name: str) -> Any: + raise UnsupportedImageJROIShapeError( + "MaskShape cannot be represented in ImageJ .roi format." + ) + + +class NapariROIConverter: + """Convert ROI objects to Napari shapes format.""" @staticmethod def add_dimensions_to_shape(shape_dict: Dict[str, Any], prepend_dims: List[float]) -> np.ndarray: """Add dimensions to a 2D shape to make it nD.""" - shape_type = shape_dict["type"] - shape_type_enum = NapariShapeType(shape_type) if isinstance(shape_type, str) else shape_type - handler = NapariROIConverter._SHAPE_DIMENSION_HANDLERS.get(shape_type_enum.value) - if handler is None: - raise ValueError(f"Unsupported shape type: {shape_type}") - return handler(shape_dict, np.array(prepend_dims)) + return NapariShapeConverter.for_shape_dict(shape_dict).add_dimensions( + shape_dict, + np.array(prepend_dims), + ) @staticmethod def rois_to_shapes(rois: List[ROI]) -> List[Dict[str, Any]]: """Convert ROI objects to Napari shapes data.""" shapes_data = [] for roi in rois: - if roi.shapes and all(isinstance(shape, PointShape) for shape in roi.shapes): - points = [[shape.y, shape.x] for shape in roi.shapes] - shapes_data.append({"type": "points", "coordinates": points, "metadata": roi.metadata}) - else: - for shape in roi.shapes: - if isinstance(shape, PolygonShape): - shapes_data.append( - {"type": "polygon", "coordinates": shape.coordinates.tolist(), "metadata": roi.metadata} - ) - elif isinstance(shape, PolylineShape): - shapes_data.append( - {"type": "path", "coordinates": shape.coordinates.tolist(), "metadata": roi.metadata} - ) - elif isinstance(shape, EllipseShape): - shapes_data.append( - { - "type": "ellipse", - "center": [shape.center_y, shape.center_x], - "radii": [shape.radius_y, shape.radius_x], - "metadata": roi.metadata, - } - ) - elif isinstance(shape, PointShape): - shapes_data.append({"type": "point", "coordinates": [shape.y, shape.x], "metadata": roi.metadata}) + for shape in roi.shapes: + shapes_data.extend( + ROIShapeNapariPayloadConverter.for_shape(shape).shape_payloads( + shape, + roi.metadata, + ) + ) return shapes_data @staticmethod @@ -101,92 +556,75 @@ def shapes_to_napari_format(shapes_data: List[Dict]) -> Tuple[List[np.ndarray], """Convert shape dicts to Napari layer format.""" napari_shapes = [] shape_types = [] - properties = {"label": [], "area": [], "centroid_y": [], "centroid_x": []} + properties = NapariShapeProperties() for shape_dict in shapes_data: - shape_type = shape_dict.get("type") - metadata = shape_dict.get("metadata", {}) - - if shape_type == "polygon": - coords = np.array(shape_dict["coordinates"]) - napari_shapes.append(coords) - shape_types.append("polygon") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "ellipse": - center = np.array(shape_dict["center"]) - radii = np.array(shape_dict["radii"]) - corners = np.array([center - radii, center + radii]) - napari_shapes.append(corners) - shape_types.append("ellipse") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "point": - coords = np.array([shape_dict["coordinates"]]) - napari_shapes.append(coords) - shape_types.append("point") - point_coords = shape_dict["coordinates"] - properties["label"].append(metadata.get("label", "")) - properties["area"].append(0) - properties["centroid_y"].append(point_coords[0]) - properties["centroid_x"].append(point_coords[1]) - - return napari_shapes, shape_types, properties + NapariShapeConverter.for_shape_dict(shape_dict).append_napari_format( + shape_dict, + napari_shapes, + shape_types, + properties, + ) + + return napari_shapes, shape_types, properties.to_mapping() class FijiROIConverter: """Convert ROI objects to ImageJ ROI bytes.""" @staticmethod - def rois_to_imagej_bytes(rois: List[ROI], roi_prefix: str = "") -> List[bytes]: - """Convert ROI objects to ImageJ ROI bytes.""" + def rois_to_imagej_members( + rois: List[ROI], + roi_prefix: str = "", + ) -> List[ImageJROIMember]: + """Convert ROI objects to ImageJ ROI members with per-member metadata.""" try: - from roifile import ImagejRoi, ROI_TYPE + import roifile # noqa: F401 except ImportError: raise ImportError("roifile library required for ImageJ ROI conversion. Install with: pip install roifile") - roi_bytes_list = [] - for roi in rois: - for shape in roi.shapes: - if isinstance(shape, PolygonShape): - coords_xy = shape.coordinates[:, [1, 0]] - ij_roi = ImagejRoi.frompoints(coords_xy) - ij_roi.name = f"{roi_prefix}_ROI_{roi.metadata.get('label', '')}".rstrip("_") - roi_bytes_list.append(ij_roi.tobytes()) - elif isinstance(shape, PolylineShape): - coords_xy = shape.coordinates[:, [1, 0]] - ij_roi = ImagejRoi.frompoints(coords_xy) - ij_roi.roitype = ROI_TYPE.POLYLINE - ij_roi.name = f"{roi_prefix}_ROI_{roi.metadata.get('label', '')}".rstrip("_") - roi_bytes_list.append(ij_roi.tobytes()) - elif isinstance(shape, EllipseShape): - center_x = shape.center_x - center_y = shape.center_y - radius_x = shape.radius_x - radius_y = shape.radius_y - left = center_x - radius_x - top = center_y - radius_y - width = 2 * radius_x - height = 2 * radius_y - ij_roi = ImagejRoi.frompoints(np.array([[left, top], [left + width, top + height]])) - ij_roi.roitype = ImagejRoi.OVAL if hasattr(ImagejRoi, "OVAL") else ROI_TYPE.OVAL - ij_roi.name = f"{roi_prefix}_ROI_{roi.metadata.get('label', '')}".rstrip("_") - roi_bytes_list.append(ij_roi.tobytes()) - elif isinstance(shape, PointShape): - coords_xy = np.array([[shape.x, shape.y]]) - ij_roi = ImagejRoi.frompoints(coords_xy) - ij_roi.name = f"{roi_prefix}_ROI_{roi.metadata.get('label', '')}".rstrip("_") - roi_bytes_list.append(ij_roi.tobytes()) - - return roi_bytes_list + members: list[ImageJROIMember] = [] + for roi_index, roi in enumerate(rois, start=1): + for shape_index, shape in enumerate(roi.shapes, start=1): + imagej_roi = ImageJROIShapeConverter.for_shape(shape).imagej_roi( + shape, + FijiROIConverter.imagej_roi_name( + roi_prefix=roi_prefix, + roi_index=roi_index, + shape_index=shape_index, + ), + ) + members.append( + ImageJROIMember( + imagej_roi=imagej_roi, + metadata=dict(roi.metadata), + ) + ) + return members + + @staticmethod + def imagej_roi_name( + *, + roi_prefix: str, + roi_index: int, + shape_index: int, + ) -> str: + """Return the stable ImageJ ROI name for one projected shape.""" + stem = f"ROI_{roi_index}_{shape_index}" + if roi_prefix: + return f"{roi_prefix}_{stem}" + return stem + + @staticmethod + def rois_to_imagej_bytes(rois: List[ROI], roi_prefix: str = "") -> List[bytes]: + """Convert ROI objects to ImageJ ROI bytes.""" + return [ + member.imagej_roi.tobytes() + for member in FijiROIConverter.rois_to_imagej_members( + rois, + roi_prefix=roi_prefix, + ) + ] @staticmethod def encode_rois_for_transmission(roi_bytes_list: List[bytes]) -> List[str]: @@ -200,6 +638,19 @@ def decode_rois_from_transmission(encoded_rois: List[str]) -> List[bytes]: import base64 return [base64.b64decode(roi_encoded) for roi_encoded in encoded_rois] + @staticmethod + def transmission_to_java_rois( + encoded_rois: List[str], + scyjava_module, + ) -> List[Any]: + """Decode transmitted ImageJ ROI bytes into Java ROI instances.""" + return [ + FijiROIConverter.bytes_to_java_roi(roi_bytes, scyjava_module) + for roi_bytes in FijiROIConverter.decode_rois_from_transmission( + encoded_rois + ) + ] + @staticmethod def bytes_to_java_roi(roi_bytes: bytes, scyjava_module) -> Any: """Convert ROI bytes to Java ROI object via temporary file.""" diff --git a/src/polystore/streaming/__init__.py b/src/polystore/streaming/__init__.py index 8a8536e..2217b4b 100644 --- a/src/polystore/streaming/__init__.py +++ b/src/polystore/streaming/__init__.py @@ -10,7 +10,32 @@ # This allows both: # from polystore.streaming import StreamingBackend # from polystore.streaming.receivers import FijiBatchProcessor -from polystore.streaming._streaming_backend import StreamingBackend - -__all__ = ["StreamingBackend"] +from polystore.streaming._streaming_backend import ( + FilePath, + RoiStreamPayload, + StreamablePayload, + StreamingBatchItemPreparationAuthority, + StreamingBatchMessageBuilder, + StreamingBatchMessageRequest, + StreamingBuiltBatch, + StreamingPreparedBatchItems, + StreamingBackend, + StreamingComponentNamesRequest, + StreamingItemPreparationRequest, + ViewerDisplayPayloadExtra, +) +__all__ = [ + "FilePath", + "RoiStreamPayload", + "StreamablePayload", + "StreamingBatchItemPreparationAuthority", + "StreamingBatchMessageBuilder", + "StreamingBatchMessageRequest", + "StreamingBuiltBatch", + "StreamingPreparedBatchItems", + "StreamingBackend", + "StreamingComponentNamesRequest", + "StreamingItemPreparationRequest", + "ViewerDisplayPayloadExtra", +] diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 417baa2..b848ebe 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -5,25 +5,524 @@ data to external systems without persistent storage capabilities. """ +from __future__ import annotations + import logging -import os import time import uuid +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from multiprocessing import resource_tracker, shared_memory from pathlib import Path -from typing import Any, Callable, List, Set, Union +from types import MappingProxyType +from typing import TypeAlias import numpy as np +import zmq +from arraybridge import convert_memory, detect_memory_type +from arraybridge.types import MemoryType as ArrayBridgeMemoryType from ..base import DataSink -from ..constants import TransportMode +from ..formats import DEFAULT_IMAGE_EXTENSIONS from ..streaming_constants import StreamingDataType -from ..roi import ROI, PointShape +from ..roi import ROI, ROI_ZIP_EXTENSION +from ..roi_converters import ROIShapeNapariPayloadConverter from ..zmq_config import POLYSTORE_ZMQ_CONFIG +from .viewer_transport import ( + ViewerMicroscopeHandlerABC, + ViewerStreamBatchItemInput, + ViewerStreamBatchItemSource, + ViewerStreamBackendKwargs, + ViewerStreamItemPayload, + ViewerStreamRequest, + ViewerTransportDefaults, +) from zmqruntime.ack_listener import GlobalAckListener -from zmqruntime.transport import coerce_transport_mode, get_zmq_transport_url +from zmqruntime.config import ZMQConfig +from zmqruntime.viewer_protocol import ( + ViewerBatchItemWireField, + ViewerBatchMessagePayload, + ViewerComponentMetadataPayload, + ViewerDisplayConfigWireField, + ViewerTransportEndpoint, + ViewerWirePayload, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +FilePath: TypeAlias = str | Path +RoiStreamPayload: TypeAlias = Sequence[ROI] +StreamablePayload: TypeAlias = np.ndarray | Sequence[ViewerWireValue] | RoiStreamPayload +ComponentValue = str | int | float | bool | tuple | None +ViewerDisplayPayloadExtraValues: TypeAlias = Mapping[ + str | ViewerDisplayConfigWireField, + ViewerWireValue, +] +STREAMING_TRANSPORT_DEFAULTS = ViewerTransportDefaults() + + +@dataclass(frozen=True) +class ViewerDisplayPayloadExtra: + """Nominal viewer-specific display payload extension.""" + + values: ViewerDisplayPayloadExtraValues = field( + default_factory=lambda: MappingProxyType({}) + ) + + @classmethod + def from_mapping( + cls, + values: ViewerDisplayPayloadExtraValues, + ) -> "ViewerDisplayPayloadExtra": + return cls(values) + + def to_wire_mapping(self) -> dict[str, ViewerWireValue]: + return ViewerWirePayload.mapping( + self.values, + context="viewer display payload extra", + ) + + +EMPTY_DISPLAY_PAYLOAD_EXTRA = ViewerDisplayPayloadExtra() + + +@dataclass(frozen=True) +class StreamingComponentDomainValue: + """Viewer component value normalized for a batch-level domain.""" + + value: ComponentValue + + @classmethod + def from_wire( + cls, + value: ViewerWireValue, + ) -> "StreamingComponentDomainValue": + if isinstance(value, (str, int, float, bool)) or value is None: + return cls(value) + if isinstance(value, tuple): + return cls(value) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return cls(tuple(value)) + raise TypeError( + "Streaming component values must be JSON scalar or tuple-like, " + f"got {type(value).__name__}." + ) + + +@dataclass(frozen=True) +class StreamingBatchImageMetadata: + """Validated metadata carried by one prepared viewer batch item.""" + + values: ViewerWireMapping + + @classmethod + def from_image_payload( + cls, + image_payload: ViewerWireMapping, + ) -> "StreamingBatchImageMetadata": + metadata = image_payload[ViewerBatchItemWireField.METADATA.value] + if not isinstance(metadata, Mapping): + raise TypeError( + "Streaming batch item metadata must be a mapping, " + f"got {type(metadata).__name__}." + ) + return cls( + ViewerWirePayload.mapping( + metadata, + context="streaming batch item metadata", + ) + ) + + def component_value(self, component: str) -> ComponentValue | None: + if component not in self.values: + return None + return StreamingComponentDomainValue.from_wire( + self.values[component] + ).value + + +class StreamingComponentValueDomainAuthority: + """Build batch-level component value domains from stream item metadata.""" + + @staticmethod + def wire_payload( + stream_request: ViewerStreamRequest, + batch_images: Sequence[ViewerWireMapping], + ) -> dict[str, ViewerWireValue]: + component_order = stream_request.display_semantics.component_order + values_by_component: dict[str, list[ComponentValue]] = { + component: [] for component in component_order + } + for image_payload in batch_images: + metadata = StreamingBatchImageMetadata.from_image_payload(image_payload) + for component in component_order: + value = metadata.component_value(component) + if value is None: + continue + if value not in values_by_component[component]: + values_by_component[component].append(value) + return { + component: values + for component, values in values_by_component.items() + if values + } + +@dataclass(frozen=True) +class StreamingComponentNamesRequest: + """Component-label metadata requested for one viewer batch.""" + + component_names: Sequence[str] + log_prefix: str | None = None + verbose: bool = False + + @classmethod + def from_stream_request( + cls, + stream_request: ViewerStreamRequest, + log_prefix: str | None = None, + verbose: bool = False, + ) -> "StreamingComponentNamesRequest": + return cls( + component_names=stream_request.display_semantics.component_order, + log_prefix=log_prefix, + verbose=verbose, + ) + + +@dataclass(frozen=True) +class StreamingBatchMessageRequest: + """Inputs for building one viewer batch message.""" + + data_list: list[StreamablePayload] + file_paths: list[FilePath] + stream_request: ViewerStreamRequest + component_names_request: StreamingComponentNamesRequest + display_payload_extra: ViewerDisplayPayloadExtra = field( + default_factory=ViewerDisplayPayloadExtra + ) + + +@dataclass(frozen=True) +class StreamingPreparedBatchItems: + """Prepared per-item viewer payloads before batch-level metadata is attached.""" + + batch_images: list[dict[str, ViewerWireValue]] + image_ids: list[str] + + +@dataclass(frozen=True) +class StreamingBuiltBatch(StreamingPreparedBatchItems): + """Prepared viewer message and per-item transmission bookkeeping.""" + + message: dict[str, ViewerWireValue] + + +@dataclass(frozen=True) +class StreamingItemPath: + """Nominal path identity for one item in a viewer stream batch.""" + + value: FilePath + + @property + def wire_value(self) -> str: + return str(self.value) + + +@dataclass(frozen=True) +class StreamingPayloadFileRequest: + """Shared payload/file identity for viewer item preparation requests.""" + + data: StreamablePayload + item_path: StreamingItemPath + + +@dataclass(frozen=True) +class StreamingItemPreparationRequest(StreamingPayloadFileRequest): + """Inputs needed to prepare one payload for a viewer batch item.""" + + streaming_data_type: StreamingDataType + + +@dataclass(frozen=True) +class StreamingSharedMemoryRequest(StreamingPayloadFileRequest): + """Inputs needed to allocate one image payload into shared memory.""" + + shm_prefix: str + + +@dataclass(frozen=True) +class StreamingSharedMemoryPayload: + """Wire payload describing a shared-memory image allocation.""" + + item_path: StreamingItemPath + shape: tuple[int, ...] + dtype: str + shm_name: str + + def to_wire_mapping(self) -> dict[str, ViewerWireValue]: + return { + ViewerBatchItemWireField.PATH.value: self.item_path.wire_value, + ViewerBatchItemWireField.SHAPE.value: self.shape, + ViewerBatchItemWireField.DTYPE.value: self.dtype, + ViewerBatchItemWireField.SHM_NAME.value: self.shm_name, + } + + +@dataclass(frozen=True) +class StreamingSharedMemoryBlock: + """Allocated shared memory and the wire payload that names it.""" + + shared_memory: shared_memory.SharedMemory + payload: StreamingSharedMemoryPayload + + +@dataclass(frozen=True) +class StreamingSharedMemoryName: + """Unique shared-memory name for one viewer transfer block.""" + + value: str + + @classmethod + def unique( + cls, + shm_prefix: str, + ) -> "StreamingSharedMemoryName": + return cls(f"{shm_prefix}{uuid.uuid4().hex[:12]}") + + +class StreamingPayloadMemoryAuthority: + """Memory conversion authority for streamable image payloads.""" + + @staticmethod + def to_numpy(data: StreamablePayload) -> np.ndarray: + if isinstance(data, np.ndarray): + return data + if isinstance(data, (list, tuple)): + return np.asarray(data) + return convert_memory( + data, + detect_memory_type(data), + ArrayBridgeMemoryType.NUMPY.value, + gpu_id=0, + ) + + +class StreamingSharedMemoryAuthority: + """Allocate image payloads for viewer transfer through shared memory.""" + + @classmethod + def create( + cls, + request: StreamingSharedMemoryRequest, + ) -> StreamingSharedMemoryBlock: + np_data = StreamingPayloadMemoryAuthority.to_numpy(request.data) + shm_name = StreamingSharedMemoryName.unique(request.shm_prefix).value + shm = shared_memory.SharedMemory( + create=True, + size=np_data.nbytes, + name=shm_name, + ) + resource_tracker.unregister(shm._name, "shared_memory") + + shm_array = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf) + shm_array[:] = np_data[:] + + return StreamingSharedMemoryBlock( + shared_memory=shm, + payload=StreamingSharedMemoryPayload( + item_path=request.item_path, + shape=tuple(int(dimension) for dimension in np_data.shape), + dtype=str(np_data.dtype), + shm_name=shm_name, + ), + ) + + +class StreamingDataTypeAuthority: + """Detect the viewer payload kind for one streamed object.""" + + @staticmethod + def detect(data: StreamablePayload) -> StreamingDataType: + is_roi = isinstance(data, list) and len(data) > 0 and isinstance(data[0], ROI) + + if not is_roi: + return StreamingDataType.IMAGE + + return ROIShapeNapariPayloadConverter.streaming_data_type_for_rois(data) + + +class StreamingComponentNamesMetadataCollector: + """Collect viewer component-label metadata for one batch.""" + + @staticmethod + def collect( + plate_path: FilePath | None, + microscope_handler: ViewerMicroscopeHandlerABC, + request: StreamingComponentNamesRequest, + ) -> dict[str, ViewerWireValue]: + component_names_metadata = {} + + if plate_path is None: + if request.verbose and request.log_prefix: + logger.warning("%s: No plate_path in kwargs", request.log_prefix) + return component_names_metadata + + for component_name in request.component_names: + metadata = microscope_handler.metadata_handler.get_component_values( + plate_path, + component_name, + ) + if request.verbose and request.log_prefix: + logger.info( + "%s: Got %s metadata: %s", + request.log_prefix, + component_name, + metadata, + ) + if metadata: + component_names_metadata[component_name] = metadata + + return component_names_metadata + + +class StreamingDisplayPayloadBuilder: + """Build the shared viewer display-config payload.""" + + @staticmethod + def build( + stream_request: ViewerStreamRequest, + display_payload_extra: ViewerDisplayPayloadExtra, + ): + return stream_request.display_semantics.batch_display_payload( + display_payload_extra.to_wire_mapping() + ) + + +class StreamingBatchItemPreparationAuthority: + """Prepare per-item viewer payloads and transmission bookkeeping.""" + + @staticmethod + def prepare( + backend: "StreamingBackend", + request: StreamingBatchMessageRequest, + ) -> StreamingPreparedBatchItems: + batch_images = [] + image_ids = [] + + for index, (data, file_path) in enumerate( + zip(request.data_list, request.file_paths) + ): + item_path = StreamingItemPath(file_path) + image_id = str(uuid.uuid4()) + image_ids.append(image_id) + + streaming_data_type = StreamingDataTypeAuthority.detect(data) + item_payload = backend._prepare_batch_item( + StreamingItemPreparationRequest( + data=data, + item_path=item_path, + streaming_data_type=streaming_data_type, + ) + ) + + batch_images.append( + request.stream_request.producer.batch_item_payload( + ViewerStreamBatchItemSource.from_input( + ViewerStreamBatchItemInput( + stream_source=request.stream_request.source, + item_payload=item_payload.item_payload, + streaming_data_type=item_payload.streaming_data_type, + file_path=item_path.value, + index=index, + image_id=image_id, + ) + ) + ).to_wire_mapping() + ) + + return StreamingPreparedBatchItems( + batch_images=batch_images, + image_ids=image_ids, + ) + + +class StreamingComponentMetadataPayloadAuthority: + """Resolve the component metadata payload for one viewer batch.""" + + @staticmethod + def payload( + request: StreamingBatchMessageRequest, + prepared_items: StreamingPreparedBatchItems, + ) -> ViewerComponentMetadataPayload: + declared = ViewerComponentMetadataPayload.from_optional_wire_mapping( + request.stream_request.message_extra_payload() + ) + if declared is not None: + return declared + return ViewerComponentMetadataPayload( + component_names_metadata=( + StreamingComponentNamesMetadataCollector.collect( + request.stream_request.source.identity.plate_path, + request.stream_request.source.identity.microscope_handler, + request.component_names_request, + ) + ), + component_value_domain=( + StreamingComponentValueDomainAuthority.wire_payload( + request.stream_request, + prepared_items.batch_images, + ) + ), + ) + + +class StreamingBatchMessageBuilder: + """Build complete viewer batch messages from prepared items.""" + + @classmethod + def build( + cls, + backend: "StreamingBackend", + request: StreamingBatchMessageRequest, + ) -> StreamingBuiltBatch: + if len(request.data_list) != len(request.file_paths): + raise ValueError("data_list and file_paths must have the same length") + + prepared_items = StreamingBatchItemPreparationAuthority.prepare( + backend, + request, + ) + + component_metadata_payload = ( + StreamingComponentMetadataPayloadAuthority.payload( + request, + prepared_items, + ) + ) + + display_payload = StreamingDisplayPayloadBuilder.build( + request.stream_request, + request.display_payload_extra, + ) + message = ViewerBatchMessagePayload.from_parts( + images=prepared_items.batch_images, + display_payload=display_payload, + component_metadata=component_metadata_payload, + timestamp=time.time(), + extra=ViewerComponentMetadataPayload.strip_component_metadata( + backend.message_extra(request.stream_request) + ), + ).to_wire_mapping() + + return StreamingBuiltBatch( + message=message, + batch_images=prepared_items.batch_images, + image_ids=prepared_items.image_ids, + ) + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -42,15 +541,22 @@ class StreamingBackend(DataSink): """ # Abstract class attributes that subclasses must define - VIEWER_TYPE: str = None - SHM_PREFIX: str = None + VIEWER_TYPE: str + SHM_PREFIX: str # Class attribute: streaming backends only support image array data and ROIs supports_arbitrary_files: bool = False # Extensions that streaming backends can handle # Subclasses can override to add support for specific formats - SUPPORTED_EXTENSIONS: set[str] = {'.tif', '.tiff', '.png', '.jpg', '.jpeg', '.roi.zip'} + SUPPORTED_EXTENSIONS: frozenset[str] = frozenset( + (*DEFAULT_IMAGE_EXTENSIONS, ROI_ZIP_EXTENSION) + ) + + def supports_file_path(self, path: FilePath) -> bool: + """Return whether the stream backend can render this output path.""" + name = Path(path).name.lower() + return any(name.endswith(ext) for ext in self.SUPPORTED_EXTENSIONS) @property def requires_filesystem_validation(self) -> bool: @@ -59,9 +565,9 @@ def requires_filesystem_validation(self) -> bool: def _filter_streamable_files( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - ) -> tuple[List[Any], List[Union[str, Path]], List[Union[str, Path]]]: + data_list: list[StreamablePayload], + file_paths: list[FilePath], + ) -> tuple[list[StreamablePayload], list[FilePath], list[FilePath]]: """ Filter data to only include files with supported extensions. @@ -77,13 +583,7 @@ def _filter_streamable_files( skipped_paths = [] for data, path in zip(data_list, file_paths): - path_obj = Path(path) - name = path_obj.name.lower() - - # Check if extension is supported - is_supported = any(name.endswith(ext) for ext in self.SUPPORTED_EXTENSIONS) - - if is_supported: + if self.supports_file_path(path): filtered_data.append(data) filtered_paths.append(path) else: @@ -97,155 +597,33 @@ def _filter_streamable_files( return filtered_data, filtered_paths, skipped_paths - def __init__(self, transport_config=None): + def __init__(self, transport_config: ZMQConfig = POLYSTORE_ZMQ_CONFIG): """Initialize ZeroMQ and shared memory infrastructure.""" self._publishers = {} self._context = None self._shared_memory_blocks = {} - self._transport_config = transport_config or POLYSTORE_ZMQ_CONFIG - - def _get_publisher(self, host: str, port: int, transport_mode: TransportMode, transport_config=None): - """ - Lazy initialization of ZeroMQ publisher (common for all streaming backends). - - Uses REQ socket for Fiji (synchronous request/reply with blocking) - and PUB socket for Napari (broadcast pattern). - - Args: - host: Host to connect to (ignored for IPC mode) - port: Port to connect to - transport_mode: IPC or TCP transport (required - comes from config) - - Returns: - ZeroMQ publisher socket - """ - # Generate transport URL using centralized function - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - key = url # Use URL as key instead of host:port - if key not in self._publishers: - try: - import zmq - if self._context is None: - self._context = zmq.Context() - - # Use REQ socket for all viewers (synchronous request/reply) - # All viewers must send acknowledgment after processing - publisher = self._context.socket(zmq.REQ) - - publisher.connect(url) - socket_name = "REQ" - logger.info(f"{self.VIEWER_TYPE} streaming {socket_name} socket connected to {url}") - time.sleep(0.1) - self._publishers[key] = publisher - - except ImportError: - logger.error("ZeroMQ not available - streaming disabled") - raise RuntimeError("ZeroMQ required for streaming") - - return self._publishers[key] - - def _parse_component_metadata(self, file_path: Union[str, Path], microscope_handler, - source: str) -> dict: - """ - Parse component metadata from filename (common for all streaming backends). - - Args: - file_path: Path to parse - microscope_handler: Handler with parser - source: Pre-built source value (step_name during execution, subdir when loading from disk) - - Returns: - Component metadata dict with source added - """ - filename = os.path.basename(str(file_path)) - component_metadata = microscope_handler.parser.parse_filename(filename) - - # Add pre-built source value directly - component_metadata['source'] = source - - return component_metadata - - def _detect_data_type(self, data: Any): - """ - Detect if data is ROI (shapes/points) or image (common for all streaming backends). + self._transport_config = transport_config - Args: - data: Data to check - - Returns: - StreamingDataType enum value (IMAGE, SHAPES, or POINTS) - """ - is_roi = isinstance(data, list) and len(data) > 0 and isinstance(data[0], ROI) - - if not is_roi: - return StreamingDataType.IMAGE - - # Check if all ROIs contain only PointShape objects (for points layer) - all_points = all( - roi.shapes and all(isinstance(shape, PointShape) for shape in roi.shapes) - for roi in data + def create_shared_memory_payload( + self, + data: StreamablePayload, + file_path: FilePath, + ) -> dict[str, ViewerWireValue]: + block = StreamingSharedMemoryAuthority.create( + StreamingSharedMemoryRequest( + data=data, + item_path=StreamingItemPath(file_path), + shm_prefix=self.SHM_PREFIX, + ) ) - - return StreamingDataType.POINTS if all_points else StreamingDataType.SHAPES - - def _create_shared_memory(self, data: Any, file_path: Union[str, Path]) -> dict: - """ - Create shared memory for image data (common for all streaming backends). - - Args: - data: Image data to put in shared memory - file_path: Path identifier - - Returns: - Dict with shared memory metadata - """ - # Convert to numpy - np_data = data.cpu().numpy() if hasattr(data, 'cpu') else \ - data.get() if hasattr(data, 'get') else np.asarray(data) - - # Create shared memory with hash-based naming to avoid "File name too long" errors - # Hash the timestamp and object ID to create a short, unique name - from multiprocessing import shared_memory, resource_tracker - import hashlib - timestamp = time.time_ns() - obj_id = id(data) - hash_input = f"{obj_id}_{timestamp}" - hash_suffix = hashlib.md5(hash_input.encode()).hexdigest()[:8] - shm_name = f"{self.SHM_PREFIX}{hash_suffix}" - shm = shared_memory.SharedMemory(create=True, size=np_data.nbytes, name=shm_name) - - # Unregister from resource tracker - we manage cleanup manually - # This prevents resource tracker warnings when worker processes exit - # before the viewer has unlinked the shared memory - try: - resource_tracker.unregister(shm._name, "shared_memory") - except Exception: - pass # Ignore errors if already unregistered - - shm_array = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf) - shm_array[:] = np_data[:] - self._shared_memory_blocks[shm_name] = shm - - return { - 'path': str(file_path), - 'shape': np_data.shape, - 'dtype': str(np_data.dtype), - 'shm_name': shm_name, - } + self._shared_memory_blocks[block.payload.shm_name] = block.shared_memory + return block.payload.to_wire_mapping() def _register_with_queue_tracker( self, - port: int, - image_ids: List[str], - transport_mode: TransportMode | None = None, - transport_config=None, + transport_endpoint: ViewerTransportEndpoint, + image_ids: list[str], + transport_config: ZMQConfig, ) -> None: """ Register sent images with queue tracker (common for all streaming backends). @@ -255,168 +633,143 @@ def _register_with_queue_tracker( image_ids: List of image IDs to register """ listener = GlobalAckListener() - transport_config = transport_config or self._transport_config listener.start( port=transport_config.shared_ack_port, - transport_mode=coerce_transport_mode(transport_mode), + transport_mode=transport_endpoint.resolved_transport_mode(), config=transport_config, ) from zmqruntime.queue_tracker import GlobalQueueTrackerRegistry registry = GlobalQueueTrackerRegistry() - tracker = registry.get_or_create_tracker(port, self.VIEWER_TYPE) + tracker = registry.get_or_create_tracker( + transport_endpoint.port, + self.VIEWER_TYPE, + ) for image_id in image_ids: tracker.register_sent(image_id) - def _build_component_modes(self, display_config) -> dict: - component_modes = {} - for comp_name in display_config.COMPONENT_ORDER: - mode_field = f"{comp_name}_mode" - if hasattr(display_config, mode_field): - mode = getattr(display_config, mode_field) - component_modes[comp_name] = mode.value - return component_modes - - def _build_display_config_base(self, display_config, component_modes: dict) -> dict: - return { - "component_modes": component_modes, - "component_order": display_config.COMPONENT_ORDER, - } + def _cleanup_shared_memory_blocks(self, batch_images, unlink: bool = False) -> None: + for img in batch_images: + shm_name = img.get(ViewerBatchItemWireField.SHM_NAME.value) + if shm_name and shm_name in self._shared_memory_blocks: + try: + shm = self._shared_memory_blocks.pop(shm_name) + shm.close() + if unlink: + shm.unlink() + except Exception as e: + logger.warning(f"Failed to cleanup shared memory {shm_name}: {e}") - def _collect_component_names_metadata( + def _prepare_batch_item( self, - plate_path, - microscope_handler, - component_names: List[str] | None = None, - log_prefix: str | None = None, - verbose: bool = False, - ) -> dict: - component_names = component_names or ["channel", "well", "site"] - component_names_metadata = {} - - if not plate_path or not microscope_handler: - if verbose and log_prefix: - if not plate_path: - logger.warning(f"{log_prefix}: No plate_path in kwargs") - if not microscope_handler: - logger.warning(f"{log_prefix}: No microscope_handler") - return component_names_metadata - - try: - for comp_name in component_names: - method_name = f"get_{comp_name}_values" - method = getattr(microscope_handler.metadata_handler, method_name, None) - if callable(method): - try: - metadata = method(plate_path) - if verbose and log_prefix: - logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") - if metadata: - component_names_metadata[comp_name] = metadata - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get {comp_name} metadata: {e}", exc_info=True) - elif verbose and log_prefix: - logger.info(f"{log_prefix}: No method {method_name} on metadata_handler") - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get component metadata: {e}", exc_info=True) + request: StreamingItemPreparationRequest, + ) -> ViewerStreamItemPayload: + raise NotImplementedError - return component_names_metadata - - def _prepare_batch_items( + def display_payload_extra( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], - ) -> tuple[list[dict], list[str]]: - batch_images = [] - image_ids = [] - - for data, file_path in zip(data_list, file_paths): - image_id = str(uuid.uuid4()) - image_ids.append(image_id) + stream_request: ViewerStreamRequest, + ) -> ViewerDisplayPayloadExtra: + return EMPTY_DISPLAY_PAYLOAD_EXTRA - data_type = self._detect_data_type(data) - component_metadata = self._parse_component_metadata( - file_path, microscope_handler, source - ) - item_data, data_type_value = prepare_item(data, file_path, data_type) - - batch_images.append( - { - **item_data, - "data_type": data_type_value, - "metadata": component_metadata, - "image_id": image_id, - } - ) + def message_extra( + self, + stream_request: ViewerStreamRequest, + ) -> dict[str, ViewerWireValue]: + return stream_request.message_extra_payload() - return batch_images, image_ids + def component_names_request( + self, + stream_request: ViewerStreamRequest, + ) -> StreamingComponentNamesRequest: + return StreamingComponentNamesRequest.from_stream_request(stream_request) - def _build_batch_message( + def after_batch_message_built( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - display_config, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], - plate_path: Union[str, Path, None] = None, - component_names_kwargs: dict | None = None, - display_payload_extra: dict | None = None, - message_extra: dict | None = None, - ) -> tuple[dict, list[dict], list[str]]: - if len(data_list) != len(file_paths): - raise ValueError("data_list and file_paths must have the same length") + stream_request: ViewerStreamRequest, + built_batch: StreamingBuiltBatch, + ) -> None: + pass - batch_images, image_ids = self._prepare_batch_items( + def save_batch( + self, + data_list: list[StreamablePayload], + file_paths: list[FilePath], + **kwargs, + ) -> None: + """Stream a batch of image or ROI payloads to this viewer.""" + data_list, file_paths, _skipped_paths = self._filter_streamable_files( data_list, file_paths, - microscope_handler, - source, - prepare_item, ) + if not data_list: + return + + stream_request = ViewerStreamBackendKwargs.from_kwargs(kwargs).stream_request + built_batch = StreamingBatchMessageBuilder.build( + self, + StreamingBatchMessageRequest( + data_list=data_list, + file_paths=file_paths, + stream_request=stream_request, + component_names_request=self.component_names_request(stream_request), + display_payload_extra=self.display_payload_extra(stream_request), + ), + ) + self.after_batch_message_built(stream_request, built_batch) - component_modes = self._build_component_modes(display_config) - - component_names_metadata = self._collect_component_names_metadata( - plate_path, - microscope_handler, - **(component_names_kwargs or {}), + transport_config = stream_request.transport_config.resolve( + self._transport_config ) + transport_endpoint = stream_request.viewer_transport + self._register_with_queue_tracker( + transport_endpoint, + built_batch.image_ids, + transport_config=transport_config, + ) + url = transport_endpoint.data_url(transport_config) - display_payload = self._build_display_config_base(display_config, component_modes) - if display_payload_extra: - display_payload.update(display_payload_extra) + if self._context is None: + self._context = zmq.Context() - message = { - "type": "batch", - "images": batch_images, - "display_config": display_payload, - "component_names_metadata": component_names_metadata, - "timestamp": time.time(), - } - if message_extra: - message.update(message_extra) + viewer_name = str(self.VIEWER_TYPE).title() + viewer_label = viewer_name.upper() + ack_policy = STREAMING_TRANSPORT_DEFAULTS.ack_policy(viewer_name) + socket = self._context.socket(zmq.REQ) + ack_policy.apply_socket_options(socket) + socket.connect(url) + time.sleep(0.1) - return message, batch_images, image_ids + try: + logger.info( + "📤 %s BACKEND: Sending batch of %d images to %s on port %s " + "(REQ/REP - blocking until ack)", + viewer_label, + len(built_batch.batch_images), + viewer_name, + transport_endpoint.port, + ) + socket.send_json(built_batch.message) + ack_response = ack_policy.receive( + socket, + lambda: self._cleanup_shared_memory_blocks( + built_batch.batch_images, + unlink=True, + ), + port=transport_endpoint.port, + ) + logger.info( + "✅ %s BACKEND: Received ack from %s: %s", + viewer_label, + viewer_name, + ack_policy.status(ack_response), + ) + finally: + socket.close() - def _cleanup_shared_memory_blocks(self, batch_images, unlink: bool = False) -> None: - for img in batch_images: - shm_name = img.get("shm_name") - if shm_name and shm_name in self._shared_memory_blocks: - try: - shm = self._shared_memory_blocks.pop(shm_name) - shm.close() - if unlink: - shm.unlink() - except Exception as e: - logger.warning(f"Failed to cleanup shared memory {shm_name}: {e}") + self._cleanup_shared_memory_blocks(built_batch.batch_images, unlink=False) - def save(self, data: Any, file_path: Union[str, Path], **kwargs) -> None: + def save(self, data: StreamablePayload | str, file_path: FilePath, **kwargs) -> None: """ Stream single item (common for all streaming backends). diff --git a/src/polystore/streaming/base.py b/src/polystore/streaming/base.py index 5cbfec0..c0bfb70 100644 --- a/src/polystore/streaming/base.py +++ b/src/polystore/streaming/base.py @@ -27,7 +27,6 @@ class TypedData(Generic[T]): """ items: List[T] metadata: Dict[str, Any] - source: str class ComponentAccessor(ABC): diff --git a/src/polystore/streaming/handlers/fiji_rois.py b/src/polystore/streaming/handlers/fiji_rois.py index 489adaa..a1fb43f 100644 --- a/src/polystore/streaming/handlers/fiji_rois.py +++ b/src/polystore/streaming/handlers/fiji_rois.py @@ -5,10 +5,123 @@ """ import logging +from collections.abc import Mapping, Sequence +from dataclasses import dataclass + from polystore.streaming.handlers import HandlerBase from polystore.streaming.base import HandlerContext +from zmqruntime.viewer_protocol import ( + ViewerBatchItemWireField, + ViewerWireMapping, + ViewerWireValue, +) logger = logging.getLogger(__name__) +FijiROIComponentValue = ( + str + | int + | float + | bool + | None + | tuple["FijiROIComponentValue", ...] +) +FijiROIComponentKey = tuple[FijiROIComponentValue, ...] + + +def fiji_roi_component_value(value: ViewerWireValue) -> FijiROIComponentValue: + """Return a hashable ImageJ-axis component value from viewer wire data.""" + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, tuple): + return tuple(fiji_roi_component_value(item) for item in value) + if isinstance(value, list): + return tuple(fiji_roi_component_value(item) for item in value) + raise TypeError( + "Fiji ROI component metadata values must be scalar or tuple-like, " + f"got {type(value).__name__}." + ) + + +@dataclass(frozen=True, slots=True) +class FijiROIWireItem: + """Typed ROI wire item used by the legacy Fiji handler path.""" + + payload: ViewerWireMapping + + @classmethod + def from_payload(cls, payload: Mapping[str, ViewerWireValue]) -> "FijiROIWireItem": + return cls(payload) + + @property + def rois(self) -> list[str]: + field = ViewerBatchItemWireField.ROIS.value + if field not in self.payload: + raise ValueError("Fiji ROI item missing required 'rois' field.") + value = self.payload[field] + if not isinstance(value, Sequence) or isinstance(value, (str, bytes)): + raise TypeError("Fiji ROI item 'rois' field must be a sequence.") + return [str(encoded_roi) for encoded_roi in value] + + @property + def metadata(self) -> ViewerWireMapping: + field = ViewerBatchItemWireField.METADATA.value + if field not in self.payload: + raise ValueError("Fiji ROI item missing required 'metadata' field.") + value = self.payload[field] + if not isinstance(value, Mapping): + raise TypeError("Fiji ROI item 'metadata' field must be a mapping.") + return value + + @property + def image_id(self) -> str | None: + field = ViewerBatchItemWireField.IMAGE_ID.value + if field not in self.payload or self.payload[field] is None: + return None + return str(self.payload[field]) + + def component_value_tuple( + self, + component_names: Sequence[str], + ) -> FijiROIComponentKey: + return tuple( + fiji_roi_component_value(self.metadata[component_name]) + for component_name in component_names + ) + + +@dataclass(frozen=True, slots=True) +class FijiROIAxisPosition: + """Strict one-based ImageJ coordinate resolver for legacy Fiji ROI handling.""" + + component_names: Sequence[str] + values: Sequence[FijiROIComponentKey] + + @classmethod + def from_items( + cls, + items: Sequence[FijiROIWireItem], + component_names: Sequence[str], + ) -> "FijiROIAxisPosition": + values = tuple( + sorted( + { + item.component_value_tuple(component_names) + for item in items + } + ) + ) + return cls(component_names, values) + + def one_based_position(self, item: FijiROIWireItem) -> int: + if not self.component_names: + return 1 + value_tuple = item.component_value_tuple(self.component_names) + if value_tuple not in self.values: + raise ValueError( + f"Fiji ROI component value {value_tuple!r} is outside axis domain " + f"{self.values!r}." + ) + return self.values.index(value_tuple) + 1 class FijiROIHandler(HandlerBase): @@ -24,9 +137,6 @@ def can_handle(data_type: str) -> bool: @staticmethod def handle(context: HandlerContext) -> None: """Add ROIs to ImageJ ROI Manager.""" - # Access data via typed wrapper - roi_data = context.data - # Get or create RoiManager on EDT import scyjava as sj RoiManager = sj.jimport("ij.plugin.frame.RoiManager") @@ -49,47 +159,54 @@ def run(self): # Get or assign integer group ID for this window group_id = context.server._get_or_create_group_id(context.window_key) + roi_items = tuple( + FijiROIWireItem.from_payload(item) + for item in context.data.items + ) + # Process ROIs with component positioning channel_comps = context.components.get_by_mode("channel") slice_comps = context.components.get_by_mode("slice") frame_comps = context.components.get_by_mode("frame") - channel_values = context.components.collect_values(channel_comps) - slice_values = context.components.collect_values(slice_comps) - frame_values = context.components.collect_values(frame_comps) + channel_position = FijiROIAxisPosition.from_items(roi_items, channel_comps) + slice_position = FijiROIAxisPosition.from_items(roi_items, slice_comps) + frame_position = FijiROIAxisPosition.from_items(roi_items, frame_comps) total_rois_added = 0 - for roi_item in roi_data.items: - rois_encoded = roi_item.get("rois", []) + from polystore.roi_converters import FijiROIConverter + + for roi_item in roi_items: + rois_encoded = roi_item.rois if not rois_encoded: - if image_id := roi_item.get("image_id"): + if image_id := roi_item.image_id: context.server._send_ack(image_id, status="success") continue - meta = roi_item.get("metadata", {}) - file_path = roi_item.get("path", "unknown") - logger.info(f"🔬 FIJI ROI HANDLER: Processing {len(rois_encoded)} ROIs") # Convert ROIs to ImageJ format - from polystore.roi_converters import FijiROIConverter - rois_list = FijiROIConverter.to_fiji_rois( + java_rois = FijiROIConverter.transmission_to_java_rois( rois_encoded, - channel_values=channel_values, - slice_values=slice_values, - frame_values=frame_values, - channel_components=channel_comps, - slice_components=slice_comps, - frame_components=frame_comps, + sj, + ) + imagej_position = ( + channel_position.one_based_position(roi_item), + slice_position.one_based_position(roi_item), + frame_position.one_based_position(roi_item), ) # Add ROIs to manager with group ID - for roi_obj in rois_list: - roi_obj.setProperty("group", group_id) + for roi_obj in java_rois: + roi_obj.setPosition(*imagej_position) + roi_obj.setGroup(group_id) rm.addRoi(roi_obj) - total_rois_added += len(rois_list) + total_rois_added += len(java_rois) + + if image_id := roi_item.image_id: + context.server._send_ack(image_id, status="success") logger.info( f"🔬 FIJI ROI HANDLER: Added {total_rois_added} ROIs to window '{context.window_key}'" diff --git a/src/polystore/streaming/identity.py b/src/polystore/streaming/identity.py new file mode 100644 index 0000000..1c543a6 --- /dev/null +++ b/src/polystore/streaming/identity.py @@ -0,0 +1,241 @@ +"""Nominal stream identity records shared by viewer streaming backends.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import ClassVar, Mapping, Sequence, TypeAlias + + +StreamProducerPayloadValue: TypeAlias = str | int | None +StreamProducerPayloadMapping: TypeAlias = Mapping[str, StreamProducerPayloadValue] +RouteKeyPart: TypeAlias = str | int | float | bool | None + + +class StreamProducerOrigin(str, Enum): + """Nominal stream producer origin values.""" + + PIPELINE = "pipeline" + MANUAL = "manual" + DIRECT = "direct" + + +class FixedStreamProducerIdentityKind(str, Enum): + """Producer identities whose origin and output kind are intentionally equal.""" + + MANUAL = StreamProducerOrigin.MANUAL.value + DIRECT = StreamProducerOrigin.DIRECT.value + + +class StreamProducerIdentityPayload(dict[str, StreamProducerPayloadValue]): + """Wire payload for one stream producer identity.""" + + @classmethod + def from_identity( + cls, + identity: "StreamProducerIdentity", + ) -> "StreamProducerIdentityPayload": + return cls( + origin=identity.origin, + output_kind=identity.output_kind, + output_key=identity.output_key, + step_name=identity.step_name, + pipeline_position=identity.pipeline_position, + step_scope_id=identity.step_scope_id, + invocation_key=identity.invocation_key, + artifact_kind=identity.artifact_kind, + ) + + +@dataclass(frozen=True, slots=True) +class StreamProducerIdentity: + """Producer/output identity for one streamed viewer item.""" + + origin: str + output_kind: str + output_key: str + step_name: str | None = None + pipeline_position: int | None = None + step_scope_id: str | None = None + invocation_key: str | None = None + artifact_kind: str | None = None + + @classmethod + def pipeline_output( + cls, + *, + output_kind: str, + output_key: str, + step_name: str, + pipeline_position: int | None, + step_scope_id: str | None = None, + artifact_kind: str | None = None, + ) -> "StreamProducerIdentity": + """Build identity for one pipeline-produced stream output.""" + return cls( + origin=StreamProducerOrigin.PIPELINE.value, + output_kind=output_kind, + output_key=output_key, + step_name=step_name, + pipeline_position=pipeline_position, + step_scope_id=step_scope_id, + artifact_kind=artifact_kind, + ) + + @classmethod + def fixed_output( + cls, + kind: FixedStreamProducerIdentityKind, + output_key: str, + ) -> "StreamProducerIdentity": + """Build identity for producer kinds whose origin owns the output kind.""" + return cls( + origin=kind.value, + output_kind=kind.value, + output_key=output_key, + ) + + @classmethod + def from_payload( + cls, + payload: "StreamProducerIdentity | StreamProducerPayloadMapping", + ) -> "StreamProducerIdentity": + if isinstance(payload, cls): + return payload + if not isinstance(payload, Mapping): + raise TypeError( + "Stream producer identity must be a mapping or StreamProducerIdentity, " + f"got {type(payload).__name__}." + ) + return cls( + origin=_required_payload_str(payload, "origin"), + output_kind=_required_payload_str(payload, "output_kind"), + output_key=_required_payload_str(payload, "output_key"), + step_name=_optional_payload_str(payload, "step_name"), + pipeline_position=_optional_payload_int(payload, "pipeline_position"), + step_scope_id=_optional_payload_str(payload, "step_scope_id"), + invocation_key=_optional_payload_str(payload, "invocation_key"), + artifact_kind=_optional_payload_str(payload, "artifact_kind"), + ) + + def to_payload(self) -> StreamProducerIdentityPayload: + return StreamProducerIdentityPayload.from_identity(self) + + def route_parts(self) -> tuple[str, ...]: + parts = [ + f"origin_{self.origin}", + f"kind_{self.output_kind}", + f"out_{self.output_key}", + ] + if self.pipeline_position is not None: + parts.append(f"step_{self.pipeline_position}") + if self.step_scope_id: + parts.append(f"scope_{self.step_scope_id}") + if self.step_name: + parts.append(f"name_{self.step_name}") + if self.invocation_key: + parts.append(f"invocation_{self.invocation_key}") + if self.artifact_kind: + parts.append(f"artifact_{self.artifact_kind}") + return tuple(parts) + + +def _required_payload_str( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> str: + if field_name not in payload: + raise ValueError( + f"Stream producer identity missing required field: {field_name}" + ) + value = payload[field_name] + if value in (None, ""): + raise ValueError( + f"Stream producer identity missing required field: {field_name}" + ) + return str(value) + + +def _optional_payload_str( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> str | None: + if field_name not in payload: + return None + return _optional_str(payload[field_name]) + + +def _optional_payload_int( + payload: StreamProducerPayloadMapping, + field_name: str, +) -> int | None: + if field_name not in payload: + return None + value = payload[field_name] + if value is None: + return None + return int(value) + + +def _optional_str(value: StreamProducerPayloadValue) -> str | None: + if value is None: + return None + text = str(value) + return text or None + + +class StreamProducerDisplayNameAuthority: + """Own user-facing labels derived from stream producer identity.""" + + PIPELINE_DISPLAY_INDEX_BASE: ClassVar[int] = 1 + OUTPUT_KEY_OMITTING_KINDS: ClassVar[frozenset[str]] = frozenset( + {"main", "manual", "direct"} + ) + + @staticmethod + def producer_base(producer: StreamProducerIdentity) -> str: + if producer.step_name: + return producer.step_name + return producer.output_key + + @classmethod + def producer_label(cls, producer: StreamProducerIdentity) -> str: + base = cls.producer_base(producer) + if producer.pipeline_position is None: + return base + return f"{producer.pipeline_position + cls.PIPELINE_DISPLAY_INDEX_BASE}. {base}" + + @classmethod + def output_label(cls, producer: StreamProducerIdentity) -> str: + parts = [cls.producer_label(producer)] + if cls.includes_output_key(producer): + parts.append(producer.output_key) + return " ".join(part for part in parts if part) + + @classmethod + def disambiguation_label(cls, producer: StreamProducerIdentity) -> str: + if producer.pipeline_position is not None: + return f"step {producer.pipeline_position + cls.PIPELINE_DISPLAY_INDEX_BASE}" + return producer.output_key or producer.origin + + @classmethod + def includes_output_key(cls, producer: StreamProducerIdentity) -> bool: + if not producer.output_key: + return False + if producer.output_kind in cls.OUTPUT_KEY_OMITTING_KINDS: + return False + return producer.output_key != cls.producer_base(producer) + + +class StreamRouteKeyAuthority: + """Own stable key-token projection for viewer route keys.""" + + @staticmethod + def token(value: RouteKeyPart) -> str: + return str(value).replace("/", "_").replace("\\", "_").replace(" ", "_") + + @classmethod + def join(cls, parts: Sequence[RouteKeyPart]) -> str: + if not parts: + raise ValueError("Cannot build a stream route key with no parts.") + return "_".join(cls.token(part) for part in parts) diff --git a/src/polystore/streaming/receivers/__init__.py b/src/polystore/streaming/receivers/__init__.py index b1876be..c6d58db 100644 --- a/src/polystore/streaming/receivers/__init__.py +++ b/src/polystore/streaming/receivers/__init__.py @@ -9,13 +9,15 @@ WindowProjectionABC, DebouncedBatchEngine, GroupedWindowItems, + WindowProjectionPayloadProvider, + WindowProjectionSource, group_items_by_component_modes, ) from polystore.streaming.receivers.fiji.fiji_batch_processor import FijiBatchProcessor from polystore.streaming.receivers.napari import ( NapariBatchProcessor, normalize_component_layout, - build_layer_key, + build_route_key, ) __all__ = [ @@ -23,9 +25,11 @@ "WindowProjectionABC", "DebouncedBatchEngine", "GroupedWindowItems", + "WindowProjectionPayloadProvider", + "WindowProjectionSource", "group_items_by_component_modes", "FijiBatchProcessor", "NapariBatchProcessor", "normalize_component_layout", - "build_layer_key", + "build_route_key", ] diff --git a/src/polystore/streaming/receivers/core/__init__.py b/src/polystore/streaming/receivers/core/__init__.py index 084d686..786f526 100644 --- a/src/polystore/streaming/receivers/core/__init__.py +++ b/src/polystore/streaming/receivers/core/__init__.py @@ -7,6 +7,8 @@ ) from polystore.streaming.receivers.core.window_projection import ( GroupedWindowItems, + WindowProjectionPayloadProvider, + WindowProjectionSource, group_items_by_component_modes, ) @@ -15,6 +17,7 @@ "WindowProjectionABC", "DebouncedBatchEngine", "GroupedWindowItems", + "WindowProjectionPayloadProvider", + "WindowProjectionSource", "group_items_by_component_modes", ] - diff --git a/src/polystore/streaming/receivers/core/window_projection.py b/src/polystore/streaming/receivers/core/window_projection.py index 4987960..308df65 100644 --- a/src/polystore/streaming/receivers/core/window_projection.py +++ b/src/polystore/streaming/receivers/core/window_projection.py @@ -2,87 +2,194 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable - - -WindowValueNormalizer = Callable[[str, Any, dict[str, Any], str | None], Any] - - -@dataclass(frozen=True) -class GroupedWindowItems: +from typing import Generic, TypeVar + +from polystore.streaming.identity import ( + StreamProducerDisplayNameAuthority, + StreamProducerIdentity, + StreamRouteKeyAuthority, +) +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerBatchItemWireField, + ViewerComponentMode, + ViewerWireMapping, + ViewerWireValue, +) + + +WINDOW_COMPONENT_MODES = ( + ViewerComponentMode.WINDOW, + ViewerComponentMode.CHANNEL, + ViewerComponentMode.SLICE, + ViewerComponentMode.FRAME, +) +WindowLabel = tuple[str, ViewerWireValue] +WindowProjectionItemT = TypeVar("WindowProjectionItemT") +WindowProjectionProviderT = TypeVar( + "WindowProjectionProviderT", + bound="WindowProjectionPayloadProvider", +) + + +class WindowProjectionPayloadProvider(ABC): + """Item that can expose its viewer wire payload for window projection.""" + + @abstractmethod + def window_projection_payload(self) -> Mapping[str, ViewerWireValue]: + """Return the wire payload used for component/window projection.""" + + +class WindowItemPayload(dict[str, ViewerWireValue]): + """Normalized wire payload retained for one projected window item.""" + + @classmethod + def from_mapping( + cls, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowItemPayload": + return cls(dict(payload)) + + +@dataclass(frozen=True, slots=True) +class GroupedWindowItems(Generic[WindowProjectionItemT]): """Projection result for a single batch.""" window_components: list[str] channel_components: list[str] slice_components: list[str] frame_components: list[str] - windows: dict[str, list[dict[str, Any]]] - fixed_window_labels: dict[str, list[tuple[str, Any]]] - - -def _default_normalizer( - component_name: str, - value: Any, - item: dict[str, Any], - images_dir: str | None, -) -> Any: - """Normalize window component values for stable keying across payload types.""" - data_type = item.get("data_type") - if component_name == "source" and images_dir and data_type == "rois": - value_str = str(value) - if "_results" in value_str or "/" in value_str: - return Path(images_dir).name - return value + windows: dict[str, list[WindowProjectionItemT]] + fixed_window_labels: dict[str, tuple[WindowLabel, ...]] + + +@dataclass(frozen=True, slots=True) +class WindowProjectionSource(Generic[WindowProjectionItemT]): + """Validated receiver item source used by window projection.""" + + item: WindowProjectionItemT + payload: Mapping[str, ViewerWireValue] + metadata: ViewerWireMapping + producer: StreamProducerIdentity + + @classmethod + def from_wire_payload( + cls, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowProjectionSource[WindowItemPayload]": + window_payload = WindowItemPayload.from_mapping(payload) + return cls.from_item(window_payload, window_payload) + + @classmethod + def from_wire_payloads( + cls, + payloads: Sequence[Mapping[str, ViewerWireValue]], + ) -> list["WindowProjectionSource[WindowItemPayload]"]: + return [cls.from_wire_payload(payload) for payload in payloads] + + @classmethod + def from_payload_provider( + cls, + item: WindowProjectionProviderT, + ) -> "WindowProjectionSource[WindowProjectionProviderT]": + return cls.from_item(item, item.window_projection_payload()) + + @classmethod + def from_payload_providers( + cls, + items: Sequence[WindowProjectionProviderT], + ) -> list["WindowProjectionSource[WindowProjectionProviderT]"]: + return [cls.from_payload_provider(item) for item in items] + + @classmethod + def from_item( + cls, + item: WindowProjectionItemT, + payload: Mapping[str, ViewerWireValue], + ) -> "WindowProjectionSource[WindowProjectionItemT]": + metadata = cls._required_mapping( + payload, + ViewerBatchItemWireField.METADATA.value, + ) + producer_identity = cls._required_mapping( + payload, + ViewerBatchItemWireField.PRODUCER_IDENTITY.value, + ) + return cls( + item=item, + payload=payload, + metadata=metadata, + producer=StreamProducerIdentity.from_payload(producer_identity), + ) + + @staticmethod + def _required_mapping( + payload: Mapping[str, ViewerWireValue], + field_name: str, + ) -> ViewerWireMapping: + if field_name not in payload: + raise ValueError( + f"Viewer window projection item missing required field {field_name!r}." + ) + value = payload[field_name] + if not isinstance(value, Mapping): + raise TypeError( + f"Viewer window projection item field {field_name!r} must be a mapping, " + f"got {type(value).__name__}." + ) + return dict(value) def group_items_by_component_modes( - items: list[dict[str, Any]], - component_modes: dict[str, str], - component_order: list[str], - *, - images_dir: str | None = None, - normalizer: WindowValueNormalizer | None = None, -) -> GroupedWindowItems: + items: Sequence[WindowProjectionSource[WindowProjectionItemT]], + display_layout: ViewerBatchDisplayPayload, +) -> GroupedWindowItems[WindowProjectionItemT]: """Project items into window groups using declared component modes.""" - if normalizer is None: - normalizer = _default_normalizer - - result: dict[str, list[str]] = { - "window": [], - "channel": [], - "slice": [], - "frame": [], - } - for comp_name in component_order: - mode = component_modes[comp_name] - result[mode].append(comp_name) - - window_components = result["window"] - channel_components = result["channel"] - slice_components = result["slice"] - frame_components = result["frame"] - - windows: dict[str, list[dict[str, Any]]] = {} - fixed_window_labels: dict[str, list[tuple[str, Any]]] = {} + mode_groups = display_layout.component_mode_groups(WINDOW_COMPONENT_MODES) + mode_groups.require_all_supported("window projection") + + window_components = list( + mode_groups.components_for_mode(ViewerComponentMode.WINDOW) + ) + channel_components = list( + mode_groups.components_for_mode(ViewerComponentMode.CHANNEL) + ) + slice_components = list( + mode_groups.components_for_mode(ViewerComponentMode.SLICE) + ) + frame_components = list( + mode_groups.components_for_mode(ViewerComponentMode.FRAME) + ) + + windows: dict[str, list[WindowProjectionItemT]] = {} + fixed_window_labels: dict[str, tuple[WindowLabel, ...]] = {} for item in items: - meta = item.get("metadata", {}) - key_parts: list[str] = [] - fixed_labels: list[tuple[str, Any]] = [] + key_parts: list[str] = list(item.producer.route_parts()) + fixed_labels: list[WindowLabel] = [ + ( + "producer", + StreamProducerDisplayNameAuthority.output_label(item.producer), + ) + ] for comp in window_components: - if comp not in meta: - continue - value = normalizer(comp, meta[comp], item, images_dir) + if comp not in item.metadata: + raise ValueError( + f"Viewer window projection item missing window component {comp!r}." + ) + value = item.metadata[comp] key_parts.append(f"{comp}_{value}") fixed_labels.append((comp, value)) - window_key = "_".join(key_parts) if key_parts else "default_window" - windows.setdefault(window_key, []).append(item) + window_key = StreamRouteKeyAuthority.join(key_parts) if window_key not in fixed_window_labels: - fixed_window_labels[window_key] = fixed_labels + windows[window_key] = [] + fixed_window_labels[window_key] = tuple(fixed_labels) + windows[window_key].append(item.item) return GroupedWindowItems( window_components=window_components, @@ -92,4 +199,3 @@ def group_items_by_component_modes( windows=windows, fixed_window_labels=fixed_window_labels, ) - diff --git a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py index 60e0a3b..b1a95f8 100644 --- a/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py +++ b/src/polystore/streaming/receivers/fiji/fiji_batch_processor.py @@ -57,6 +57,7 @@ def add_items( display_config: Dict[str, Any], images_dir: str, component_names_metadata: Dict[str, Any], + component_value_domain: Dict[str, Any], ): """ Add items to the batch for processing. @@ -65,13 +66,15 @@ def add_items( window_key: Unique identifier for the Fiji window items: List of items to add (images) display_config: Display configuration dict - images_dir: Source image subdirectory + images_dir: Artifact image directory context. component_names_metadata: Component name mappings for dimension labels + component_value_domain: Component value domains for axis cardinality """ context = { "display_config": display_config, "images_dir": images_dir, "component_names_metadata": component_names_metadata, + "component_value_domain": component_value_domain, "window_key": window_key, } self._engine.enqueue(items=items, context=context) @@ -90,15 +93,17 @@ def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) - display_config = context["display_config"] images_dir = context["images_dir"] component_names_metadata = context["component_names_metadata"] + component_value_domain = context["component_value_domain"] window_key = context["window_key"] logger.info( "FijiBatchProcessor: Processing batch of %d items for window '%s'", len(items), window_key, ) - self.fiji_server._process_items_from_batch( + self.fiji_server.batch_processor.process_wire_items( items=items, - display_config_dict=display_config, + display_config=display_config, images_dir=images_dir, component_names_metadata=component_names_metadata, + component_value_domain=component_value_domain, ) diff --git a/src/polystore/streaming/receivers/napari/__init__.py b/src/polystore/streaming/receivers/napari/__init__.py index 9ece1bf..6472c4c 100644 --- a/src/polystore/streaming/receivers/napari/__init__.py +++ b/src/polystore/streaming/receivers/napari/__init__.py @@ -3,7 +3,7 @@ from polystore.streaming.receivers.napari.napari_batch_processor import NapariBatchProcessor from polystore.streaming.receivers.napari.layer_key import ( normalize_component_layout, - build_layer_key, + build_route_key, ) -__all__ = ["NapariBatchProcessor", "normalize_component_layout", "build_layer_key"] +__all__ = ["NapariBatchProcessor", "normalize_component_layout", "build_route_key"] diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index dec6fff..c382853 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -1,46 +1,93 @@ -"""Canonical napari layer-key construction from component metadata.""" +"""Canonical napari route-key construction.""" from __future__ import annotations -from typing import Any +from collections.abc import Mapping +from polystore.streaming.identity import StreamProducerIdentity, StreamRouteKeyAuthority from polystore.streaming_constants import StreamingDataType +from zmqruntime.viewer_protocol import ( + ViewerBatchDisplayPayload, + ViewerComponentMode, + ViewerDisplayConfigWireField, + ViewerWireMapping, + ViewerWireValue, +) -def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], list[str]]: - """Return canonical (component_modes, component_order) from display config.""" +def normalize_component_layout( + display_config: ViewerBatchDisplayPayload | ViewerWireMapping, +) -> ViewerBatchDisplayPayload: + """Return canonical display layout from a viewer display-config payload.""" + if isinstance(display_config, ViewerBatchDisplayPayload): + return display_config if isinstance(display_config, dict): - component_modes = display_config["component_modes"] - component_order = display_config["component_order"] - return component_modes, component_order - - component_order = list(display_config.COMPONENT_ORDER) - component_modes: dict[str, str] = {} - for component in component_order: - mode_field = f"{component}_mode" - mode_value = display_config.__getattribute__(mode_field) - component_modes[component] = mode_value.value - return component_modes, component_order - - -def build_layer_key( - component_info: dict[str, Any], - component_modes: dict[str, str], - component_order: list[str], + return ViewerBatchDisplayPayload( + component_modes=_required_mapping( + display_config, + ViewerDisplayConfigWireField.COMPONENT_MODES.value, + ), + component_order=_required_sequence( + display_config, + ViewerDisplayConfigWireField.COMPONENT_ORDER.value, + ), + ) + + raise TypeError( + "Napari component layout requires ViewerBatchDisplayPayload or mapping, " + f"got {type(display_config).__name__}." + ) + + +def build_route_key( + producer_identity: StreamProducerIdentity | Mapping[str, ViewerWireValue], + component_info: Mapping[str, ViewerWireValue], + display_layout: ViewerBatchDisplayPayload, data_type: StreamingDataType, ) -> str: - """Build canonical layer key from slice-mode components and payload type.""" - layer_key_parts: list[str] = [] - for component in component_order: - mode = component_modes[component] - if mode == "slice" and component in component_info: - layer_key_parts.append(f"{component}_{component_info[component]}") + """Build hidden route key from producer identity, slice components, and type.""" + producer = StreamProducerIdentity.from_payload(producer_identity) + route_parts: list[str] = list(producer.route_parts()) + for component in display_layout.components_for_mode(ViewerComponentMode.SLICE): + if component not in component_info: + raise ValueError( + f"Napari route key missing slice component {component!r}." + ) + route_parts.append(f"{component}_{component_info[component]}") + + route_key = StreamRouteKeyAuthority.join(route_parts) + + return f"{route_key}{data_type.napari_layer_suffix}" + - layer_key = "_".join(layer_key_parts) if layer_key_parts else "default_layer" +def _required_mapping( + payload: Mapping[str, ViewerWireValue], + field_name: str, +) -> dict[str, str]: + if field_name not in payload: + raise ValueError(f"Display config missing required field {field_name!r}.") + value = payload[field_name] + if not isinstance(value, Mapping): + raise TypeError( + f"Display config field {field_name!r} must be a mapping, " + f"got {type(value).__name__}." + ) + return { + str(component): str(mode) + for component, mode in value.items() + } - if data_type == StreamingDataType.SHAPES: - return f"{layer_key}_shapes" - if data_type == StreamingDataType.POINTS: - return f"{layer_key}_points" - return layer_key +def _required_sequence( + payload: Mapping[str, ViewerWireValue], + field_name: str, +) -> list[str]: + if field_name not in payload: + raise ValueError(f"Display config missing required field {field_name!r}.") + value = payload[field_name] + if isinstance(value, str) or not isinstance(value, list | tuple): + raise TypeError( + f"Display config field {field_name!r} must be a sequence, " + f"got {type(value).__name__}." + ) + return [str(component) for component in value] diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index b8dcbdd..e6e80d5 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,20 +1,45 @@ import logging -from typing import Any, Dict, List, Optional - -from polystore.streaming.receivers.core import DebouncedBatchEngine +from dataclasses import dataclass +from collections.abc import Sequence +from typing import Generic, Optional, TypeVar logger = logging.getLogger(__name__) +NapariBatchItemT = TypeVar("NapariBatchItemT") +NapariDisplayPayloadT = TypeVar("NapariDisplayPayloadT") +NapariComponentNamesMetadataT = TypeVar("NapariComponentNamesMetadataT") + + +@dataclass(frozen=True) +class NapariBatchDisplayRequest( + Generic[ + NapariBatchItemT, + NapariDisplayPayloadT, + NapariComponentNamesMetadataT, + ] +): + """Nominal request for one debounced Napari display update.""" + + layer_key: str + items: Sequence[NapariBatchItemT] + display_payload: NapariDisplayPayloadT + component_names_metadata: NapariComponentNamesMetadataT + + def dispatch_to(self, napari_server) -> None: + napari_server.display_layer_batch( + layer_key=self.layer_key, + items=self.items, + display_payload=self.display_payload, + component_names_metadata=self.component_names_metadata, + ) class NapariBatchProcessor: """ - Batch processor for Napari viewer with configurable batching strategies. - - Accumulates items and displays them based on batch_size configuration: - - None: Wait for all items in operation, then display once - - N: Display every N items incrementally - - Uses debouncing to collect items arriving in rapid succession. + Batch processor for Napari viewer display operations. + + Napari layer mutation must run on the Qt event-loop thread. OpenHCS owns that + Qt-thread debounce before this processor is called, so this class only + adapts batch payloads into the server display operation. """ def __init__( @@ -29,22 +54,15 @@ def __init__( Args: napari_server: Reference to NapariViewerServer for display operations - batch_size: Number of items to batch before displaying - None = wait for all (default), N = display every N items - debounce_delay_ms: Wait time after last item before processing (ms) - max_debounce_wait_ms: Maximum total wait time before forcing display (ms) + batch_size: Reserved for compatibility with viewer configuration + debounce_delay_ms: Qt-thread debounce delay owned by the caller + max_debounce_wait_ms: Reserved for compatibility with viewer configuration """ self.napari_server = napari_server self.batch_size = batch_size self.debounce_delay_ms = debounce_delay_ms self.max_debounce_wait_ms = max_debounce_wait_ms - self._engine = DebouncedBatchEngine( - process_fn=self._process_batch, - debounce_delay_ms=debounce_delay_ms, - max_debounce_wait_ms=max_debounce_wait_ms, - ) - logger.info( f"NapariBatchProcessor: Created with batch_size={batch_size}, " f"debounce={debounce_delay_ms}ms, max_wait={max_debounce_wait_ms}ms" @@ -53,27 +71,25 @@ def __init__( def add_items( self, layer_key: str, - items: List[Dict[str, Any]], - display_config: Dict[str, Any], - component_names_metadata: Dict[str, Any], + items: Sequence[NapariBatchItemT], + display_payload: NapariDisplayPayloadT, + component_names_metadata: NapariComponentNamesMetadataT, ): """ - Add items to the batch for processing. + Display items already released by the Qt-thread debounce. Args: layer_key: Unique identifier for the layer items: List of items to add (images or ROIs) - display_config: Display configuration dict + display_payload: Viewer-owned display payload object component_names_metadata: Component name mappings for dimension labels """ - self._engine.enqueue( + NapariBatchDisplayRequest( + layer_key=layer_key, items=items, - context={ - "display_config": display_config, - "component_names_metadata": component_names_metadata, - "layer_key": layer_key, - }, - ) + display_payload=display_payload, + component_names_metadata=component_names_metadata, + ).dispatch_to(self.napari_server) logger.debug( "NapariBatchProcessor: Added %d items to batch for layer '%s'", len(items), @@ -81,14 +97,4 @@ def add_items( ) def flush(self) -> None: - """Force immediate processing of the pending batch.""" - self._engine.flush() - - def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> None: - """Process callback used by shared debounced batch engine.""" - self.napari_server._display_layer_batch( - layer_key=context["layer_key"], - items=items, - display_config=context["display_config"], - component_names_metadata=context["component_names_metadata"], - ) + """Compatibility no-op; OpenHCS owns the Qt-thread debounce timer.""" diff --git a/src/polystore/streaming/viewer_transport.py b/src/polystore/streaming/viewer_transport.py new file mode 100644 index 0000000..6f3ef77 --- /dev/null +++ b/src/polystore/streaming/viewer_transport.py @@ -0,0 +1,512 @@ +"""Nominal transport helpers for blocking viewer stream backends.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, replace +from enum import Enum +from pathlib import Path +from typing import ( + ClassVar, + TypeAlias, +) + +from polystore.registry import AutoRegisterMeta +from polystore.streaming_constants import StreamingDataType +from polystore.streaming.identity import StreamProducerIdentity +from zmqruntime.config import ZMQConfig +from zmqruntime.viewer_protocol import ( + ViewerAckPolicy, + ViewerBatchDisplayPayload, + ViewerBatchContextWireField, + ViewerBatchItemPayload, + ViewerTransportEndpoint, + ViewerTransportMode, + ViewerWirePayload, + ViewerWireMapping, + ViewerWireValue, +) + + +DisplayComponentToken: TypeAlias = str | Enum +DisplayModeToken: TypeAlias = str | Enum | None +ViewerIndexedComponentMetadata: TypeAlias = Sequence[ViewerWireMapping] +ViewerPathComponentMetadata: TypeAlias = Mapping[str, ViewerWireMapping] + + +class ViewerDisplayConfigABC(ABC): + """Display-config surface required by viewer streaming backends.""" + + COMPONENT_ORDER: Sequence[DisplayComponentToken] + + @abstractmethod + def component_modes(self) -> Mapping[DisplayComponentToken, DisplayModeToken]: + """Return mode assignments by display component.""" + + +class ViewerFilenameParserABC(ABC): + """Filename parser surface needed by viewer streaming metadata.""" + + @abstractmethod + def parse_filename(self, filename: str) -> ViewerWireMapping | None: + """Return component metadata parsed from a filename.""" + + +class ViewerMetadataHandlerABC(ABC): + """Metadata-handler surface needed by viewer component labels.""" + + @abstractmethod + def get_component_values( + self, + plate_path: str | Path | None, + component_name: str, + ) -> ViewerWireValue: + """Return display-name metadata for one component.""" + + +class ViewerMicroscopeHandlerABC(ABC): + """Microscope-handler surface used by viewer streaming.""" + + parser: ViewerFilenameParserABC + metadata_handler: ViewerMetadataHandlerABC + + +class ViewerTransportConfigSelection(ABC, metaclass=AutoRegisterMeta): + """Nominal selection of the transport config used for one stream request.""" + + __registry_key__ = "registry_key" + registry_key: ClassVar[str | None] = None + + @classmethod + def select(cls, value) -> "ViewerTransportConfigSelection": + if isinstance(value, cls): + return value + for selection_type in cls.__registry__.values(): + if selection_type.accepts(value): + return selection_type.from_raw(value) + raise TypeError( + "transport_config must be a ZMQConfig, " + "ViewerTransportConfigSelection, or None." + ) + + @classmethod + @abstractmethod + def accepts(cls, value) -> bool: + """Return whether this registered selection can adapt the raw value.""" + + @classmethod + @abstractmethod + def from_raw(cls, value) -> "ViewerTransportConfigSelection": + """Adapt the raw value into a concrete transport-config selection.""" + + @abstractmethod + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + """Return the concrete config for this request.""" + + +@dataclass(frozen=True) +class DefaultViewerTransportConfig(ViewerTransportConfigSelection): + """Use the backend's configured transport settings.""" + + registry_key: ClassVar[str] = "default" + + @classmethod + def accepts(cls, value) -> bool: + return value is None + + @classmethod + def from_raw(cls, value) -> "DefaultViewerTransportConfig": + return cls() + + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + return default_transport_config + + +@dataclass(frozen=True) +class ExplicitViewerTransportConfig(ViewerTransportConfigSelection): + """Use a caller-supplied transport config for this request.""" + + registry_key: ClassVar[str] = "explicit" + + config: ZMQConfig + + @classmethod + def accepts(cls, value) -> bool: + return isinstance(value, ZMQConfig) + + @classmethod + def from_raw(cls, value) -> "ExplicitViewerTransportConfig": + return cls(value) + + def resolve(self, default_transport_config: ZMQConfig) -> ZMQConfig: + return self.config + + +@dataclass(frozen=True) +class ViewerTransportDefaults: + """Declared transport defaults shared by viewer streaming backends.""" + + ack_timeout_ms: int = 30_000 + + def ack_policy(self, viewer_name: str) -> ViewerAckPolicy: + return ViewerAckPolicy( + viewer_name=viewer_name, + timeout_ms=self.ack_timeout_ms, + ) + + +class ViewerSourceComponentMetadataPayload(dict[str, ViewerWireValue]): + """Validated component metadata payload for one streamed source item.""" + + @classmethod + def from_mapping( + cls, + value: ViewerWireMapping, + *, + source_label: str, + ) -> "ViewerSourceComponentMetadataPayload": + if not isinstance(value, Mapping): + raise TypeError( + "Viewer stream component metadata must be a mapping " + f"for {source_label}; got {type(value).__name__}." + ) + return cls( + ViewerWirePayload.mapping( + value, + context=f"viewer stream component metadata for {source_label}", + ) + ) + + +class ViewerStreamSourceMetadata(ABC, metaclass=AutoRegisterMeta): + """Component metadata authority for streamed source items.""" + + __registry_key__ = "metadata_kind" + __skip_if_no_key__ = True + metadata_kind: ClassVar[str | None] = None + + @abstractmethod + def component_metadata_for_item( + self, + file_path: str | Path, + index: int, + ) -> ViewerSourceComponentMetadataPayload: + """Return explicit component metadata for one batch item.""" + + +@dataclass(frozen=True) +class BatchViewerStreamSourceMetadata(ViewerStreamSourceMetadata): + """One component metadata payload shared by every streamed item.""" + + metadata_kind: ClassVar[str] = "batch" + component_metadata: ViewerWireMapping + + def component_metadata_for_item( + self, + file_path: str | Path, + index: int, + ) -> ViewerSourceComponentMetadataPayload: + return ViewerSourceComponentMetadataPayload.from_mapping( + self.component_metadata, + source_label=f"batch metadata for {file_path!r}", + ) + + +@dataclass(frozen=True) +class PathMappedViewerStreamSourceMetadata(ViewerStreamSourceMetadata): + """Component metadata selected by stream item path identity.""" + + metadata_kind: ClassVar[str] = "path_mapped" + metadata_by_path: ViewerPathComponentMetadata + + def component_metadata_for_item( + self, + file_path: str | Path, + index: int, + ) -> ViewerSourceComponentMetadataPayload: + path = Path(file_path) + for key in (str(file_path), path.as_posix(), path.name): + if key in self.metadata_by_path: + return ViewerSourceComponentMetadataPayload.from_mapping( + self.metadata_by_path[key], + source_label=f"path metadata for {file_path!r}", + ) + raise KeyError( + "Viewer stream path-mapped component metadata has no entry for " + f"{file_path!r}." + ) + + +@dataclass(frozen=True) +class IndexedViewerStreamSourceMetadata(ViewerStreamSourceMetadata): + """Component metadata selected by stream item batch position.""" + + metadata_kind: ClassVar[str] = "indexed" + metadata_by_index: ViewerIndexedComponentMetadata + + def component_metadata_for_item( + self, + file_path: str | Path, + index: int, + ) -> ViewerSourceComponentMetadataPayload: + if index >= len(self.metadata_by_index): + raise IndexError( + "Viewer stream indexed component metadata has no entry for " + f"item {index} at {file_path!r}." + ) + return ViewerSourceComponentMetadataPayload.from_mapping( + self.metadata_by_index[index], + source_label=f"indexed metadata for {file_path!r}", + ) + + +@dataclass(frozen=True) +class ViewerStreamProducer: + """Producer identity carrier that owns viewer item identity projection.""" + + identity: StreamProducerIdentity + + @classmethod + def from_identity( + cls, + identity: StreamProducerIdentity, + ) -> "ViewerStreamProducer": + return cls(identity=identity) + + def batch_item_payload( + self, + item_source: "ViewerStreamBatchItemSource", + ) -> ViewerBatchItemPayload: + return ViewerBatchItemPayload.from_parts( + item_payload=item_source.item_payload, + data_type=item_source.wire_data_type, + metadata=item_source.metadata, + producer_identity=self.identity.to_payload(), + image_id=item_source.image_id, + ) + + +@dataclass(frozen=True) +class ViewerStreamItemPayload: + """Typed item payload produced by a concrete viewer streaming backend.""" + + item_payload: ViewerWireMapping + streaming_data_type: StreamingDataType + + @property + def wire_data_type(self) -> str: + return self.streaming_data_type.value + + +@dataclass(frozen=True) +class ViewerStreamBatchItemInput(ViewerStreamItemPayload): + """Nominal input for constructing one viewer batch item source.""" + + stream_source: "ViewerStreamSource" + file_path: str | Path + index: int + image_id: str + + +@dataclass(frozen=True) +class ViewerStreamBatchItemSource(ViewerStreamItemPayload): + """Declared source payload for one viewer batch item.""" + + metadata: ViewerSourceComponentMetadataPayload + image_id: str + + @classmethod + def from_input( + cls, + source_input: ViewerStreamBatchItemInput, + ) -> "ViewerStreamBatchItemSource": + return cls( + item_payload=source_input.item_payload, + streaming_data_type=source_input.streaming_data_type, + metadata=source_input.stream_source.metadata.component_metadata_for_item( + source_input.file_path, + source_input.index, + ), + image_id=source_input.image_id, + ) + + +@dataclass(frozen=True) +class ViewerStreamSourceIdentity: + """Stable source identity shared by all stream batches for one plate.""" + + microscope_handler: ViewerMicroscopeHandlerABC + plate_path: str | Path | None = None + + +class ViewerStreamKwarg(str, Enum): + """Raw kwarg names accepted at the top-level viewer stream boundary.""" + + STREAM_REQUEST = "stream_request" + + +@dataclass(frozen=True) +class ViewerStreamSource: + """Source provenance and metadata authority for one viewer stream.""" + + identity: ViewerStreamSourceIdentity + metadata: ViewerStreamSourceMetadata + + +@dataclass(frozen=True) +class ViewerStreamDisplaySemantics: + """Normalized display-axis semantics for a viewer stream request.""" + + display_config: ViewerDisplayConfigABC + + @property + def component_order(self) -> tuple[str, ...]: + return tuple(str(component) for component in self.display_config.COMPONENT_ORDER) + + @property + def component_modes(self) -> dict[str, str]: + return { + str(component): str(mode.value if isinstance(mode, Enum) else mode) + for component, mode in self.display_config.component_modes().items() + } + + def batch_display_payload( + self, + extra: Mapping[str | Enum, ViewerWireValue] | None = None, + ) -> ViewerBatchDisplayPayload: + if extra is None: + extra_payload: dict[str | Enum, ViewerWireValue] = {} + else: + extra_payload = dict(extra) + return ViewerBatchDisplayPayload( + component_modes=self.component_modes, + component_order=self.component_order, + extra=extra_payload, + ) + + +@dataclass(frozen=True, kw_only=True, slots=True) +class ViewerStreamMessageContext: + """Viewer message context carried through stream request boundaries.""" + + message_extra: ViewerWireMapping | None = None + images_dir: str | None = None + + def message_extra_payload(self) -> dict[str, ViewerWireValue]: + return ViewerMessageExtraAuthority.payload(self.message_extra) + + def message_extra_payload_with_images_dir(self) -> dict[str, ViewerWireValue]: + payload = self.message_extra_payload() + payload[ViewerBatchContextWireField.IMAGES_DIR.value] = self.images_dir + return payload + + +@dataclass(frozen=True, kw_only=True) +class ViewerStreamRequest(ViewerStreamMessageContext): + """Typed view of backend kwargs at the viewer streaming boundary.""" + + viewer_transport: ViewerTransportEndpoint + display_config: ViewerDisplayConfigABC + source: ViewerStreamSource + producer: ViewerStreamProducer + transport_config: ViewerTransportConfigSelection = DefaultViewerTransportConfig() + + @classmethod + def from_message_context( + cls, + *, + message_context: ViewerStreamMessageContext, + viewer_transport: ViewerTransportEndpoint, + display_config: ViewerDisplayConfigABC, + source: ViewerStreamSource, + producer: ViewerStreamProducer, + transport_config: ViewerTransportConfigSelection = DefaultViewerTransportConfig(), + ) -> "ViewerStreamRequest": + return cls( + viewer_transport=viewer_transport, + display_config=display_config, + source=source, + producer=producer, + transport_config=transport_config, + message_extra=message_context.message_extra, + images_dir=message_context.images_dir, + ) + + @property + def host(self) -> str: + return self.viewer_transport.host + + @property + def port(self) -> int: + return self.viewer_transport.port + + @property + def transport_mode(self) -> ViewerTransportMode: + return self.viewer_transport.transport_mode + + @property + def display_semantics(self) -> ViewerStreamDisplaySemantics: + return ViewerStreamDisplaySemantics(self.display_config) + + +ViewerStreamKwargPayloadMapping: TypeAlias = Mapping[ + str, + "ViewerStreamRequest", +] + + +@dataclass(frozen=True) +class ViewerStreamBackendKwargs: + """The only accepted FileManager kwarg payload for viewer stream backends.""" + + stream_request: ViewerStreamRequest + + @classmethod + def from_kwargs( + cls, + kwargs: ViewerStreamKwargPayloadMapping, + ) -> "ViewerStreamBackendKwargs": + expected = frozenset((ViewerStreamKwarg.STREAM_REQUEST.value,)) + actual = frozenset(kwargs) + if actual != expected: + raise ValueError( + "Viewer stream backends require exactly one kwarg: stream_request" + ) + value = kwargs[ViewerStreamKwarg.STREAM_REQUEST.value] + if not isinstance(value, ViewerStreamRequest): + raise TypeError("stream_request must be a ViewerStreamRequest instance") + return cls(value) + + def to_kwargs(self) -> dict[str, ViewerStreamRequest]: + return {ViewerStreamKwarg.STREAM_REQUEST.value: self.stream_request} + + def with_single_item_component_metadata( + self, + component_metadata: ViewerWireMapping, + ) -> "ViewerStreamBackendKwargs": + """Return kwargs with component metadata for a single streamed item.""" + stream_request = self.stream_request + source = replace( + stream_request.source, + metadata=BatchViewerStreamSourceMetadata( + component_metadata=ViewerWirePayload.mapping( + component_metadata, + context="single-item viewer stream component metadata", + ), + ), + ) + return type(self)(replace(stream_request, source=source)) + + +class ViewerMessageExtraAuthority: + """Formal boundary for absent caller-supplied viewer message extras.""" + + @staticmethod + def payload(message_extra: Mapping[str, ViewerWireValue] | None) -> dict[str, ViewerWireValue]: + if message_extra is None: + return {} + return ViewerWirePayload.mapping( + message_extra, + context="viewer message extra", + ) diff --git a/src/polystore/streaming_constants.py b/src/polystore/streaming_constants.py index d7f0596..05c834c 100644 --- a/src/polystore/streaming_constants.py +++ b/src/polystore/streaming_constants.py @@ -15,6 +15,21 @@ class StreamingDataType(Enum): POINTS = "points" # Napari points layer (e.g., skeleton tracings) ROIS = "rois" # Fiji ROI payloads + @property + def uses_napari_vector_payload(self) -> bool: + """Whether napari should receive this type through vector layer payloads.""" + return self in (type(self).SHAPES, type(self).POINTS) + + @property + def napari_layer_suffix(self) -> str: + """Layer-key suffix contributed by this data type.""" + return { + type(self).IMAGE: "", + type(self).SHAPES: "_shapes", + type(self).POINTS: "_points", + type(self).ROIS: "", + }[self] + class NapariShapeType(Enum): """Napari shape types for ROI visualization.""" diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index 45081a3..54ff2a9 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -1,20 +1,191 @@ """Virtual Workspace Backend - Symlink-free workspace using metadata mapping.""" -import logging import json +import logging from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, Hashable, List, Mapping, Optional, Set, Union from fnmatch import fnmatch +import numpy as np + from .disk import DiskStorageBackend from .metadata_writer import get_metadata_path from .exceptions import StorageResolutionError -from .base import ReadOnlyBackend +from .base import PicklableBackend, ReadOnlyBackend +from .constants import Backend +from .registry import AutoRegisterMeta logger = logging.getLogger(__name__) -class VirtualWorkspaceBackend(ReadOnlyBackend): +class VirtualWorkspaceSourceRefResolver(ABC, metaclass=AutoRegisterMeta): + """Nominal loader family for virtual-workspace source references.""" + + __registry_key__ = "resolver_key" + __skip_if_no_key__ = True + resolver_key: ClassVar[str | None] = None + priority: ClassVar[int] + + @classmethod + def for_ref(cls, source_ref: Any) -> "VirtualWorkspaceSourceRefResolver": + for resolver_type in sorted( + cls.__registry__.values(), + key=lambda registered_type: registered_type.priority, + ): + resolver = resolver_type() + if resolver.accepts(source_ref): + return resolver + raise StorageResolutionError( + f"Unsupported virtual workspace source reference: {source_ref!r}" + ) + + @abstractmethod + def accepts(self, source_ref: Any) -> bool: + """Return whether this resolver owns the reference shape.""" + + @abstractmethod + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + """Return the concrete source path for existence and diagnostics.""" + + @abstractmethod + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + """Load the payload addressed by this source reference.""" + + def batch_key(self, plate_root: Path, source_ref: Any) -> Hashable: + """Return the physical source identity shared by batch-compatible refs.""" + return self.source_path(plate_root, source_ref) + + def load_batch( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_refs: tuple[Any, ...], + **kwargs: Any, + ) -> tuple[Any, ...]: + """Load a batch of references owned by this resolver.""" + return tuple( + self.load(disk_backend, plate_root, source_ref, **kwargs) + for source_ref in source_refs + ) + + +class PathSourceRefResolver(VirtualWorkspaceSourceRefResolver): + """Resolve legacy string path mappings.""" + + resolver_key = "path" + priority = 100 + + def accepts(self, source_ref: Any) -> bool: + return isinstance(source_ref, (str, Path)) + + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + path = Path(source_ref) + return path if path.is_absolute() else plate_root / path + + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + return disk_backend.load(self.source_path(plate_root, source_ref), **kwargs) + + +class DiskSourceRefResolver(VirtualWorkspaceSourceRefResolver): + """Resolve structured disk refs, including single-plane TIFF pages.""" + + resolver_key = "disk" + priority = 10 + + def accepts(self, source_ref: Any) -> bool: + return ( + isinstance(source_ref, Mapping) + and source_ref.get("backend", Backend.DISK.value) == Backend.DISK.value + and isinstance(source_ref.get("source_path"), (str, Path)) + ) + + def source_path(self, plate_root: Path, source_ref: Any) -> Path: + path = Path(source_ref["source_path"]) + return path if path.is_absolute() else plate_root / path + + def load( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_ref: Any, + **kwargs: Any, + ) -> Any: + payload = disk_backend.load(self.source_path(plate_root, source_ref), **kwargs) + plane_index = source_ref.get("plane_index") + if plane_index is None: + return payload + return _payload_plane(payload, int(plane_index), source_ref) + + def load_batch( + self, + disk_backend: DiskStorageBackend, + plate_root: Path, + source_refs: tuple[Any, ...], + **kwargs: Any, + ) -> tuple[Any, ...]: + if not source_refs: + return () + source_paths = tuple(self.source_path(plate_root, ref) for ref in source_refs) + unique_source_paths = tuple(dict.fromkeys(source_paths)) + if len(unique_source_paths) != 1: + raise StorageResolutionError( + f"{type(self).__name__}.load_batch requires one physical source path, " + f"got {len(unique_source_paths)}." + ) + payload = disk_backend.load(unique_source_paths[0], **kwargs) + return tuple( + payload + if source_ref.get("plane_index") is None + else _payload_plane(payload, int(source_ref["plane_index"]), source_ref) + for source_ref in source_refs + ) + + +def _payload_plane(payload: Any, plane_index: int, source_ref: Mapping[str, Any]) -> Any: + array = np.asarray(payload) + if array.ndim < 3: + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but loaded " + f"payload shape {array.shape!r} is not a stack." + ) + if plane_index < 0 or plane_index >= array.shape[0]: + raise StorageResolutionError( + f"Source ref {source_ref!r} requested plane {plane_index}, but loaded " + f"payload shape is {array.shape!r}." + ) + return array[plane_index] + + +@dataclass(frozen=True, slots=True) +class VirtualWorkspaceResolvedRef: + """Resolved source reference for one virtual workspace request.""" + + output_index: int + source_ref: Any + resolver: VirtualWorkspaceSourceRefResolver + + def batch_key(self, plate_root: Path) -> tuple[type[VirtualWorkspaceSourceRefResolver], Hashable]: + return (type(self.resolver), self.resolver.batch_key(plate_root, self.source_ref)) + + +_UNSET_BATCH_OUTPUT = object() + + +class VirtualWorkspaceBackend(ReadOnlyBackend, PicklableBackend): """ Read-only path translation layer for virtual workspace. @@ -53,12 +224,33 @@ def __init__(self, plate_root: Path): """ self.plate_root = Path(plate_root) self.disk_backend = DiskStorageBackend() - self._mapping_cache: Optional[Dict[str, str]] = None + self._mapping_cache: Optional[Dict[str, Any]] = None self._cache_mtime: Optional[float] = None # Load mapping eagerly - fail loud if metadata missing self._load_mapping() + @classmethod + def from_connection_params( + cls, + params: Optional[Dict[str, Any]], + ) -> "VirtualWorkspaceBackend": + if not params: + raise ValueError("VirtualWorkspaceBackend requires plate_root.") + return cls(plate_root=Path(params["plate_root"])) + + def get_connection_params(self) -> Optional[Dict[str, Any]]: + return {"plate_root": str(self.plate_root)} + + def set_connection_params(self, params: Optional[Dict[str, Any]]) -> None: + if not params: + raise ValueError("VirtualWorkspaceBackend requires plate_root.") + self.plate_root = Path(params["plate_root"]) + self.disk_backend = DiskStorageBackend() + self._mapping_cache = None + self._cache_mtime = None + self._load_mapping() + @staticmethod def _normalize_relative_path(path_str: str) -> str: """ @@ -76,7 +268,7 @@ def _normalize_relative_path(path_str: str) -> str: normalized = path_str.replace('\\', '/') return '' if normalized == '.' else normalized - def _load_mapping(self) -> Dict[str, str]: + def _load_mapping(self) -> Dict[str, Any]: """ Load workspace_mapping from metadata with mtime-based caching. @@ -122,7 +314,7 @@ def _load_mapping(self) -> Dict[str, str]: logger.info(f"Loaded {len(combined_mapping)} mappings for {self.plate_root}") return combined_mapping - def _resolve_path(self, path: Union[str, Path]) -> str: + def _resolve_ref(self, path: Union[str, Path]) -> Any: """ Resolve virtual path to real plate path using plate-relative mapping. @@ -163,20 +355,78 @@ def _resolve_path(self, path: Union[str, Path]) -> str: f"This path must be accessed through the virtual workspace mapping." ) - real_relative = self._mapping_cache[relative_str] - real_absolute = self.plate_root / real_relative - logger.debug(f"Resolved virtual → real: {relative_str} → {real_relative}") - return str(real_absolute) + source_ref = self._mapping_cache[relative_str] + logger.debug("Resolved virtual source ref: %s -> %r", relative_str, source_ref) + return source_ref + + def _resolve_path(self, path: Union[str, Path]) -> str: + """Resolve a virtual path to the concrete source path for diagnostics.""" + source_ref = self._resolve_ref(path) + resolver = VirtualWorkspaceSourceRefResolver.for_ref(source_ref) + return str(resolver.source_path(self.plate_root, source_ref)) def load(self, file_path: Union[str, Path], **kwargs) -> Any: """Load file from virtual workspace.""" - real_path = self._resolve_path(file_path) - return self.disk_backend.load(real_path, **kwargs) + source_ref = self._resolve_ref(file_path) + resolver = VirtualWorkspaceSourceRefResolver.for_ref(source_ref) + return resolver.load( + self.disk_backend, + self.plate_root, + source_ref, + **kwargs, + ) def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: """Load multiple files from virtual workspace.""" - real_paths = [self._resolve_path(fp) for fp in file_paths] - return self.disk_backend.load_batch(real_paths, **kwargs) + resolved_refs = tuple( + self._resolved_ref(index, file_path) + for index, file_path in enumerate(file_paths) + ) + grouped_refs: dict[ + tuple[type[VirtualWorkspaceSourceRefResolver], Hashable], + list[VirtualWorkspaceResolvedRef], + ] = {} + for resolved_ref in resolved_refs: + grouped_refs.setdefault( + resolved_ref.batch_key(self.plate_root), + [], + ).append(resolved_ref) + + ordered_outputs: list[Any] = [_UNSET_BATCH_OUTPUT] * len(file_paths) + for group in grouped_refs.values(): + resolver = group[0].resolver + source_refs = tuple(ref.source_ref for ref in group) + outputs = resolver.load_batch( + self.disk_backend, + self.plate_root, + source_refs, + **kwargs, + ) + if len(outputs) != len(group): + raise StorageResolutionError( + f"{type(resolver).__name__}.load_batch returned {len(outputs)} " + f"outputs for {len(group)} virtual workspace refs." + ) + for resolved_ref, output in zip(group, outputs, strict=True): + ordered_outputs[resolved_ref.output_index] = output + + if any(output is _UNSET_BATCH_OUTPUT for output in ordered_outputs): + raise StorageResolutionError( + "Virtual workspace batch load did not populate every requested path." + ) + return ordered_outputs + + def _resolved_ref( + self, + output_index: int, + file_path: Union[str, Path], + ) -> VirtualWorkspaceResolvedRef: + source_ref = self._resolve_ref(file_path) + return VirtualWorkspaceResolvedRef( + output_index=output_index, + source_ref=source_ref, + resolver=VirtualWorkspaceSourceRefResolver.for_ref(source_ref), + ) def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, extensions: Optional[Set[str]] = None, recursive: bool = False, @@ -205,10 +455,20 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, if self._mapping_cache is None: self._load_mapping() - logger.info(f"VirtualWorkspace.list_files called: directory={directory}, recursive={recursive}, pattern={pattern}, extensions={extensions}") - logger.info(f" plate_root={self.plate_root}") - logger.info(f" relative_dir_str='{relative_dir_str}'") - logger.info(f" mapping has {len(self._mapping_cache)} entries") + logger.debug( + "VirtualWorkspace.list_files directory=%s recursive=%s pattern=%s extensions=%s", + directory, + recursive, + pattern, + extensions, + ) + logger.debug(" plate_root=%s", self.plate_root) + logger.debug(" relative_dir_str=%r", relative_dir_str) + logger.debug(" mapping has %s entries", len(self._mapping_cache)) + + lowercase_extensions = ( + None if extensions is None else {ext.lower() for ext in extensions} + ) # Filter paths in this directory results = [] @@ -230,20 +490,20 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, vpath = Path(virtual_relative) if pattern and not fnmatch(vpath.name, pattern): continue - if extensions and vpath.suffix not in extensions: + if lowercase_extensions and vpath.suffix.lower() not in lowercase_extensions: continue # Return absolute path results.append(str(self.plate_root / virtual_relative)) - logger.info(f" VirtualWorkspace.list_files returning {len(results)} files") + logger.debug(" VirtualWorkspace.list_files returning %s files", len(results)) if len(results) == 0 and len(self._mapping_cache) > 0: # Log first few mapping keys to help debug sample_keys = list(self._mapping_cache.keys())[:3] - logger.info(f" Sample mapping keys: {sample_keys}") + logger.debug(" Sample mapping keys: %s", sample_keys) if not recursive and relative_dir_str == '': sample_parents = [str(Path(k).parent).replace('\\', '/') for k in sample_keys] - logger.info(f" Sample parent dirs: {sample_parents}") + logger.debug(" Sample parent dirs: %s", sample_parents) logger.info(f" Expected parent to match: '{relative_dir_str}'") return sorted(results) diff --git a/src/polystore/zarr.py b/src/polystore/zarr.py index 7249323..848d2bf 100644 --- a/src/polystore/zarr.py +++ b/src/polystore/zarr.py @@ -11,9 +11,8 @@ import logging import os import threading -from functools import wraps from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import zarr @@ -29,6 +28,7 @@ ATTR_OUTPUT_PATHS = f"{_ATTR_PREFIX}_output_paths" ATTR_DIMENSIONS = f"{_ATTR_PREFIX}_dimensions" DEFAULT_PLATE_NAME = os.getenv("POLYSTORE_PLATE_NAME", "Polystore_Plate") +DISK_PASSTHROUGH_EXTENSIONS = ('.json', '.csv', '.txt', '.roi.zip', '.zip') def _get_attr(attrs: Dict[str, Any], key: str): @@ -37,77 +37,6 @@ def _get_attr(attrs: Dict[str, Any], key: str): return None -# Decorator for passthrough to disk backend -def passthrough_to_disk(*extensions: str, ensure_parent_dir: bool = False): - """ - Decorator to automatically passthrough certain file types to disk backend. - - Zarr only supports array data, so non-array files (JSON, CSV, TXT, ROI.ZIP, etc.) - are automatically delegated to the disk backend. - - Uses introspection to automatically find the path parameter (any parameter with 'path' in its name). - - Args: - *extensions: File extensions to passthrough (e.g., '.json', '.csv', '.txt') - ensure_parent_dir: If True, ensure parent directory exists before calling disk backend (for save operations) - - Usage: - @passthrough_to_disk('.json', '.csv', '.txt', '.roi.zip', '.zip', ensure_parent_dir=True) - def save(self, data, output_path, **kwargs): - # Zarr-specific save logic here - ... - """ - import inspect - - def decorator(method: Callable) -> Callable: - # Use introspection to find the path parameter index at decoration time - sig = inspect.signature(method) - path_param_index = None - - for i, (param_name, param) in enumerate(sig.parameters.items()): - if param_name == 'self': - continue - # Find first parameter with 'path' in its name - if 'path' in param_name.lower(): - # Adjust for self parameter (subtract 1 since we skip 'self' in args) - path_param_index = i - 1 - break - - if path_param_index is None: - raise ValueError(f"No path parameter found in {method.__name__} signature. " - f"Expected a parameter with 'path' in its name.") - - @wraps(method) - def wrapper(self, *args, **kwargs): - # Extract path from args at the discovered index - path_arg = None - - if len(args) > path_param_index: - arg = args[path_param_index] - if isinstance(arg, (str, Path)): - path_arg = str(arg) - - # Check if path matches passthrough extensions - if path_arg and any(path_arg.endswith(ext) for ext in extensions): - from .constants import Backend - from .backend_registry import get_backend_instance - disk_backend = get_backend_instance(Backend.DISK.value) - - # Ensure parent directory exists if requested (for save operations) - if ensure_parent_dir: - parent_dir = Path(path_arg).parent - disk_backend.ensure_directory(parent_dir) - - # Call the same method on disk backend - return getattr(disk_backend, method.__name__)(*args, **kwargs) - - # Otherwise, call the original method - return method(self, *args, **kwargs) - - return wrapper - return decorator - - def _load_ome_zarr(): """Load ome-zarr and cache imports.""" try: @@ -175,11 +104,11 @@ def _ensure_ome_zarr(timeout: float = 30.0): FCNTL_AVAILABLE = False from .constants import Backend -from .base import StorageBackend +from .base import PicklableBackend, StorageBackend from .exceptions import StorageResolutionError -class ZarrStorageBackend(StorageBackend): +class ZarrStorageBackend(StorageBackend, PicklableBackend): """Zarr storage backend with automatic registration.""" _backend_type = Backend.ZARR.value supports_arbitrary_files = False # Class attribute: zarr only handles array data @@ -214,17 +143,27 @@ def __init__(self, zarr_config: Optional['ZarrConfig'] = None): if zarr_config is None: zarr_config = ZarrConfig() - self.config = zarr_config + self._configure(zarr_config) - # Convenience attributes + def _configure(self, zarr_config: "ZarrConfig") -> None: + self.config = zarr_config self.compression_level = zarr_config.compression_level - - # Create actual compressor from config (shuffle always enabled for Blosc) - self.compressor = self.config.compressor.create_compressor( + self.compressor = self.config.compressor_factory.create( self.config.compression_level, - shuffle=True # Always enable shuffle for better compression + shuffle=True, ) + def get_connection_params(self) -> Optional[Dict[str, Any]]: + return {"zarr_config": self.config} + + def set_connection_params(self, params: Optional[Dict[str, Any]]) -> None: + from .config import ZarrConfig + + if params is None: + self._configure(ZarrConfig()) + return + self._configure(params["zarr_config"]) + def _get_compressor(self) -> Optional[Any]: """ Get the configured compressor with appropriate settings. @@ -232,25 +171,36 @@ def _get_compressor(self) -> Optional[Any]: Returns: Configured compressor instance or None for no compression """ - if self.compressor is None: - return None + return self.compressor - # If compression_level is specified and compressor supports it - if self.compression_level is not None: - # Check if compressor has level parameter - if hasattr(self.compressor, '__class__'): - try: - # Create new instance with compression level - compressor_class = self.compressor.__class__ - if 'level' in compressor_class.__init__.__code__.co_varnames: - return compressor_class(level=self.compression_level) - elif 'clevel' in compressor_class.__init__.__code__.co_varnames: - return compressor_class(clevel=self.compression_level) - except (AttributeError, TypeError): - # Fall back to original compressor if level setting fails - pass + @staticmethod + def _as_cpu_array(data: Any) -> Any: + try: + import cupy + except ImportError: + pass + else: + if isinstance(data, cupy.ndarray): + return data.get() - return self.compressor + try: + import torch + except ImportError: + pass + else: + if isinstance(data, torch.Tensor): + return data.cpu().numpy() + + try: + import jax + from jax import Array as JaxArray + except ImportError: + pass + else: + if isinstance(data, JaxArray): + return jax.device_get(data) + + return data def _calculate_chunks(self, data_shape: Tuple[int, ...]) -> Tuple[int, ...]: """ @@ -304,7 +254,6 @@ def _split_store_and_key(self, path: Union[str, Path]) -> Tuple[Any, str]: store = zarr.DirectoryStore(str(store_path), dimension_separator='/') return store, relative_key - @passthrough_to_disk('.json', '.csv', '.txt', '.roi.zip', '.zip', ensure_parent_dir=True) def save(self, data: Any, output_path: Union[str, Path], **kwargs): """ Save data to Zarr at the given output_path. @@ -316,7 +265,15 @@ def save(self, data: Any, output_path: Union[str, Path], **kwargs): FileExistsError: If destination key already exists StorageResolutionError: If creation fails """ - # Zarr-specific save logic (non-array files automatically passthrough to disk) + output_path_text = str(output_path) + if output_path_text.endswith(DISK_PASSTHROUGH_EXTENSIONS): + from .backend_registry import get_backend_instance + + disk_backend = get_backend_instance(Backend.DISK.value) + disk_backend.ensure_directory(Path(output_path_text).parent) + disk_backend.save(data, output_path, **kwargs) + return + store, key = self._split_store_and_key(output_path) group = zarr.group(store=store) @@ -470,17 +427,7 @@ def save_batch(self, data_list: List[Any], output_paths: List[Union[str, Path]], logger.debug(f"Saving batch for chunk {chunk_name} with {len(data_list)} images to row={row}, col={col}") # Convert GPU arrays to CPU arrays before saving - cpu_data_list = [] - for data in data_list: - if hasattr(data, 'get'): # CuPy array - cpu_data_list.append(data.get()) - elif hasattr(data, 'cpu'): # PyTorch tensor - cpu_data_list.append(data.cpu().numpy()) - elif hasattr(data, 'device') and 'cuda' in str(data.device).lower(): # JAX on GPU - import jax - cpu_data_list.append(jax.device_get(data)) - else: # Already CPU array (NumPy, etc.) - cpu_data_list.append(data) + cpu_data_list = [self._as_cpu_array(data) for data in data_list] # Ensure parent directory exists store_path.parent.mkdir(parents=True, exist_ok=True) @@ -921,9 +868,13 @@ def delete_all(self, path: Union[str, Path]) -> None: except Exception as e: raise StorageResolutionError(f"Failed to recursively delete Zarr path: {path}") from e - @passthrough_to_disk('.json', '.csv', '.txt') def exists(self, path: Union[str, Path]) -> bool: - # Zarr-specific existence check (text files automatically passthrough to disk) + if str(path).endswith(DISK_PASSTHROUGH_EXTENSIONS): + from .backend_registry import get_backend_instance + + disk_backend = get_backend_instance(Backend.DISK.value) + return disk_backend.exists(path) + path = Path(path) # If path has no file extension, treat as directory existence check @@ -990,22 +941,37 @@ def is_symlink(self, path: Union[str, Path]) -> bool: try: obj = group[key] - attrs = getattr(obj, "attrs", {}) - - if "_symlink" not in attrs: - return False - - # Enforce that the _symlink attr matches schema (e.g. str or list of path components) - if not isinstance(attrs["_symlink"], str): - raise StorageResolutionError(f"Invalid symlink format in Zarr attrs at: {path}") - - return True + return self._symlink_target(obj, path) is not None except KeyError: # Key doesn't exist, so it's not a symlink return False except Exception as e: raise StorageResolutionError(f"Failed to inspect Zarr symlink at: {path}") from e + def _symlink_target(self, obj: Any, path: Union[str, Path]) -> str | None: + if not isinstance(obj, (zarr.core.Array, zarr.hierarchy.Group)): + raise StorageResolutionError(f"Unknown Zarr object at: {path}") + if "_symlink" not in obj.attrs: + return None + target = obj.attrs["_symlink"] + if not isinstance(target, str): + raise StorageResolutionError(f"Invalid symlink format in Zarr attrs at: {path}") + return target + + def _resolve_symlink(self, group: Any, key: str) -> tuple[str, Any]: + seen_keys = set() + while True: + if key not in group: + raise FileNotFoundError(f"Zarr key does not exist: {key}") + obj = group[key] + target = self._symlink_target(obj, key) + if target is None: + return key, obj + if key in seen_keys: + raise StorageResolutionError(f"Symlink cycle detected in Zarr at: {key}") + seen_keys.add(key) + key = target + def _auto_chunks(self, data: Any, chunk_divisor: int = 1) -> Tuple[int, ...]: shape = data.shape @@ -1036,22 +1002,7 @@ def is_file(self, path: Union[str, Path]) -> bool: store, key = self._split_store_and_key(path) group = zarr.group(store=store) - # Resolve symlinks (Zarr-native, via .attrs) - seen_keys = set() - while True: - if key not in group: - raise FileNotFoundError(f"Zarr key does not exist: {key}") - obj = group[key] - - if hasattr(obj, "attrs") and "_symlink" in obj.attrs: - if key in seen_keys: - raise StorageResolutionError(f"Symlink cycle detected in Zarr at: {key}") - seen_keys.add(key) - key = obj.attrs["_symlink"] - continue - break # resolution complete - - # Now obj is the resolved target + _, obj = self._resolve_symlink(group, key) if isinstance(obj, zarr.core.Array): return True elif isinstance(obj, zarr.hierarchy.Group): @@ -1090,24 +1041,8 @@ def is_dir(self, path: Union[str, Path]) -> bool: try: store, key = self._split_store_and_key(path) group = zarr.group(store=store) - - seen_keys = set() - - # Resolve symlink chain - while True: - if key not in group: - raise FileNotFoundError(f"Zarr key does not exist: {key}") - obj = group[key] - - if hasattr(obj, "attrs") and "_symlink" in obj.attrs: - if key in seen_keys: - raise StorageResolutionError(f"Symlink cycle detected in Zarr at: {key}") - seen_keys.add(key) - key = obj.attrs["_symlink"] - continue - break - - # obj is resolved + + _, obj = self._resolve_symlink(group, key) if isinstance(obj, zarr.hierarchy.Group): return True elif isinstance(obj, zarr.core.Array): @@ -1146,16 +1081,7 @@ def move(self, src: Union[str, Path], dst: Union[str, Path]) -> None: if dst_key in dst_group: raise FileExistsError(f"Zarr destination key already exists: {dst_key}") - obj = src_group[src_key] - - # Resolve symlinks if present - seen_keys = set() - while hasattr(obj, "attrs") and "_symlink" in obj.attrs: - if src_key in seen_keys: - raise StorageResolutionError(f"Symlink cycle detected at: {src_key}") - seen_keys.add(src_key) - src_key = obj.attrs["_symlink"] - obj = src_group[src_key] + src_key, obj = self._resolve_symlink(src_group, src_key) try: if src_store is dst_store: @@ -1194,15 +1120,7 @@ def copy(self, src: Union[str, Path], dst: Union[str, Path]) -> None: if dst_key in dst_group: raise FileExistsError(f"Zarr destination key already exists: {dst_key}") - obj = src_group[src_key] - - seen_keys = set() - while hasattr(obj, "attrs") and "_symlink" in obj.attrs: - if src_key in seen_keys: - raise StorageResolutionError(f"Symlink cycle detected at: {src_key}") - seen_keys.add(src_key) - src_key = obj.attrs["_symlink"] - obj = src_group[src_key] + src_key, obj = self._resolve_symlink(src_group, src_key) try: obj.copy(dst_group, name=dst_key) @@ -1230,13 +1148,8 @@ def stat(self, path: Union[str, Path]) -> Dict[str, Any]: try: if key in group: obj = group[key] - attrs = getattr(obj, "attrs", {}) - is_link = "_symlink" in attrs - - if is_link: - target = attrs["_symlink"] - if not isinstance(target, str): - raise StorageResolutionError(f"Invalid symlink format at {key}") + target = self._symlink_target(obj, key) + if target is not None: return { "type": "symlink", "key": key, diff --git a/tests/test_filemanager_extended.py b/tests/test_filemanager_extended.py index 0d8b30a..60a0784 100644 --- a/tests/test_filemanager_extended.py +++ b/tests/test_filemanager_extended.py @@ -8,20 +8,31 @@ - Advanced features (symlinks, find, mirror) """ +import json +import pickle import tempfile import shutil from pathlib import Path import numpy as np import pytest -from polystore import FileManager, BackendRegistry +from polystore import FileManager +from polystore.constants import Backend +from polystore.disk import DiskBackend from polystore.exceptions import StorageResolutionError +from polystore.memory import MemoryBackend +from polystore.metadata_writer import get_metadata_path +from polystore.virtual_workspace import VirtualWorkspaceBackend +from polystore.zarr import ZarrStorageBackend @pytest.fixture def registry(): """Create a backend registry with disk and memory backends.""" - return BackendRegistry() + return { + Backend.DISK.value: DiskBackend(), + Backend.MEMORY.value: MemoryBackend(), + } @pytest.fixture @@ -51,6 +62,33 @@ def test_init_with_valid_registry(self, registry): fm = FileManager(registry) assert fm.registry is registry + def test_pickle_preserves_picklable_registry_backends(self, tmp_path): + """FileManager owns worker-safe recreation of context-specific backends.""" + (tmp_path / "real.tif").write_bytes(b"placeholder") + get_metadata_path(tmp_path).write_text(json.dumps({ + "subdirectories": { + ".": { + "workspace_mapping": {"virtual.tif": "real.tif"}, + "available_backends": {Backend.VIRTUAL_WORKSPACE.value: True}, + } + } + })) + + filemanager = FileManager({ + Backend.ZARR.value: ZarrStorageBackend(), + Backend.VIRTUAL_WORKSPACE.value: VirtualWorkspaceBackend(tmp_path), + }) + + restored = pickle.loads(pickle.dumps(filemanager)) + + assert Backend.ZARR.value in restored.registry + assert Backend.VIRTUAL_WORKSPACE.value in restored.registry + assert restored.registry[Backend.VIRTUAL_WORKSPACE.value].plate_root == tmp_path + assert ( + restored.registry[Backend.ZARR.value].config + == filemanager.registry[Backend.ZARR.value].config + ) + class TestFileManagerBackendResolution: """Test backend resolution and error handling.""" diff --git a/tests/test_memory_backend.py b/tests/test_memory_backend.py index f55996b..ec8a080 100644 --- a/tests/test_memory_backend.py +++ b/tests/test_memory_backend.py @@ -109,6 +109,17 @@ def test_list_files_with_extension_filter(self): npy_files = self.backend.list_files("/test", extensions={".npy"}) assert len(npy_files) == 2 + def test_list_files_extension_filter_is_case_insensitive(self): + """Test extension filtering matches backend contract case-insensitively.""" + self.backend.save(np.array([1]), "/test/image.TIF") + self.backend.save(np.array([2]), "/test/image.tif") + self.backend.save("text", "/test/notes.TXT") + + tif_files = self.backend.list_files("/test", extensions={".tif"}) + + assert len(tif_files) == 2 + assert {path.name for path in tif_files} == {"image.TIF", "image.tif"} + def test_list_files_recursive(self): """Test recursive file listing.""" # Create files in multiple levels diff --git a/tests/test_roi.py b/tests/test_roi.py new file mode 100644 index 0000000..2364c4c --- /dev/null +++ b/tests/test_roi.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest + +from polystore.disk import DiskStorageBackend +from polystore.roi import ROI +from polystore.roi import MaskShape +from polystore.roi import PolygonShape +from polystore.roi import load_rois_from_json +from polystore.roi import load_rois_from_zip +from polystore.roi import extract_rois_from_labeled_mask + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_polygons(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=True, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert rois[0].metadata["bbox"] == (12, 23, 16, 27) + assert rois[0].metadata["centroid"] == (13.5, 24.5) + assert isinstance(rois[0].shapes[0], PolygonShape) + assert float(rois[0].shapes[0].coordinates[:, 0].min()) >= 11.5 + assert float(rois[0].shapes[0].coordinates[:, 1].min()) >= 22.5 + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_mask_bbox(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=False, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], MaskShape) + assert rois[0].shapes[0].bbox == (12, 23, 16, 27) + + +def test_extract_rois_from_labeled_mask_records_source_canvas_shape(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + source_spatial_shape_yx=(100, 200), + ) + + assert len(rois) == 1 + assert rois[0].metadata["source_spatial_shape_yx"] == (100, 200) + + +def test_roi_zip_roundtrip_preserves_source_canvas_shape_metadata(tmp_path): + pytest.importorskip("roifile") + path = tmp_path / "labels.roi.zip" + rois = [ + ROI( + shapes=[ + PolygonShape( + np.array( + [[10, 20], [10, 22], [12, 22], [12, 20]], + dtype=float, + ) + ) + ], + metadata={"label": 7, "source_spatial_shape_yx": (100, 200)}, + ) + ] + + DiskStorageBackend()._save_rois(rois, path) + loaded_rois = load_rois_from_zip(path) + + assert loaded_rois[0].metadata["label"] == 7 + assert loaded_rois[0].metadata["source_spatial_shape_yx"] == (100, 200) + + +def test_load_rois_from_json_decodes_shapes_through_nominal_registry(tmp_path): + roi_path = tmp_path / "rois.json" + roi_path.write_text( + """ + [ + { + "metadata": {"label": 1}, + "shapes": [ + {"type": "polygon", "coordinates": [[1, 2], [3, 4], [5, 6]]}, + {"type": "mask", "mask": [[true, false], [false, true]], "bbox": [10, 20, 12, 22]} + ] + } + ] + """ + ) + + rois = load_rois_from_json(roi_path) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], PolygonShape) + assert isinstance(rois[0].shapes[1], MaskShape) + assert rois[0].shapes[1].bbox == (10, 20, 12, 22) diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py new file mode 100644 index 0000000..e2e07c6 --- /dev/null +++ b/tests/test_streaming_metadata.py @@ -0,0 +1,311 @@ +from types import SimpleNamespace + +import pytest + +from polystore.streaming._streaming_backend import StreamingBackend +from polystore.streaming._streaming_backend import StreamingBatchItemPreparationAuthority +from polystore.streaming._streaming_backend import StreamingBatchMessageBuilder +from polystore.streaming._streaming_backend import StreamingBatchMessageRequest +from polystore.streaming._streaming_backend import StreamingComponentNamesRequest +from polystore.streaming._streaming_backend import StreamingItemPath +from polystore.streaming._streaming_backend import StreamingItemPreparationRequest +from polystore.streaming_constants import StreamingDataType +from polystore.streaming.identity import StreamProducerIdentity +from polystore.streaming.viewer_transport import BatchViewerStreamSourceMetadata +from polystore.streaming.viewer_transport import IndexedViewerStreamSourceMetadata +from polystore.streaming.viewer_transport import PathMappedViewerStreamSourceMetadata +from polystore.streaming.viewer_transport import ViewerDisplayConfigABC +from polystore.streaming.viewer_transport import ViewerMicroscopeHandlerABC +from polystore.streaming.viewer_transport import ViewerStreamProducer +from polystore.streaming.viewer_transport import ViewerStreamItemPayload +from polystore.streaming.viewer_transport import ViewerStreamRequest +from polystore.streaming.viewer_transport import ViewerStreamSource +from polystore.streaming.viewer_transport import ViewerStreamSourceIdentity +from polystore.streaming.viewer_transport import ViewerStreamSourceMetadata +from zmqruntime.config import TransportMode +from zmqruntime.viewer_protocol import ViewerAckPolicy +from zmqruntime.viewer_protocol import ViewerTransportEndpoint + + +class MetadataProbeStreamingBackend(StreamingBackend): + VIEWER_TYPE = "probe" + SHM_PREFIX = "probe_" + + def _prepare_batch_item(self, request: StreamingItemPreparationRequest): + return ViewerStreamItemPayload( + item_payload={"path": request.item_path.wire_value, "payload": "ok"}, + streaming_data_type=StreamingDataType.IMAGE, + ) + + def save_batch(self, data_list, file_paths, **kwargs): + raise NotImplementedError + + +class DisplayConfigStub(ViewerDisplayConfigABC): + COMPONENT_ORDER = ("well", "site", "channel") + + def component_modes(self): + return { + "well": "stack", + "site": "stack", + "channel": "stack", + } + + +PRODUCER_IDENTITY = StreamProducerIdentity( + origin="pipeline", + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", +) + + +EMPTY_SOURCE_METADATA = BatchViewerStreamSourceMetadata( + {"well": "A01", "site": 1, "channel": 1} +) + + +def stream_request( + microscope_handler, + source_metadata=EMPTY_SOURCE_METADATA, + *, + plate_path=None, + message_extra=None, +): + return ViewerStreamRequest( + viewer_transport=ViewerTransportEndpoint( + host="127.0.0.1", + port=5555, + transport_mode=TransportMode.TCP, + ), + display_config=DisplayConfigStub(), + source=ViewerStreamSource( + identity=ViewerStreamSourceIdentity( + microscope_handler=microscope_handler, + plate_path=plate_path, + ), + metadata=source_metadata, + ), + producer=ViewerStreamProducer.from_identity(PRODUCER_IDENTITY), + message_extra=message_extra, + ) + + +def batch_message_request(data_list, file_paths, viewer_request): + return StreamingBatchMessageRequest( + data_list=data_list, + file_paths=file_paths, + stream_request=viewer_request, + component_names_request=( + StreamingComponentNamesRequest.from_stream_request(viewer_request) + ), + ) + + +def microscope_handler_with_parser(parser): + class MicroscopeHandlerStub(ViewerMicroscopeHandlerABC): + pass + + microscope_handler = MicroscopeHandlerStub() + microscope_handler.parser = parser + microscope_handler.metadata_handler = SimpleNamespace( + get_component_values=lambda _plate_path, _component_name: None + ) + return microscope_handler + + +def test_streaming_source_metadata_is_abstract_boundary() -> None: + with pytest.raises(TypeError, match="abstract"): + ViewerStreamSourceMetadata().component_metadata_for_item( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + 0, + ) + + +def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + with pytest.raises(KeyError, match="path-mapped component metadata"): + StreamingBatchItemPreparationAuthority.prepare( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + stream_request( + microscope_handler, + PathMappedViewerStreamSourceMetadata(metadata_by_path={}), + ), + ) + ) + + +def test_streaming_batch_items_accept_per_path_component_metadata() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + prepared_items = StreamingBatchItemPreparationAuthority.prepare( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + stream_request( + microscope_handler, + PathMappedViewerStreamSourceMetadata( + metadata_by_path={ + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip": { + "well": "A01", + "site": 1, + "channel": 1, + }, + } + ), + ), + ) + ) + + assert prepared_items.batch_images[0]["metadata"] == { + "well": "A01", + "site": 1, + "channel": 1, + } + assert ( + prepared_items.batch_images[0]["producer_identity"] + == PRODUCER_IDENTITY.to_payload() + ) + + +def test_streaming_item_component_metadata_preserves_explicit_fields() -> None: + metadata = BatchViewerStreamSourceMetadata( + {"well": "A01", "site": 1, "channel": 1}, + ).component_metadata_for_item( + StreamingItemPath("A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip").value, + 0, + ) + + assert metadata == { + "well": "A01", + "site": 1, + "channel": 1, + } + + +def test_streaming_batch_message_declares_component_value_domain() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + built_batch = StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object(), object()], + ["A01_s001_w1_z001_t001.tif", "A01_s002_w2_z001_t001.tif"], + stream_request( + microscope_handler, + IndexedViewerStreamSourceMetadata( + metadata_by_index=( + {"well": "A01", "site": 1, "channel": 1}, + {"well": "A01", "site": 2, "channel": 2}, + ), + ), + ), + ), + ) + + assert built_batch.message["component_value_domain"] == { + "well": ["A01"], + "site": [1, 2], + "channel": [1, 2], + } + + +def test_streaming_batch_message_honors_declared_component_metadata_payload() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + built_batch = StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001.tif"], + stream_request( + microscope_handler, + BatchViewerStreamSourceMetadata( + {"well": "A01", "site": 1, "channel": 1} + ), + message_extra={ + "component_value_domain": {"well": ["A01", "B01"]}, + "component_names_metadata": {"well": {"A01": "control"}}, + }, + ), + ), + ) + + assert built_batch.message["component_value_domain"] == {"well": ["A01", "B01"]} + assert built_batch.message["component_names_metadata"] == { + "well": {"A01": "control"} + } + + +def test_streaming_batch_message_rejects_partial_declared_component_metadata_payload() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = microscope_handler_with_parser( + SimpleNamespace(parse_filename=lambda _filename: None) + ) + + with pytest.raises(ValueError, match="component_names_metadata"): + StreamingBatchMessageBuilder.build( + backend, + batch_message_request( + [object()], + ["A01_s001_w1_z001_t001.tif"], + stream_request( + microscope_handler, + BatchViewerStreamSourceMetadata( + {"well": "A01", "site": 1, "channel": 1} + ), + message_extra={"component_value_domain": {"well": ["A01"]}}, + ), + ), + ) + + +def test_streaming_component_metadata_rejects_invalid_explicit_metadata() -> None: + with pytest.raises(TypeError, match="must be a mapping"): + BatchViewerStreamSourceMetadata( + ["not", "metadata"], + ).component_metadata_for_item( + StreamingItemPath("A01_s001_w1_z001_t001.TIF").value, + 0, + ) + + +class ViewerAckSocketStub: + def __init__(self, response): + self.response = response + + def recv_json(self): + return self.response + + +def test_viewer_ack_policy_rejects_error_status_and_cleans_up() -> None: + cleanup_calls = [] + policy = ViewerAckPolicy(viewer_name="Napari", timeout_ms=30_000) + + with pytest.raises(RuntimeError, match="Napari rejected stream batch"): + policy.receive( + ViewerAckSocketStub( + {"status": "error", "message": "missing component_value_domain"} + ), + lambda: cleanup_calls.append("cleanup"), + port=5555, + ) + + assert cleanup_calls == ["cleanup"] diff --git a/tests/test_streaming_receiver_core.py b/tests/test_streaming_receiver_core.py index 6f7cd63..eb7a76b 100644 --- a/tests/test_streaming_receiver_core.py +++ b/tests/test_streaming_receiver_core.py @@ -4,75 +4,230 @@ import time from polystore.streaming_constants import StreamingDataType +from polystore.streaming.identity import ( + FixedStreamProducerIdentityKind, + StreamProducerDisplayNameAuthority, + StreamProducerIdentity, +) from polystore.streaming.receivers.core import ( DebouncedBatchEngine, + WindowProjectionSource, group_items_by_component_modes, ) from polystore.streaming.receivers.napari import ( normalize_component_layout, - build_layer_key, + build_route_key, ) +from zmqruntime.viewer_protocol import ViewerBatchDisplayPayload + +class PipelineProducerFixture: + """Nominal producer fixtures for receiver-core tests.""" + + MAIN_KIND = "main" + MAIN_KEY = "main" + ARTIFACT_KIND = "artifact" + + @classmethod + def main_output( + cls, + *, + step_name: str, + pipeline_position: int, + ) -> StreamProducerIdentity: + return StreamProducerIdentity.pipeline_output( + output_kind=cls.MAIN_KIND, + output_key=cls.MAIN_KEY, + step_name=step_name, + pipeline_position=pipeline_position, + ) + @classmethod + def artifact_output( + cls, + *, + output_key: str, + step_name: str, + pipeline_position: int, + artifact_kind: str | None = None, + ) -> StreamProducerIdentity: + return StreamProducerIdentity.pipeline_output( + output_kind=cls.ARTIFACT_KIND, + output_key=output_key, + step_name=step_name, + pipeline_position=pipeline_position, + artifact_kind=artifact_kind, + ) -def test_group_items_by_component_modes_source_normalization_for_rois() -> None: + +def test_group_items_by_component_modes_keys_windows_by_producer_identity() -> None: + image_identity = PipelineProducerFixture.main_output( + step_name="RawLoad", + pipeline_position=1, + ) + roi_identity = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + artifact_kind="object_labels", + ) items = [ { "data_type": "rois", - "metadata": {"source": "/tmp/foo_results", "well": "A01", "channel": 1}, + "metadata": {"well": "A01", "channel": 1}, + "producer_identity": roi_identity.to_payload(), }, { "data_type": "image", - "metadata": {"source": "step_1", "well": "A01", "channel": 1}, + "metadata": {"well": "A01", "channel": 1}, + "producer_identity": image_identity.to_payload(), }, ] - component_modes = {"source": "window", "well": "frame", "channel": "channel"} - component_order = ["source", "well", "channel"] + component_modes = {"well": "frame", "channel": "channel"} + component_order = ["well", "channel"] grouped = group_items_by_component_modes( - items, - component_modes=component_modes, - component_order=component_order, - images_dir="/my/plate/images", + WindowProjectionSource.from_wire_payloads(items), + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), ) - assert grouped.window_components == ["source"] + assert grouped.window_components == [] assert grouped.channel_components == ["channel"] assert grouped.frame_components == ["well"] - assert "source_images" in grouped.windows - assert "source_step_1" in grouped.windows + assert grouped.slice_components == [] + assert grouped.fixed_window_labels[ + "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels" + ] == (("producer", "3. Segment Nuclei"),) + assert set(grouped.windows) == { + "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_artifact_object_labels", + "origin_pipeline_kind_main_out_main_step_1_name_RawLoad", + } + + +def test_group_items_by_component_modes_rejects_missing_metadata() -> None: + producer = PipelineProducerFixture.main_output( + step_name="RawLoad", + pipeline_position=1, + ) + + try: + group_items_by_component_modes( + WindowProjectionSource.from_wire_payloads( + [{"producer_identity": producer.to_payload()}] + ), + display_layout=ViewerBatchDisplayPayload( + component_modes={"well": "window"}, + component_order=["well"], + ), + ) + except ValueError as error: + assert "metadata" in str(error) + else: + raise AssertionError("missing metadata must fail loudly") + + +def test_stream_producer_display_name_authority_matches_pipeline_editor_indexing() -> None: + main_output = PipelineProducerFixture.main_output( + step_name="ConvertObjectsToImage", + pipeline_position=8, + ) + artifact_output = PipelineProducerFixture.artifact_output( + output_key="NucleiObjects3D", + step_name="ConvertObjectsToImage", + pipeline_position=8, + artifact_kind="object_labels", + ) + manual_output = StreamProducerIdentity.fixed_output( + FixedStreamProducerIdentityKind.MANUAL, + "selected_rois", + ) + assert ( + StreamProducerDisplayNameAuthority.producer_label(main_output) + == "9. ConvertObjectsToImage" + ) + assert ( + StreamProducerDisplayNameAuthority.output_label(main_output) + == "9. ConvertObjectsToImage" + ) + assert ( + StreamProducerDisplayNameAuthority.output_label(artifact_output) + == "9. ConvertObjectsToImage NucleiObjects3D" + ) + assert StreamProducerDisplayNameAuthority.output_label(manual_output) == "selected_rois" + assert ( + StreamProducerDisplayNameAuthority.disambiguation_label(main_output) + == "step 9" + ) -def test_napari_layer_key_builder_uses_slice_components_and_payload_type() -> None: + +def test_napari_route_key_builder_uses_producer_slice_components_and_payload_type() -> None: + producer = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + ) component_modes = {"well": "slice", "channel": "stack", "site": "slice"} component_order = ["well", "channel", "site"] component_info = {"well": "A01", "channel": 2, "site": 3} - key_image = build_layer_key( + key_image = build_route_key( + producer_identity=producer, component_info=component_info, - component_modes=component_modes, - component_order=component_order, + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), data_type=StreamingDataType.IMAGE, ) - key_shapes = build_layer_key( + key_shapes = build_route_key( + producer_identity=producer, component_info=component_info, - component_modes=component_modes, - component_order=component_order, + display_layout=ViewerBatchDisplayPayload( + component_modes=component_modes, + component_order=component_order, + ), data_type=StreamingDataType.SHAPES, ) - assert key_image == "well_A01_site_3" - assert key_shapes == "well_A01_site_3_shapes" + assert key_image == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3" + assert key_shapes == "origin_pipeline_kind_artifact_out_Nuclei_step_2_name_Segment_well_A01_site_3_shapes" + + +def test_napari_route_key_builder_rejects_missing_slice_component() -> None: + producer = PipelineProducerFixture.artifact_output( + output_key="Nuclei", + step_name="Segment", + pipeline_position=2, + ) + + try: + build_route_key( + producer_identity=producer, + component_info={"well": "A01"}, + display_layout=ViewerBatchDisplayPayload( + component_modes={"well": "slice", "site": "slice"}, + component_order=["well", "site"], + ), + data_type=StreamingDataType.IMAGE, + ) + except ValueError as error: + assert "site" in str(error) + else: + raise AssertionError("missing slice component must fail loudly") def test_normalize_component_layout_dict_config() -> None: - component_modes, component_order = normalize_component_layout( + display_layout = normalize_component_layout( { "component_modes": {"well": "slice", "channel": "stack"}, "component_order": ["well", "channel"], } ) - assert component_order == ["well", "channel"] - assert component_modes["well"] == "slice" + assert list(display_layout.component_order) == ["well", "channel"] + assert display_layout.component_modes["well"] == "slice" def test_debounced_batch_engine_flush_processes_pending_once() -> None: diff --git a/tests/test_viewer_transport.py b/tests/test_viewer_transport.py new file mode 100644 index 0000000..5074023 --- /dev/null +++ b/tests/test_viewer_transport.py @@ -0,0 +1,152 @@ +import pytest + +from polystore.streaming.identity import StreamProducerIdentity +from polystore.streaming.viewer_transport import BatchViewerStreamSourceMetadata +from polystore.streaming.viewer_transport import ExplicitViewerTransportConfig +from polystore.streaming.viewer_transport import IndexedViewerStreamSourceMetadata +from polystore.streaming.viewer_transport import ViewerStreamBackendKwargs +from polystore.streaming.viewer_transport import ViewerStreamKwarg +from polystore.streaming.viewer_transport import ViewerDisplayConfigABC +from polystore.streaming.viewer_transport import ViewerFilenameParserABC +from polystore.streaming.viewer_transport import ViewerMetadataHandlerABC +from polystore.streaming.viewer_transport import ViewerMicroscopeHandlerABC +from polystore.streaming.viewer_transport import ViewerStreamRequest +from polystore.streaming.viewer_transport import ViewerStreamProducer +from polystore.streaming.viewer_transport import ViewerStreamSource +from polystore.streaming.viewer_transport import ViewerStreamSourceIdentity +from polystore.streaming.viewer_transport import ViewerStreamSourceMetadata +from zmqruntime.config import TransportMode, ZMQConfig +from zmqruntime.viewer_protocol import ViewerTransportEndpoint + + +class DisplayConfigFixture(ViewerDisplayConfigABC): + COMPONENT_ORDER = ("well", "site", "channel") + + def component_modes(self): + return { + "well": "stack", + "site": "slice", + "channel": "channel", + } + + +class FilenameParserFixture(ViewerFilenameParserABC): + def parse_filename(self, filename): + return {"filename": filename} + + +class MetadataHandlerFixture(ViewerMetadataHandlerABC): + def get_component_values(self, plate_path, component_name): + return f"{plate_path}:{component_name}" + + +class MicroscopeHandlerFixture(ViewerMicroscopeHandlerABC): + parser = FilenameParserFixture() + metadata_handler = MetadataHandlerFixture() + + +EMPTY_SOURCE_METADATA = BatchViewerStreamSourceMetadata( + {"well": "A01", "site": 1, "channel": 1} +) + + +def stream_source( + source_metadata=EMPTY_SOURCE_METADATA, + *, + plate_path="/tmp/plate", +): + return ViewerStreamSource( + identity=ViewerStreamSourceIdentity( + microscope_handler=MicroscopeHandlerFixture(), + plate_path=plate_path, + ), + metadata=source_metadata, + ) + + +def required_stream_request(**kwargs): + values = { + "viewer_transport": ViewerTransportEndpoint( + host="127.0.0.1", + port=5555, + transport_mode=TransportMode.TCP, + ), + "display_config": DisplayConfigFixture(), + "source": stream_source(), + "producer": ViewerStreamProducer.from_identity( + StreamProducerIdentity.pipeline_output( + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", + pipeline_position=2, + ) + ), + } + values.update(kwargs) + return ViewerStreamRequest(**values) + + +def test_viewer_stream_kwargs_declares_explicit_backend_request() -> None: + stream_kwargs = required_stream_request( + source=stream_source( + IndexedViewerStreamSourceMetadata( + metadata_by_index=( + {"well": "A01", "site": 1}, + {"well": "A01", "site": 2}, + ), + ), + plate_path="/tmp/plate", + ), + message_extra={"component_value_domain": {"well": ["A01"]}}, + images_dir="/tmp/images", + ) + + backend_kwargs = ViewerStreamBackendKwargs(stream_kwargs).to_kwargs() + + assert backend_kwargs == {ViewerStreamKwarg.STREAM_REQUEST.value: stream_kwargs} + assert ViewerStreamBackendKwargs.from_kwargs(backend_kwargs).stream_request is stream_kwargs + assert stream_kwargs.host == "127.0.0.1" + assert stream_kwargs.port == 5555 + assert stream_kwargs.transport_mode is TransportMode.TCP + assert stream_kwargs.producer.identity == StreamProducerIdentity.pipeline_output( + output_kind="main", + output_key="main", + step_name="IdentifyPrimaryObjects", + pipeline_position=2, + ) + assert stream_kwargs.source.metadata.metadata_by_index == ( + {"well": "A01", "site": 1}, + {"well": "A01", "site": 2}, + ) + default_config = ZMQConfig(default_port=9001) + assert stream_kwargs.transport_config.resolve(default_config) is default_config + + +def test_viewer_stream_source_metadata_is_abstract_boundary() -> None: + with pytest.raises(TypeError, match="abstract"): + ViewerStreamSourceMetadata() + + +def test_viewer_stream_backend_rejects_flat_kwargs() -> None: + with pytest.raises(ValueError, match="stream_request"): + ViewerStreamBackendKwargs.from_kwargs( + {"display_config": DisplayConfigFixture()} + ) + + +def test_viewer_stream_kwargs_preserves_explicit_transport_config() -> None: + explicit_config = ZMQConfig(shared_ack_port=8111) + default_config = ZMQConfig(shared_ack_port=8222) + + stream_kwargs = required_stream_request( + transport_config=ExplicitViewerTransportConfig(explicit_config) + ) + + assert stream_kwargs.transport_config.resolve(default_config) is explicit_config + + +def test_viewer_stream_backend_rejects_non_request_payload() -> None: + with pytest.raises(TypeError, match="ViewerStreamRequest"): + ViewerStreamBackendKwargs.from_kwargs( + {ViewerStreamKwarg.STREAM_REQUEST.value: DisplayConfigFixture()} + )