Source code for deepblink.models.spots

"""SpotsModel class."""

import functools

import numpy as np

from ..augment import augment_batch_baseline
from ..losses import combined_f1_rmse
from ..losses import f1_score
from ..losses import rmse
from ._models import Model


[docs]class SpotsModel(Model): """Class to predict spot localization; see base class.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.batch_augment_fn = functools.partial( augment_batch_baseline, flip_=self.augmentation_args["flip"], illuminate_=self.augmentation_args["illuminate"], gaussian_noise_=self.augmentation_args["gaussian_noise"], rotate_=self.augmentation_args["rotate"], translate_=self.augmentation_args["translate"], cell_size=self.dataset_args["cell_size"], ) @property def metrics(self) -> list: """List of all metrics recorded during training.""" return [ f1_score, rmse, combined_f1_rmse, ]
[docs] def predict_on_image(self, image: np.ndarray) -> np.ndarray: """Predict on a single input image.""" return self.network.predict(image[None, ..., None], batch_size=1).squeeze()