Source code for dffml.tuner.parameter_grid

from typing import Union, Dict, Any
import itertools
import logging

from ..base import (
    config,
    field,
)
from ..high_level.ml import train, score
from .tuner import Tuner, TunerContext
from ..util.entrypoint import entrypoint
from ..source.source import BaseSource, Record
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
from ..model import ModelContext
from ..feature.feature import Feature


[docs]@config class ParameterGridConfig: parameters: dict = field("Parameters to be optimized")
[docs]class ParameterGridContext(TunerContext): """ Parameter Grid Tuner """
[docs] async def optimize( self, model: ModelContext, feature: Feature, accuracy_scorer: Union[AccuracyScorer, AccuracyContext], train_data: Union[BaseSource, Record, Dict[str, Any]], test_data: Union[BaseSource, Record, Dict[str, Any]], ): """ Method to optimize hyperparameters by parameter grid. Uses a grid of hyperparameters in the form of a dictionary present in config, Trains each permutation of the grid of parameters and compares accuracy. Sets model to the best parameters and returns highest accuracy. Parameters ---------- model : ModelContext The Model which needs to be used. feature : Feature The Target feature in the data. accuracy_scorer: AccuracyContext The accuracy scorer that needs to be used. train_data: SourcesContext The train_data to train models on with the hyperparameters provided. sources : SourcesContext The test_data to score against and optimize hyperparameters. Returns ------- float The highest score value """ highest_acc = -1 best_config = dict() logging.info( f"Optimizing model with parameter grid: {self.parent.config.parameters}" ) names = list(self.parent.config.parameters.keys()) logging.info(names) with model.config.no_enforce_immutable(): for combination in itertools.product( *list(self.parent.config.parameters.values()) ): logging.info(combination) for i in range(len(combination)): param = names[i] setattr(model.config, names[i], combination[i]) await train(model, *train_data) acc = await score(model, accuracy_scorer, feature, *test_data) logging.info(f"Accuracy of the tuned model: {acc}") if acc > highest_acc: highest_acc = acc for param in names: best_config[param] = getattr(model.config, param) for param in names: setattr(model.config, param, best_config[param]) await train(model, *train_data) logging.info(f"\nOptimal Hyper-parameters: {best_config}") logging.info(f"Accuracy of Optimized model: {highest_acc}") return highest_acc
[docs]@entrypoint("parameter_grid") class ParameterGrid(Tuner): CONFIG = ParameterGridConfig CONTEXT = ParameterGridContext