Source code for cellector.roi_processor

from typing import List, Optional, Dict, Union
from pathlib import Path
from copy import deepcopy
import numpy as np
from . import utils
from .filters import filter
from .features import FeaturePipeline, standard_pipelines, functional_pipelines

# Might be useful for optimizing parameters
# from sklearn.model_selection import ParameterGrid

# Default parameters for RedCellProcessor
DEFAULT_PARAMETERS = dict(
    um_per_pixel=None,
    surround_iterations=2,
    fill_value=0.0,
    centered_width=40,
    centroid_method="median",
    window_kernel=np.hanning,
    phase_corr_eps=1e6,
    lowcut=12,
    highcut=250,
    order=3,
)

# Mapping of parameters to cache entries that are affected by the change
PARAM_CACHE_MAPPING = dict(
    surround_iterations=[],
    fill_value=[
        "centered_masks",
        "centered_reference",
        "filtered_centered_reference",
        "centered_reference_functional",
        "filtered_centered_reference_functional",
    ],
    centered_width=[
        "centered_masks",
        "centered_reference",
        "filtered_centered_reference",
        "centered_reference_functional",
        "filtered_centered_reference_functional",
    ],
    centroid_method=[
        "centroids",
        "centered_masks",
        "centered_reference",
        "filtered_centered_reference",
        "centered_reference_functional",
        "filtered_centered_reference_functional",
    ],
    window_kernel=[],
    phase_corr_eps=[],
    lowcut=["filtered_centered_reference", "filtered_centered_reference_functional"],
    highcut=["filtered_centered_reference", "filtered_centered_reference_functional"],
    order=["filtered_centered_reference", "filtered_centered_reference_functional"],
)


