关于ZAKER 融媒体解决方案 合作 加入

如何将 Keras .h5 导出到 tensorflow .pb?

CocoaChina 09-19

Keras 本身不包括将 TensorFlow 图导出为协议缓冲区文件的任何方法 , 但您可以使用常规 TensorFlow 实用程序来完成 . Here是一篇博客文章 , 解释了如何使用 TensorFlow 中包含的实用程序脚本freeze_graph.py来完成它 , 这是它的 " 典型 " 方式 .

但是 , 我个人觉得必须制作检查点 , 然后运行外部脚本来获取模型 , 而不是喜欢从我自己的 Python 代码中执行它 , 所以我使用这样的函数:

def freeze_session ( session, keep_var_names=None, output_names=None, clear_devices=True ) : """ Freezes the state of a session into a pruned computation graph. Creates a new computation graph where variable nodes are replaced by constants taking their current value in the session. The new graph will be pruned so subgraphs that are not necessary to compute the requested outputs are removed. @param session The TensorFlow session to be frozen. @param keep_var_names A list of variable names that should not be frozen, or None to freeze all the variables in the graph. @param output_names Names of the relevant graph outputs. @param clear_devices Remove the device directives from the graph for better portability. @return The frozen graph definition. """ graph = session.graph with graph.as_default ( ) : freeze_var_names = list ( set ( v.op.name for v in tf.global_variables ( ) ) .difference ( keep_var_names or [ ] ) ) output_names = output_names or [ ] output_names += [ v.op.name for v in tf.global_variables ( ) ] input_graph_def = graph.as_graph_def ( ) if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = tf.graph_util.convert_variables_to_constants ( session, input_graph_def, output_names, freeze_var_names ) return frozen_graph

这是在 freeze_graph.py 的实现中受到启发的 . 参数也类似于脚本 . session 是 TensorFlow 会话对象 . 只有当你想保留一些未冻结的变量时 ( 例如有状态模型 ) 才需要 keep_var_names, 所以通常不需要 . output_names 是一个列表 , 其中包含生成所需输出的操作的名称 . clear_devices 只删除任何设备指令 , 使图形更具可移植性 . 因此 , 对于具有一个输出的典型 Keras 模型 , 您可以执行以下操作:

from keras import backend as K# Create, compile and train model...frozen_graph = freeze_session ( K.get_session ( ) , output_names= [ out.op.name for out in model.outputs ] )

然后您可以像往常一样将图形写入文件tf.train.write_graph

tf.train.write_graph ( frozen_graph, "some_directory", "my_model.pb", as_text=False )

以上内容由"CocoaChina"上传发布 查看原文
相关标签 python

觉得文章不错,微信扫描分享好友

扫码分享