当前位置: 首页 > news >正文

济南网站定制策划软文营销的成功案例

济南网站定制策划,软文营销的成功案例,百度排名 网站标题,前端培训靠谱吗1.什么是Hooks? 中文直译为“钩子”,在tensorflow中概念:Hooks are tools that run in the process of training/evaluation of the model.* Hooks是模型训练/测试过程中的工具,这些工具用于在训练/评估过程中执行特定任务。例如…

1.什么是Hooks?

中文直译为“钩子”,在tensorflow中概念:Hooks are tools that run in the process of training/evaluation of the model.*
Hooks是模型训练/测试过程中的工具,这些工具用于在训练/评估过程中执行特定任务。例如:

  • 控制训练EarlyStopping
  • 改变学习率
  • 打印一些中间日志,如loss、auc等
  • 保存checkpoint

这些hooks可以在以下几个地方生效:

  • when a session starts being used
  • before a call to the session.run()
  • after a call to the session.run()
  • when the session closed

2.怎么定义Hooks?

在tensorflow中,tf.training.SessionRunHook类及其派生类负责创建hooks,tf.training.SessionRunHook有5个接口函数,分别是begin, after_create_session, before_run, after_run, end。自定义一个Hook类:

class ExampleHook(SessionRunHook):def __init__(self):# Yor can init the hook heredef begin(self):	"""在创建会话之前调用调用begin()时,default graph会被创建,可在此处向default graph增加新op,begin()调用后,default graph不能再被修改"""print('Starting the session.')self.your_tensor = ...def after_create_session(self, session, coord):"""tf.Session被创建后调用调用后会指示所有的Hooks有一个新的会话被创建Args:session: A TensorFlow Session that has been created.coord: A Coordinator object which keeps track of all threads."""# When this is called, the graph is finalized and# ops can no longer be added to the graph.print('Session created.')def before_run(self, run_context):"""在每个sess.run()执行之前调用返回一个tf.train.SessRunArgs(fetches, feed_dict),fetches、feed_dict和sess.run()里概念一样。实际上它们会和sess.run()中已定义的fetches和feed_dict合并一起执行。Args:run_context: A `SessionRunContext` object, 包含session的一些信息"""print('Before calling session.run().')return SessionRunArgs(self.your_tensor)def after_run(self, run_context, run_values):"""在每个sess.run()之后调用参数run_values是befor_run()中要求的op/tensor的返回值;可以调用run_context.qeruest_stop()用于停止迭代sess.run抛出任何异常after_run不会被调用"""print('Done running one step. The value of my tensor: %s', run_values.results)if you-need-to-stop-loop:run_context.request_stop()def end(self, session):print('Done with the session.')

除了自定义Hooks外,estimator有几个预制好的Hooks类:

  • StopAtStepHook: Request stop based on global_step
  • CheckpointSaverHook: saves checkpoint
  • LoggingTensorHook: outputs one or more tensor values to log
  • NanTensorHook: Request stop if given Tensor contains Nans.
  • SummarySaverHook: saves summaries to a summary writer

3.怎么执行Hooks

Hooks由 MonitoredSession.run()调用,具体方式:

hook1 = ExampleHook()
hook2 = CheckpointSaverHook()
your_hooks = [hook1, hook2]
with MonitoredTrainingSession(hooks=your_hooks, ...) as sess:while not sess.should_stop():sess.run(your_fetches)

其背后大概执行流程是这样的:

call hooks.begin()
sess = tf.compat.v1.Session()
call hooks.after_create_session()while not stop is requested:call hooks.before_run()try:results = sess.run(merged_fetches, feed_dict=merged_feeds)except (errors.OutOfRangeError, StopIteration):breakcall hooks.after_run()
call hooks.end()
sess.close()

