Source code for deepblink.datasets.sequence

"""SequenceDataset class."""

from typing import Callable, Tuple
import warnings

import numpy as np
import tensorflow as tf

from ..util import relative_shuffle

[docs]class SequenceDataset(tf.keras.utils.Sequence): """Custom Sequence class used to feed data into Args: x_list: List of inputs. y_list: List of targets. batch_size: Size of one mini-batch. augment_fn: Function to augment one mini-batch of x and y. format_fn: Function to format raw data to model input. overfit: If only one batch should be used thereby causing overfitting. """ def __init__( self, x: np.ndarray, y: np.ndarray, batch_size: int = 16, augment_fn: Callable = None, format_fn: Callable = None, overfit: bool = False, ): self.x = x self.y = y self.batch_size = batch_size self.augment_fn = augment_fn self.format_fn = format_fn self.overfit = overfit def __len__(self) -> int: """Return length of the dataset in unit of batch size.""" if len(self.x) <= self.batch_size: warnings.warn( "Batch size larger than dataset, setting batch size to match length of dataset", RuntimeWarning, ) self.batch_size = len(self.x) return int(np.floor(len(self.x) / self.batch_size)) def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: """Return a single batch.""" if self.overfit: idx = 0 begin = idx * self.batch_size end = (idx + 1) * self.batch_size batch_x = self.x[begin:end] batch_y = self.y[begin:end] if self.format_fn: batch_x, batch_y = self.format_fn(batch_x, batch_y) if self.augment_fn: batch_x, batch_y = self.augment_fn(batch_x, batch_y) if batch_x.ndim < 4: batch_x = np.expand_dims(batch_x, -1) if batch_y.ndim < 4: batch_y = np.expand_dims(batch_y, -1) return batch_x, batch_y
[docs] def on_epoch_end(self) -> None: """Shuffle data after every epoch.""" if not self.overfit: self.x, self.y = relative_shuffle(self.x, self.y)