[docs] class RoiProcessor: """ Process and analyze mask & fluorescence data across multiple image planes. This class handles the processing of mask & fluorescence data by managing masks and reference images across multiple planes, providing functionality for feature calculation and analysis. Attributes ---------- root_dir : Path Path to the root directory where the data is stored. zpix : Union[np.ndarray, List[np.ndarray]] Numpy array containing the plane index for each ROI. Or, if volumetric=True, a list of numpy arrays containing the z-pixel indices for each ROI. ypix : List[np.ndarray] List of numpy arrays containing the y-pixel indices for each ROI. xpix : List[np.ndarray] List of numpy arrays containing the x-pixel indices for each ROI. lam : List[np.ndarray] List of numpy arrays containing the pixen intensities for each ROI. reference : np.ndarray 3D numpy array containing reference images for each plane. functional_reference : Union[np.ndarray, None] 3D numpy array containing functional reference images for each plane (optional). lx, ly, lz : int Dimensions of the imaging volume. num_rois : int Total number of ROIs across all planes. features : dict Computed features for all ROIs. feature_pipeline_methods : dict Mapping of feature pipeline names to their corresponding methods. feature_pipeline_dependencies : dict Mapping of feature pipeline names to dependencies on attributes of roi_processor instances. parameters : dict Dictionary containing all the preprocessing parameters used. _cache : dict Dictionary containing cached values of attributes that are expensive to compute. """ def __init__( self, root_dir: Union[Path, str], zpix: Union[np.ndarray, List[np.ndarray]], ypix: List[np.ndarray], xpix: List[np.ndarray], lam: List[np.ndarray], reference: np.ndarray, functional_reference: Union[np.ndarray, None] = None, extra_features: Optional[Dict[str, List[np.ndarray]]] = None, volumetric: bool = False, autocompute: bool = True, use_saved: bool = True, save_features: bool = True, **kwargs: dict, ): """Initialize the RoiProcessor with ROI stats and reference images. Parameters ---------- root_dir : Union[Path, str] Path to the root directory where the data is stored. This is used to save and load features from disk. zpix: Union[np.ndarray, List[np.ndarray]] Numpy array containing the plane index for each ROI. Or, if volumetric=True, a list of numpy arrays containing the z-pixel indices for each ROI. ypix: List[np.ndarray] List of numpy arrays containing the y-pixel indices for each ROI. xpix: List[np.ndarray] List of numpy arrays containing the x-pixel indices for each ROI. lam: List[np.ndarray] List of numpy arrays containing the pixen intensities for each ROI. reference : np.ndarray 3D numpy array containing reference images for each plane. The first dimension should be the number of planes (lz), and the second and third dimensions should be the height and width of the reference images (notated as ly and lx). functional_reference : Union[np.ndarray, None], optional 3D numpy array containing functional reference images for each plane (optional). extra_features : Optional[Dict[str, List[np.ndarray]]], optional Dictionary containing extra features to be added to each plane. Each key is the name of the feature and the value is a list of 1d numpy arrays with length equal to the number of ROIs in each plane. Default is None. volumetric : bool, optional Set as True to indicate that the data is volumetric (e.g. processed with suite3d). This will convert all processing operations to use volumetric information about ROI masks. autocompute : bool, optional If True, will automatically compute all standard features upon initialization. The only reason not to have this set to True is if you want the object for some other purpose or if you want to compute a subset of the features, which you can do manually. Default is True. use_saved : bool, optional If True, will attempt to load saved features from disk if they exist. Default is True. save_features : bool, optional If True, will save the computed features to disk. Default is True. **kwargs : dict Additional parameters to update the default parameters used for preprocessing. """ if len(lam) != len(ypix) != len(xpix) != len(zpix): raise ValueError( "Lengths of mask data do not match each other (inspect lam, xpix, ypix, and zpix)!" ) if not isinstance(reference, np.ndarray) or reference.ndim != 3: raise TypeError("reference must be a 3D numpy array") if ( functional_reference is not None and functional_reference.shape != reference.shape ): raise TypeError( "Functional reference must have the same shape as reference" ) root_dir = Path(root_dir) if not root_dir.is_dir(): raise ValueError("root_dir must be existing directory.") # Initialize attributes self.zpix = zpix self.ypix = ypix self.xpix = xpix self.lam = lam self.lz, self.ly, self.lx = reference.shape self.num_rois = len(lam) self.reference = reference self.functional_reference = functional_reference self.volumetric = volumetric self.root_dir = root_dir self.save_features = save_features # Validate mask data for each ROI if self.volumetric: for lm, xp, yp, zp in zip(self.lam, self.xpix, self.ypix, self.zpix): if not (len(lm) == len(xp) == len(yp) == len(zp)): raise ValueError("Mismatched lengths of mask data") if ( max(max(x) for x in self.xpix) >= self.lx or max(max(y) for y in self.ypix) >= self.ly or max(max(z) for z in self.zpix) >= self.lz ): raise ValueError("Pixel indices exceed image dimensions") else: for lm, xp, yp in zip(self.lam, self.xpix, self.ypix): if not (len(lm) == len(xp) == len(yp)): raise ValueError("Mismatched lengths of mask data") if ( max(max(x) for x in self.xpix) >= self.lx or max(max(y) for y in self.ypix) >= self.ly ): raise ValueError("Pixel indices exceed image dimensions") # Store flattened mask data for some optimized implementations self._flattened_roi_data = utils.flatten_roi_data( self.lam, self.ypix, self.xpix, self.zpix if self.volumetric else None, ) # Initialize feature and pipeline dictionary self.features = {} self.feature_pipeline_methods = {} self.feature_pipeline_dependencies = {} # Initialize preprocessing cache self._cache = {} # If extra features are provided, validate and store if extra_features is not None: if not isinstance(extra_features, dict): raise TypeError("Extra features must be a dictionary") for name, value in extra_features.items(): if not isinstance(name, str): raise TypeError("Extra feature values must be a numpy array") if not isinstance(value, np.ndarray) or not len(value) == self.num_rois: raise TypeError( "Extra feature values must be a numpy array with length equal to the number of ROIs" ) self.add_feature(name, value) # Establish preprocessing parameters self.parameters = deepcopy(DEFAULT_PARAMETERS) if set(kwargs) - set(DEFAULT_PARAMETERS): raise ValueError( f"Invalid parameter(s): {', '.join(set(kwargs) - set(DEFAULT_PARAMETERS))}" ) self.parameters.update(kwargs) # register feature pipelines for pipeline in standard_pipelines: self.register_feature_pipeline(pipeline) if self.functional_reference is not None: for pipeline in functional_pipelines: self.register_feature_pipeline(pipeline) # Measure features if autocompute: self.compute_features(use_saved)
[docs] def compute_features(self, use_saved: bool = True): """Compute all registered features for each ROI. FeaturePipelines are registered with the RoiProcessor instance, and each pipeline defines a method that computes a feature based on the attributes of the RoiProcessor instance. compute_features iterates over each pipeline and computes the feature values for each ROI. Resulting feature values are stored in the self.features dictionary. Parameters ---------- use_saved : bool, optional If True, will attempt to load saved features from disk if they exist. Default is True. """ from .io.base import load_feature, is_feature_saved for name, method in self.feature_pipeline_methods.items(): if use_saved: if is_feature_saved(self.root_dir, name): value = load_feature(self.root_dir, name) if len(value) == self.num_rois: self.add_feature(name, value) # Skip recomputing the feature and move to next one continue # If the feature is not saved or the shapes don't match, compute the feature again and add it self.add_feature(name, method(self))
[docs] def add_feature(self, name: str, values: np.ndarray): """Add (or update) the name and values to the self.features dictionary. Parameters ---------- name : str Name of the feature. values : np.ndarray Feature values for each ROI. Must have the same length as the number of ROIs across all planes. """ from .io.base import save_feature if len(values) != self.num_rois: raise ValueError( f"Length of feature values ({len(values)}) for feature {name} must match number of ROIs ({self.num_rois})" ) self.features[name] = values # cache the feature values if self.save_features: # save to disk if requested save_feature(self.root_dir, name, values)
[docs] def register_feature_pipeline(self, pipeline: FeaturePipeline): """Register a feature pipeline with the RoiProcessor instance. pipeline is a FeaturePipeline object that defines a method to compute a feature based on the attributes of the RoiProcessor instance. The method should take the RoiProcessor instance as an argument and return a numpy array of feature values. The dependencies attribute of the pipeline object should be a list of strings indicating the attributes of the RoiProcessor instance that the method depends on. If any of these attributes are updated, the feature will be recomputed. Parameters ---------- pipeline : FeaturePipeline FeaturePipeline object that defines a method to compute a feature based on the attributes of the RoiProcessor instance. """ if not isinstance(pipeline, FeaturePipeline): raise TypeError("Pipeline must be an instance of FeaturePipeline") if ( pipeline.name in self.feature_pipeline_methods or pipeline.name in self.feature_pipeline_dependencies ): raise ValueError( f"A pipeline called {pipeline.name} has already been registered." ) if not all(dep in self.parameters for dep in pipeline.dependencies): raise ValueError( f"The following dependencies for pipeline {pipeline.name} not found in parameters ({', '.join(pipeline.dependencies)})" ) self.feature_pipeline_methods[pipeline.name] = pipeline.method self.feature_pipeline_dependencies[pipeline.name] = pipeline.dependencies
[docs] def update_parameters(self, **kwargs: dict): """Update preprocessing parameters and clear affected cache entries. Preprocessing parameters are used to compute properties of self that are cached upon first access, and also for feature computation. When parameters are updated, the cache entries that are affected by the change are cleared so they can be recomputed with the new parameters when accessed again. Features are automatically regenerated if they depend on the updated parameters and have already been computed. Parameter dependencies are indicated in the PARAM_CACHE_MAPPING dictionary. Feature dependencies are indicated in the feature_pipeline_dependencies dictionary. Parameters ---------- **kwargs : dict New values to update in the initial dictionary. Must be a subset of the keys in initial, otherwise a ValueError will be raised. Returns ------- dict Updated dictionary of parameters. """ # First check if any invalid parameters are provided extra_kwargs = set(kwargs) - set(self.parameters) if extra_kwargs: raise ValueError(f"Invalid parameter(s): {', '.join(extra_kwargs)}") # For every changed parameter, identify affected cache / features affected_cache = [] affected_features = [] for key, value in kwargs.items(): if key in self.parameters and self.parameters[key] != value: affected_cache.extend(PARAM_CACHE_MAPPING.get(key, [])) for ( pipeline, dependencies, ) in self.feature_pipeline_dependencies.items(): if key in dependencies: affected_features.append(pipeline) self.parameters[key] = value # Clear affected cache to be recomputed lazily whenever it is needed again for cache_key in set(affected_cache): self._cache.pop(cache_key, None) # Recompute affected features if they have already been computed for feature_key in set(affected_features): if feature_key in self.features: self.add_feature( feature_key, self.feature_pipeline_methods[feature_key](self) )
[docs] def copy_with_params(self, params: dict): """Create a new processor instance with updated parameters. Parameters ---------- params : dict New parameter values to update in the new instance. Must be a subset of the keys in DEFAULT_PARAMETERS, otherwise a ValueError will be raised. Returns ------- RoiProcessor New instance of RoiProcessor with updated parameters. """ copy_of_self = deepcopy(self) copy_of_self.update_parameters(**params) return copy_of_self
@property def centroids(self): """Return the centroids of the ROIs in each plane. Centroids are two lists of the y-centroid and x-centroid for each ROI, concatenated across planes. The centroid method is determined by the centroid_method attribute. Centroids are always returned as integers. Returns ------- Tuple[np.ndarray] Tuple of two numpy arrays, the y-centroids and x-centroids. """ if "centroids" not in self._cache: zc, yc, xc = utils.get_roi_centroids( self.lam, self.zpix, self.ypix, self.xpix, method=self.parameters["centroid_method"], asint=True, ) centroids = dict(zc=zc, yc=yc, xc=xc) self._cache["centroids"] = centroids return self._cache["centroids"] @property def centered_masks(self): """Return the centered mask images for each ROI. Returns ------- np.ndarray The centered mask images of each ROI. If volumetric=False, has shape (numROIs, centered_width*2+1, centered_width*2+1). If volumetric=True, has shape (numROIs, num_planes, centered_width*2+1, centered_width*2+1). """ if "centered_masks" not in self._cache: centered_masks = utils.get_centered_masks( self._flattened_roi_data, self.centroids, width=self.parameters["centered_width"], fill_value=self.parameters["fill_value"], num_planes=self.lz, volumetric=self.volumetric, ) self._cache["centered_masks"] = centered_masks return self._cache["centered_masks"] @property def centered_reference(self): """Return the centered reference image for each ROI. Returns ------- np.ndarray The centered reference image around each ROI. If volumetric=False, has shape (numROIs, centered_width*2+1, centered_width*2+1). If volumetric=True, has shape (numROIs, num_planes, centered_width*2+1, centered_width*2+1). """ if "centered_reference" not in self._cache: centered_reference = utils.get_centered_reference( self.reference, self.centroids, width=self.parameters["centered_width"], fill_value=self.parameters["fill_value"], volumetric=self.volumetric, ) self._cache["centered_reference"] = centered_reference return self._cache["centered_reference"] @property def filtered_reference(self): """Return the filtered reference image for each ROI. Uses a Butterworth bandpass filter to filter the reference image. Returns ------- np.ndarray The filtered reference image for each ROI, with shape (numROIs, lx, ly) """ if "filtered_reference" not in self._cache: bpf_parameters = dict( lowcut=self.parameters["lowcut"], highcut=self.parameters["highcut"], order=self.parameters["order"], ) filtered_reference = filter( self.reference, "butterworth_bpf", **bpf_parameters ) self._cache["filtered_reference"] = filtered_reference return self._cache["filtered_reference"] @property def filtered_centered_reference(self): """Return the filtered centered reference image for each ROI. Uses a Butterworth bandpass filter to filter the reference image, then generates a centered reference stack around each ROI using the filtered reference. Returns ------- np.ndarray The filtered centered reference image around each ROI, with shape (numROIs, centered_width*2+1, centered_width*2+1) """ if "filtered_centered_reference" not in self._cache: filtered_centered_reference = utils.get_centered_reference( self.filtered_reference, self.centroids, width=self.parameters["centered_width"], fill_value=self.parameters["fill_value"], volumetric=self.volumetric, ) self._cache["filtered_centered_reference"] = filtered_centered_reference return self._cache["filtered_centered_reference"] @property def centered_reference_functional(self): """Return the centered reference image for each ROI based on the functional reference image. Returns ------- np.ndarray The centered reference image around each ROI, with shape (numROIs, centered_width*2+1, centered_width*2+1) Raises ------ ValueError If functional reference are not available. """ if self.functional_reference is None: raise ValueError("Functional reference are not available") if "centered_reference_functional" not in self._cache: centered_reference = utils.get_centered_reference( self.functional_reference, self.centroids, width=self.parameters["centered_width"], fill_value=self.parameters["fill_value"], volumetric=self.volumetric, ) self._cache["centered_reference_functional"] = centered_reference return self._cache["centered_reference_functional"] @property def filtered_reference_functional(self): """Return the filtered reference image for each ROI based on the functional reference image. Uses a Butterworth bandpass filter to filter the reference image. Returns ------- np.ndarray The filtered reference image for each ROI, with shape (numROIs, lx, ly) Raises ------ ValueError If functional reference are not available. """ if self.functional_reference is None: raise ValueError("Functional reference are not available") if "filtered_reference_functional" not in self._cache: bpf_parameters = dict( lowcut=self.parameters["lowcut"], highcut=self.parameters["highcut"], order=self.parameters["order"], ) filtered_reference = filter( self.functional_reference, "butterworth_bpf", **bpf_parameters, ) self._cache["filtered_reference_functional"] = filtered_reference return self._cache["filtered_reference_functional"] @property def filtered_centered_reference_functional(self): """Return the filtered centered reference image for each ROI based on the functional reference image. Uses a Butterworth bandpass filter to filter the reference image, then generates a centered reference stack around each ROI using the filtered reference. Returns ------- np.ndarray The filtered centered reference image around each ROI, with shape (numROIs, centered_width*2+1, centered_width*2+1) Raises ------ ValueError If functional reference are not available. """ if self.functional_reference is None: raise ValueError("Functional reference are not available") if "filtered_centered_reference_functional" not in self._cache: filtered_centered_reference = utils.get_centered_reference( self.filtered_reference_functional, self.centroids, width=self.parameters["centered_width"], fill_value=self.parameters["fill_value"], volumetric=self.volumetric, ) self._cache["filtered_centered_reference_functional"] = ( filtered_centered_reference ) return self._cache["filtered_centered_reference_functional"]