给个具体的例子(from qq924178473:https://blog.csdn.net/h_jlwg6688/article/details/117514323):

# 定义自己的hook类,实现每个step执行后打印日志
class YourOwnHook(tf.train.SessionRunHook):def __init__(self):np.set_printoptions(suppress=True)np.set_printoptions(linewidth=400)def before_run(self, run_context):"""返回SessionRunArgs和session run一起跑"""v1 = tf.get_collection('logis')prob = tf.get_collection('prob')return tf.train.SessionRunArgs(fetches=[v1, prob])def after_run(self, run_context, run_values):v1, batch_labels = run_values.resultslogger.info("logis value:{}".format(v1))print("prob :",batch_labels)# 实现estimator
class MyEstimator(tf.estimator.Estimator):def __init__(self,model_dir,hidden_units,optimizer,activation_fn,dropout=None,batch_norm=False,weight_column=None,label_vocabulary=None,loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,params=None,config=None,warm_start_from=None):def model_fn(features,labels,mode):inputs_layers =tf.feature_column.input_layer(features,feature_columns)# 自定义网络层user_hidden_fn = DNNModel(hidden_units=hidden_units,activation_fn=activation_fn,dropout=dropout,batch_norm=batch_norm,name="user_dnn")user_hidden_net = user_hidden_fn(inputs_layers,mode=mode)with tf.name_scope("logits"):logits = tf.keras.layers.Dense(units=2, activation=None)(user_hidden_net)loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(labels['label'],[-1]),logits=logits))train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())# Compute predictions.predicted_classes = tf.argmax(logits, 1)# 设置模型评价指标accuracy = tf.metrics.accuracy(labels=labels["label"],predictions=predicted_classes,name='acc_op')auc = tf.metrics.auc(labels=labels["label"],predictions=predicted_classes,name='auc_op')metrics = {'accuracy': accuracy,'auc':auc}tf.summary.scalar('accuracy', accuracy[1])if mode==tf.estimator.ModeKeys.TRAIN:# 定义自定义钩子函数,并设置要输出的中间值的名称ownhook = YourOwnHook()tf.add_to_collection('logis', logits)tf.add_to_collection('prob',predicted_classes)# 将自定义钩子添加到训练的estimator中return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,training_hooks=[ownhook])if mode == tf.estimator.ModeKeys.EVAL:return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,loss=loss,eval_metric_ops=metrics)super(MyEstimator,self).__init__(model_fn=model_fn,model_dir=model_dir,params=params,config=config,warm_start_from=warm_start_from)

Reference

session_run_hook.py源码
Hook? tf.train.SessionRunHook()介绍【精】
TensorFlow系列——在自定义的标准estimator中使用tensorboard及打印中间数据

http://www.lbrq.cn/news/2404603.html

相关文章:

  • 办公用品网站建设share群组链接分享
  • 专做立体化的网站郑州网站托管
  • 如何用威客做网站推广 方案电子商务网页制作
  • 网站备案ip地址段网站自动提交收录
  • 做面包有哪些网站知乎江北seo页面优化公司
  • 本地网站建设多少钱学生没钱怎么开网店
  • 广州智能建站网站交易平台
  • 武汉人才网seo的中文意思
  • 海淀网站建设多少钱百度竞价排名的优缺点
  • 机械网站建设注意什么下载百度app下载
  • 深圳住建局官网登录入口青岛推广优化
  • html网站的直播怎么做的企业查询app
  • 公众号申请网站网站管理
  • 北京做软件最好的公司重庆seo网站收录优化
  • 西安网站建设维护达州seo
  • 营口做网站企业怎么样在百度上推广自己的产品
  • 做电商怎么找货源seo网站优化系统
  • 做室内设计的网站有哪些方面陕西seo关键词优化外包
  • 谷歌账号注册网站打不开黑龙江最新疫情
  • 淘宝客是如何做网站与淘宝对接的天津百度seo推广
  • 海外营销网站建设百度做免费推广的步骤
  • 做公众号时图片的网站外链代发免费
  • 延吉做网站ybdiran友情链接网站免费
  • 舅舅建筑网东莞优化怎么做seo
  • 一个页面多少钱惠州seo快速排名
  • 自己做微网站制作教程网站点击量查询
  • 网站建设 补充协议百度学术官网论文查重免费
  • 做网站开发的是不是程序员seo关键词优化报价
  • 世界500强企业排名2021茶叶seo网站推广与优化方案
  • 工程机械外贸网站建设seo文章排名优化
  • 枪战验证系统:通过战斗证明你是人类
  • Java中List<int[]>()和List<int[]>[]的区别
  • 将EXCEL或者CSV转换为键值对形式的Markdown文件
  • 前端面试专栏-工程化:28.团队协作与版本控制(Git)
  • Jmeter系列(7)-线程组
  • 【Web APIs】JavaScript 自定义属性操作 ② ( H5 自定义属性 )