济南网站定制策划软文营销的成功案例
1.什么是Hooks?
中文直译为“钩子”,在tensorflow中概念:Hooks are tools that run in the process of training/evaluation of the model.*
Hooks是模型训练/测试过程中的工具,这些工具用于在训练/评估过程中执行特定任务。例如:
- 控制训练EarlyStopping
- 改变学习率
- 打印一些中间日志,如loss、auc等
- 保存checkpoint
- …
这些hooks可以在以下几个地方生效:
- when a session starts being used
- before a call to the
session.run()
- after a call to the
session.run()
- when the session closed
2.怎么定义Hooks?
在tensorflow中,tf.training.SessionRunHook类及其派生类负责创建hooks,tf.training.SessionRunHook有5个接口函数,分别是begin, after_create_session, before_run, after_run, end。自定义一个Hook类:
class ExampleHook(SessionRunHook):def __init__(self):# Yor can init the hook heredef begin(self): """在创建会话之前调用调用begin()时,default graph会被创建,可在此处向default graph增加新op,begin()调用后,default graph不能再被修改"""print('Starting the session.')self.your_tensor = ...def after_create_session(self, session, coord):"""tf.Session被创建后调用调用后会指示所有的Hooks有一个新的会话被创建Args:session: A TensorFlow Session that has been created.coord: A Coordinator object which keeps track of all threads."""# When this is called, the graph is finalized and# ops can no longer be added to the graph.print('Session created.')def before_run(self, run_context):"""在每个sess.run()执行之前调用返回一个tf.train.SessRunArgs(fetches, feed_dict),fetches、feed_dict和sess.run()里概念一样。实际上它们会和sess.run()中已定义的fetches和feed_dict合并一起执行。Args:run_context: A `SessionRunContext` object, 包含session的一些信息"""print('Before calling session.run().')return SessionRunArgs(self.your_tensor)def after_run(self, run_context, run_values):"""在每个sess.run()之后调用参数run_values是befor_run()中要求的op/tensor的返回值;可以调用run_context.qeruest_stop()用于停止迭代sess.run抛出任何异常after_run不会被调用"""print('Done running one step. The value of my tensor: %s', run_values.results)if you-need-to-stop-loop:run_context.request_stop()def end(self, session):print('Done with the session.')
除了自定义Hooks外,estimator有几个预制好的Hooks类:
- StopAtStepHook: Request stop based on global_step
- CheckpointSaverHook: saves checkpoint
- LoggingTensorHook: outputs one or more tensor values to log
- NanTensorHook: Request stop if given
Tensor
contains Nans. - SummarySaverHook: saves summaries to a summary writer
3.怎么执行Hooks
Hooks由 MonitoredSession.run()调用,具体方式:
hook1 = ExampleHook()
hook2 = CheckpointSaverHook()
your_hooks = [hook1, hook2]
with MonitoredTrainingSession(hooks=your_hooks, ...) as sess:while not sess.should_stop():sess.run(your_fetches)
其背后大概执行流程是这样的:
call hooks.begin()
sess = tf.compat.v1.Session()
call hooks.after_create_session()while not stop is requested:call hooks.before_run()try:results = sess.run(merged_fetches, feed_dict=merged_feeds)except (errors.OutOfRangeError, StopIteration):breakcall hooks.after_run()
call hooks.end()
sess.close()
给个具体的例子(from qq924178473:https://blog.csdn.net/h_jlwg6688/article/details/117514323):
# 定义自己的hook类,实现每个step执行后打印日志
class YourOwnHook(tf.train.SessionRunHook):def __init__(self):np.set_printoptions(suppress=True)np.set_printoptions(linewidth=400)def before_run(self, run_context):"""返回SessionRunArgs和session run一起跑"""v1 = tf.get_collection('logis')prob = tf.get_collection('prob')return tf.train.SessionRunArgs(fetches=[v1, prob])def after_run(self, run_context, run_values):v1, batch_labels = run_values.resultslogger.info("logis value:{}".format(v1))print("prob :",batch_labels)# 实现estimator
class MyEstimator(tf.estimator.Estimator):def __init__(self,model_dir,hidden_units,optimizer,activation_fn,dropout=None,batch_norm=False,weight_column=None,label_vocabulary=None,loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,params=None,config=None,warm_start_from=None):def model_fn(features,labels,mode):inputs_layers =tf.feature_column.input_layer(features,feature_columns)# 自定义网络层user_hidden_fn = DNNModel(hidden_units=hidden_units,activation_fn=activation_fn,dropout=dropout,batch_norm=batch_norm,name="user_dnn")user_hidden_net = user_hidden_fn(inputs_layers,mode=mode)with tf.name_scope("logits"):logits = tf.keras.layers.Dense(units=2, activation=None)(user_hidden_net)loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(labels['label'],[-1]),logits=logits))train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())# Compute predictions.predicted_classes = tf.argmax(logits, 1)# 设置模型评价指标accuracy = tf.metrics.accuracy(labels=labels["label"],predictions=predicted_classes,name='acc_op')auc = tf.metrics.auc(labels=labels["label"],predictions=predicted_classes,name='auc_op')metrics = {'accuracy': accuracy,'auc':auc}tf.summary.scalar('accuracy', accuracy[1])if mode==tf.estimator.ModeKeys.TRAIN:# 定义自定义钩子函数,并设置要输出的中间值的名称ownhook = YourOwnHook()tf.add_to_collection('logis', logits)tf.add_to_collection('prob',predicted_classes)# 将自定义钩子添加到训练的estimator中return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,training_hooks=[ownhook])if mode == tf.estimator.ModeKeys.EVAL:return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,loss=loss,eval_metric_ops=metrics)super(MyEstimator,self).__init__(model_fn=model_fn,model_dir=model_dir,params=params,config=config,warm_start_from=warm_start_from)
Reference
session_run_hook.py源码
Hook? tf.train.SessionRunHook()介绍【精】
TensorFlow系列——在自定义的标准estimator中使用tensorboard及打印中间数据