Source code for deepmd.fit.fitting

from typing import (
    Callable,
)

from deepmd.env import (
    tf,
)
from deepmd.utils import (
    Plugin,
    PluginVariant,
)


[docs]class Fitting(PluginVariant): __plugins = Plugin()
[docs] @staticmethod def register(key: str) -> Callable: """Register a Fitting plugin. Parameters ---------- key : str the key of a Fitting Returns ------- Fitting the registered Fitting Examples -------- >>> @Fitting.register("some_fitting") class SomeFitting(Fitting): pass """ return Fitting.__plugins.register(key)
def __new__(cls, *args, **kwargs): if cls is Fitting: try: fitting_type = kwargs["type"] except KeyError: raise KeyError("the type of fitting should be set by `type`") if fitting_type in Fitting.__plugins.plugins: cls = Fitting.__plugins.plugins[fitting_type] else: raise RuntimeError("Unknown descriptor type: " + fitting_type) return super().__new__(cls) @property def precision(self) -> tf.DType: """Precision of fitting network.""" return self.fitting_precision
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, suffix: str = "", ) -> None: """Init the fitting net variables with the given dict. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def suffix : str suffix to name scope Notes ----- This method is called by others when the fitting supported initialization from the given variables. """ raise NotImplementedError( "Fitting %s doesn't support initialization from the given variables!" % type(self).__name__ )