#!/usr/bin/env python3
"""Script for freezing TF trained graph so it can be used with LAMMPS and i-PI.
References
----------
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
"""
import logging
import google.protobuf.message
from deepmd.env import tf, FITTING_NET_PATTERN, REMOVE_SUFFIX_DICT
from deepmd.utils.errors import GraphTooLargeError
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import get_pattern_nodes_from_graph_def
from os.path import abspath
import json
# load grad of force module
import deepmd.op
from typing import List, Optional
from deepmd.nvnmd.entrypoints.freeze import save_weight
__all__ = ["freeze"]
log = logging.getLogger(__name__)
def _transfer_fitting_net_trainable_variables(sess, old_graph_def, raw_graph_def):
old_pattern = FITTING_NET_PATTERN
raw_pattern = FITTING_NET_PATTERN\
.replace('idt', r'idt+_\d+')\
.replace('bias', r'bias+_\d+')\
.replace('matrix', r'matrix+_\d+')
old_graph_nodes = get_pattern_nodes_from_graph_def(
old_graph_def,
old_pattern
)
try :
raw_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
raw_graph_def, # The graph_def is used to retrieve the nodes
[n + '_1' for n in old_graph_nodes], # The output node names are used to select the usefull nodes
)
except AssertionError:
# if there's no additional nodes
return old_graph_def
raw_graph_nodes = get_pattern_nodes_from_graph_def(
raw_graph_def,
raw_pattern
)
for node in old_graph_def.node:
if node.name not in old_graph_nodes.keys():
continue
tensor = tf.make_ndarray(raw_graph_nodes[node.name + '_1'])
node.attr["value"].tensor.tensor_content = tensor.tostring()
return old_graph_def
def _remove_fitting_net_suffix(output_graph_def, out_suffix):
"""Remove fitting net suffix for multi-task mode
Parameters
----------
output_graph_def : tf.GraphDef
The output graph to remove suffix.
out_suffix : str
The suffix to remove.
"""
def change_name(name, suffix):
if suffix in name:
for item in REMOVE_SUFFIX_DICT:
if item.format(suffix) in name:
name = name.replace(item.format(suffix), REMOVE_SUFFIX_DICT[item])
break
assert suffix not in name, 'fitting net name illegal!'
return name
for node in output_graph_def.node:
if out_suffix in node.name:
node.name = change_name(node.name, out_suffix)
for idx in range(len(node.input)):
if out_suffix in node.input[idx]:
node.input[idx] = change_name(node.input[idx], out_suffix)
attr_list = node.attr['_class'].list.s
for idx in range(len(attr_list)):
if out_suffix in bytes.decode(attr_list[idx]):
attr_list[idx] = bytes(change_name(bytes.decode(attr_list[idx]), out_suffix), encoding='utf8')
return output_graph_def
def _modify_model_suffix(output_graph_def, out_suffix, freeze_type):
"""Modify model suffix in graph nodes for multi-task mode, including fitting net, model attr and training script.
Parameters
----------
output_graph_def : tf.GraphDef
The output graph to remove suffix.
out_suffix : str
The suffix to remove.
freeze_type : str
The model type to freeze.
"""
output_graph_def = _remove_fitting_net_suffix(
output_graph_def,
out_suffix
)
for node in output_graph_def.node:
if 'model_attr/model_type' in node.name:
node.attr['value'].tensor.string_val[0] = bytes(freeze_type, encoding='utf8')
# change the input script for frozen model
elif 'train_attr/training_script' in node.name:
jdata = json.loads(node.attr['value'].tensor.string_val[0])
# fitting net
assert out_suffix in jdata['model']['fitting_net_dict']
jdata['model']['fitting_net'] = jdata['model'].pop('fitting_net_dict')[out_suffix]
# data systems
systems = jdata['training'].pop('data_dict')
if out_suffix in systems:
jdata['training']['training_data'] = systems[out_suffix]['training_data']
if 'validation_data' in systems[out_suffix]:
jdata['training']['validation_data'] = systems[out_suffix]['validation_data']
else:
jdata['training']['training_data'] = {}
log.warning('The fitting net {} has no training data in input script, resulting in '
'untrained frozen model, and cannot be compressed directly! '.format(out_suffix))
# loss
if 'loss_dict' in jdata:
loss_dict = jdata.pop('loss_dict')
if out_suffix in loss_dict:
jdata['loss'] = loss_dict[out_suffix]
# fitting weight
if 'fitting_weight' in jdata['training']:
jdata['training'].pop('fitting_weight')
node.attr['value'].tensor.string_val[0] = bytes(json.dumps(jdata), encoding='utf8')
return output_graph_def
def _make_node_names(model_type: str, modifier_type: Optional[str] = None, out_suffix: str = '') -> List[str]:
"""Get node names based on model type.
Parameters
----------
model_type : str
str type of model
modifier_type : Optional[str], optional
modifier type if any, by default None
out_suffix : str
suffix for output nodes
Returns
-------
List[str]
list with all node names to freeze
Raises
------
RuntimeError
if unknown model type
"""
nodes = [
"model_type",
"descrpt_attr/rcut",
"descrpt_attr/ntypes",
"model_attr/tmap",
"model_attr/model_type",
"model_attr/model_version",
"train_attr/min_nbor_dist",
"train_attr/training_script",
]
if model_type == "ener":
nodes += [
"o_energy",
"o_force",
"o_virial",
"o_atom_energy",
"o_atom_virial",
"fitting_attr/dfparam",
"fitting_attr/daparam",
]
elif model_type == "wfc":
nodes += [
"o_wfc",
"model_attr/sel_type",
"model_attr/output_dim",
]
elif model_type == "dipole":
nodes += [
"o_dipole",
"o_global_dipole",
"o_force",
"o_virial",
"o_atom_virial",
"o_rmat",
"o_rmat_deriv",
"o_nlist",
"o_rij",
"descrpt_attr/sel",
"descrpt_attr/ndescrpt",
"model_attr/sel_type",
"model_attr/output_dim",
]
elif model_type == "polar":
nodes += [
"o_polar",
"o_global_polar",
"o_force",
"o_virial",
"o_atom_virial",
"model_attr/sel_type",
"model_attr/output_dim",
]
elif model_type == "global_polar":
nodes += [
"o_global_polar",
"model_attr/sel_type",
"model_attr/output_dim",
]
else:
raise RuntimeError(f"unknow model type {model_type}")
if modifier_type == "dipole_charge":
nodes += [
"modifier_attr/type",
"modifier_attr/mdl_name",
"modifier_attr/mdl_charge_map",
"modifier_attr/sys_charge_map",
"modifier_attr/ewald_h",
"modifier_attr/ewald_beta",
"dipole_charge/model_type",
"dipole_charge/descrpt_attr/rcut",
"dipole_charge/descrpt_attr/ntypes",
"dipole_charge/model_attr/tmap",
"dipole_charge/model_attr/model_type",
"dipole_charge/model_attr/model_version",
"o_dm_force",
"dipole_charge/model_attr/sel_type",
"dipole_charge/o_dipole",
"dipole_charge/model_attr/output_dim",
"o_dm_virial",
"o_dm_av",
]
if out_suffix != '':
for ind in range(len(nodes)):
if (nodes[ind][:2] == 'o_' and nodes[ind] not in ["o_rmat", "o_rmat_deriv", "o_nlist", "o_rij"]) \
or nodes[ind] == "model_attr/sel_type" \
or nodes[ind] == "model_attr/output_dim":
nodes[ind] += '_{}'.format(out_suffix)
elif 'fitting_attr' in nodes[ind]:
content = nodes[ind].split('/')[1]
nodes[ind] = 'fitting_attr_{}/{}'.format(out_suffix, content)
return nodes
def freeze_graph(sess, input_graph, input_node, freeze_type, modifier, out_graph_name, node_names=None, out_suffix=''):
"""Freeze the single graph with chosen out_suffix.
Parameters
----------
sess : tf.Session
The default session.
input_graph : tf.GraphDef
The input graph_def stored from the checkpoint.
input_node : List[str]
The expected nodes to freeze.
freeze_type : str
The model type to freeze.
modifier : Optional[str], optional
Modifier type if any, by default None.
out_graph_name : str
The output graph.
node_names : Optional[str], optional
Names of nodes to output, by default None.
out_suffix : str
The chosen suffix to freeze in the input_graph.
"""
if node_names is None:
output_node = _make_node_names(freeze_type, modifier, out_suffix=out_suffix)
different_set = set(output_node) - set(input_node)
if different_set:
log.warning(
"The following nodes are not in the graph: %s. "
"Skip freezeing these nodes. You may be freezing "
"a checkpoint generated by an old version." % different_set
)
# use intersection as output list
output_node = list(set(output_node) & set(input_node))
else:
output_node = node_names.split(",")
log.info(f"The following nodes will be frozen: {output_node}")
# We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph, # The graph_def is used to retrieve the nodes
output_node, # The output node names are used to select the usefull nodes
)
# if multi-task, change fitting_net suffix and model_type
if out_suffix != '':
output_graph_def = _modify_model_suffix(output_graph_def, out_suffix, freeze_type)
# If we need to transfer the fitting net variables
output_graph_def = _transfer_fitting_net_trainable_variables(
sess,
output_graph_def,
input_graph
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(out_graph_name, "wb") as f:
f.write(output_graph_def.SerializeToString())
log.info(f"{len(output_graph_def.node):d} ops in the final graph.")
def freeze_graph_multi(sess, input_graph, input_node, modifier, out_graph_name, node_names):
"""Freeze multiple graphs for multi-task model.
Parameters
----------
sess : tf.Session
The default session.
input_graph : tf.GraphDef
The input graph_def stored from the checkpoint.
input_node : List[str]
The expected nodes to freeze.
modifier : Optional[str], optional
Modifier type if any, by default None.
out_graph_name : str
The output graph.
node_names : Optional[str], optional
Names of nodes to output, by default None.
"""
input_script = json.loads(run_sess(sess, "train_attr/training_script:0", feed_dict={}))
assert 'model' in input_script.keys() and 'fitting_net_dict' in input_script['model']
for fitting_key in input_script['model']['fitting_net_dict']:
fitting_type = input_script['model']['fitting_net_dict'][fitting_key]['type']
if out_graph_name[-3:] == '.pb':
output_graph_item = out_graph_name[:-3] + '_{}.pb'.format(fitting_key)
else:
output_graph_item = out_graph_name + '_{}'.format(fitting_key)
freeze_graph(sess, input_graph, input_node, fitting_type, modifier, output_graph_item, node_names,
out_suffix=fitting_key)
[docs]def freeze(
*, checkpoint_folder: str, output: str, node_names: Optional[str] = None, nvnmd_weight: Optional[str] = None, **kwargs
):
"""Freeze the graph in supplied folder.
Parameters
----------
checkpoint_folder : str
location of the folder with model
output : str
output file name
node_names : Optional[str], optional
names of nodes to output, by default None
"""
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(checkpoint_folder)
input_checkpoint = checkpoint.model_checkpoint_path
# expand the output file to full path
output_graph = abspath(output)
# Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep
# and what part it can dump
# NOTE: this variable is plural, because you can have multiple output nodes
# node_names = "energy_test,force_test,virial_test,t_rcut"
# We clear devices to allow TensorFlow to control
# on which device it will load operations
clear_devices = True
# We import the meta graph and retrieve a Saver
try:
# In case paralle training
import horovod.tensorflow as _
except ImportError:
pass
saver = tf.train.import_meta_graph(
f"{input_checkpoint}.meta", clear_devices=clear_devices
)
# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
try:
input_graph_def = graph.as_graph_def()
except google.protobuf.message.DecodeError as e:
raise GraphTooLargeError(
"The graph size exceeds 2 GB, the hard limitation of protobuf."
" Then a DecodeError was raised by protobuf. You should "
"reduce the size of your model."
) from e
nodes = [n.name for n in input_graph_def.node]
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
model_type = run_sess(sess, "model_attr/model_type:0", feed_dict={}).decode("utf-8")
if "modifier_attr/type" in nodes:
modifier_type = run_sess(sess, "modifier_attr/type:0", feed_dict={}).decode(
"utf-8"
)
else:
modifier_type = None
if nvnmd_weight is not None:
save_weight(sess, nvnmd_weight) # nvnmd
if model_type != 'multi-task':
freeze_graph(sess, input_graph_def, nodes, model_type, modifier_type, output_graph, node_names)
else:
freeze_graph_multi(sess, input_graph_def, nodes, modifier_type, output_graph, node_names)