Source code for deepblink.datasets.spots

"""SpotsDataset class."""

import numpy as np

from ..data import get_prediction_matrix
from ..data import next_power
from ..data import normalize_image
from ..io import load_npz
from ._datasets import Dataset


[docs]class SpotsDataset(Dataset): """Class used to load all spots data. Args: cell_size: Number of pixels (from original image) constituting one cell in the prediction matrix. smooth_factor: Value used to weigh true cells, weighs false cells with 1-smooth_factor. """ def __init__(self, name: str, cell_size: int, smooth_factor: float = 1): super().__init__(name) self.cell_size = cell_size self.smooth_factor = smooth_factor self.load_data()
[docs] def load_data(self) -> None: """Load dataset into memory.""" self.x_train, self.y_train, self.x_valid, self.y_valid, _, _ = load_npz( self.data_filename ) self.prepare_data() self.normalize_dataset()
@property def image_size(self): """Check if all images have the same square shape.""" base_shape = self.x_train[0].shape if not all( base_shape == x.shape for dataset in [self.x_train, self.x_valid] for x in dataset ): raise ValueError("All images must have the same shape.") if not base_shape[0] == base_shape[1]: raise ValueError("Images must be square. ") if not base_shape[0] == next_power(base_shape[0]): raise ValueError( f"Images sidelength must be a power of two. {base_shape[0]} is not." ) return base_shape[0]
[docs] def prepare_data(self) -> None: """Convert raw labels into labels usable for training. In the "spots" format, training labels are stored as lists of coordinates, this format cannot be used for training. Here, this format is converted into prediction matrices. """ def __convert(dataset, image_size, cell_size): labels = [] for coords in dataset: matrix = get_prediction_matrix(coords, image_size, cell_size) matrix[..., 0] = np.where( matrix[..., 0], self.smooth_factor, 1 - self.smooth_factor ) labels.append(matrix) return np.array(labels) self.y_train = __convert(self.y_train, self.image_size, self.cell_size) self.y_valid = __convert(self.y_valid, self.image_size, self.cell_size)
[docs] def normalize_dataset(self) -> None: """Normalize all the images to have zero mean and standard deviation 1.""" def __normalize(dataset): return np.array([normalize_image(image) for image in dataset]) self.x_train = __normalize(self.x_train) self.x_valid = __normalize(self.x_valid)