Source code for deepblink.training

"""Training functions."""
# pylint: disable=C0415

from typing import Dict
import datetime
import os
import platform

import tensorflow as tf

from .datasets import Dataset
from .models import Model
from .util import get_from_module


[docs]def train_model( model: Model, dataset: Dataset, cfg: Dict, run_name: str = "model", use_wandb: bool = True, ) -> Model: """Model training loop with callbacks. Args: model: Model class with the .fit method. dataset: Dataset class with access to train and validation images. cfg: Configuration file equivalent to the one used in pink.training.run_experiment. run_name: Name given to the model.h5 file saved. use_wandb: If Wandb should be used. """ callbacks = [] cb_saver = tf.keras.callbacks.ModelCheckpoint( os.path.join(cfg["savedir"], f"{run_name}.h5"), save_best_only=True, ) callbacks.append(cb_saver) if use_wandb: from ._wandb import WandbComputeMetrics from ._wandb import WandbImageLogger from ._wandb import wandb_callback cb_image = WandbImageLogger(model, dataset) cb_wandb = wandb_callback() cb_metrics = WandbComputeMetrics(model, dataset, mdist=3) callbacks.extend([cb_image, cb_wandb, cb_metrics]) model.fit(dataset=dataset, callbacks=callbacks) return model
[docs]def run_experiment(cfg: Dict, pre_model: tf.keras.models.Model = None): """Run a training experiment. Configuration file can be generated using deepblink config. Args: cfg: Dictionary configuration file. pre_model: Pre-trained model if not training from scratch. """ # Classes / functions dataset_class = get_from_module("deepblink.datasets", cfg["dataset"]) model_class = get_from_module("deepblink.models", cfg["model"]) network_fn = get_from_module("deepblink.networks", cfg["network"]) optimizer_fn = get_from_module("deepblink.optimizers", cfg["optimizer"]) loss_fn = get_from_module("deepblink.losses", cfg["loss"]) # Arguments augmentation_args = cfg.get("augmentation_args", {}) dataset_args = cfg.get("dataset_args", {}) dataset = dataset_class(**dataset_args) network_args = ( cfg.get("network_args", {}) if cfg.get("network_args", {}) is not None else {} ) network_args["cell_size"] = dataset_args["cell_size"] train_args = cfg.get("train_args", {}) model = model_class( augmentation_args=augmentation_args, dataset_args=dataset_args, dataset_cls=dataset, loss_fn=loss_fn, network_args=network_args, network_fn=network_fn, optimizer_fn=optimizer_fn, train_args=train_args, pre_model=pre_model, ) cfg["system"] = { "gpus": tf.config.list_logical_devices("GPU"), "version": platform.version(), "platform": platform.platform(), } now = datetime.datetime.now().strftime("%y%m%d_%H%M%S") run_name = f"{now}_{cfg['run_name']}" use_wandb = cfg["use_wandb"] if use_wandb: try: import wandb if wandb.__version__ <= "0.10.03": raise AssertionError except (ModuleNotFoundError, AttributeError, AssertionError): raise ImportError( ( "To support conda packages we don't ship deepBlink with wandb. " "Please install any using pip: 'pip install \"wandb>=0.10.3\"'" ) ) # pylint:disable=E1101 wandb.init(name=run_name, project=cfg["name"], config=cfg) model = train_model(model, dataset, cfg, run_name, use_wandb) if use_wandb: wandb.join() # pylint:disable=E1101