当前位置: 首页 > news >正文

设计网站printestseo官网优化

设计网站printest,seo官网优化,做五金的外贸网站有哪些,宿州公司网站建设pytorch 到 tensorflow 可以用onnx作为中间工具转换,将pytorch转为onnx,再从onnx转为tensorflow,但是中间可能出现一些乱七八糟的问题。其实手动读参数再填充的对应的模型中也很方便,本文就总结一下手动模型转换。第一种方式&…

pytorch 到 tensorflow 可以用onnx作为中间工具转换,将pytorch转为onnx,再从onnx转为tensorflow,但是中间可能出现一些乱七八糟的问题。其实手动读参数再填充的对应的模型中也很方便,本文就总结一下手动模型转换。

第一种方式: 直接用kernel_initializer来填充权重

  • dense 层转换:
 dense_w = state_dict['dense.weight'].permute(1,0).numpy()dense_b = state_dict['dense.bias'].numpy()output = tf.keras.layers.Dense( dense_w.shape[-1], kernel_initializer=tf.constant_initializer(dense_w), bias_initializer=tf.constant_initializer(dense_b),name='bottleneck')(output)output = tf.keras.layers.Softmax(-1)(output)
  • 卷积层转换:
 conv_w = state_dict['conv.weight'].permute(2,1,0).numpy()conv_b = state_dict['conv.bias'].numpy()  output = tf.keras.layers.Conv1D(filters=conv_w.shape[-1], kernel_size=conv_w.shape[0], padding='same', kernel_initializer= tf.constant_initializer(conv_w), bias_initializer=tf.constant_initializer(conv_b), name ='conv')(input)output = tf.keras.layers.LeakyReLU(alpha=0.01)(output) 

注意:leakyReLU pytorch 中对应的默认参数和tensorflow中的默认参数不同,一定要保持一致。

  • BiLSTM层转换:
 hidden_channel=128lstm_w = state_dict['layer.weight_ih_l0'].permute(1,0).numpy()lstm_r = state_dict['layer.weight_hh_l0'].permute(1,0).numpy()lstm_b = state_dict['layer.bias_hh_l0'].numpy() + state_dict['layer.bias_ih_l0'].numpy()lstm_w_inv = state_dict['layer.weight_ih_l0_reverse'].permute(1,0).numpy()lstm_r_inv = state_dict['layer.weight_hh_l0_reverse'].permute(1,0).numpy()lstm_b_inv = state_dict['layer.bias_hh_l0_reverse'].numpy() + state_dict['layer.bias_ih_l0_reverse'].numpy()fw = tf.keras.layers.LSTM( hidden_channel, return_sequences=True, recurrent_activation='sigmoid', use_bias=True,  kernel_initializer=tf.keras.initializers.constant(lstm_w),unit_forget_bias=False, recurrent_initializer=tf.keras.initializers.constant(lstm_r),bias_initializer=tf.keras.initializers.constant(lstm_b))bw = tf.keras.layers.LSTM( hidden_channel, return_sequences=True, go_backwards=True, recurrent_activation='sigmoid', use_bias=True, kernel_initializer=tf.keras.initializers.constant(lstm_w_inv),unit_forget_bias=False, recurrent_initializer=tf.keras.initializers.constant(lstm_r_inv),bias_initializer=tf.keras.initializers.constant(lstm_b_inv))output = tf.keras.layers.Bidirectional(fw, backward_layer = bw)(output)

注意:LSTM的转换方式特殊一些,pytorch中包含4组参数(input的权重+state的权重+两组bias, 文档中说第二组参数主要是为了CuDNN的并行化torch.nn.modules.rnn - PyTorch 1.7.0 documentation NVIDIA Deep Learning cuDNN Documentation); 而在tensorflow中包含3组参数(input的权重+state的权重+一组bias),在转换的时候将pytorch中的两组参数相加作为tensorflow的bias。这种方式 unit_forget_bias=False 否则bias 参数对不上。

!!!虽然 tf.keras.layers.LSTM 中参数recurrent_activation默认设置为'sigmoid',但是不显示设置为'sigmoid', 结果对不上。

第二种方式: 使用set_weights 来填充权重

  • 特别注意bilstm的层:
 lstm_w = state_dict['layer.weight_ih_l0'].permute(1,0).numpy()lstm_r = state_dict['layer.weight_hh_l0'].permute(1,0).numpy()lstm_b = state_dict['layer.bias_hh_l0'].T.numpy() + state_dict['layer.bias_ih_l0'].T.numpy()lstm_w_inv = state_dict['layer.weight_ih_l0_reverse'].permute(1,0).numpy()lstm_r_inv = state_dict['layer.weight_hh_l0_reverse'].permute(1,0).numpy()lstm_b_inv = state_dict['layer.bias_hh_l0_reverse'].T.numpy() + state_dict['layer.bias_ih_l0_reverse'].T.numpy()fw = tf.keras.layers.LSTM(hidden_channel, return_sequences=True, recurrent_activation='sigmoid')bw = tf.keras.layers.LSTM(hidden_channel, return_sequences=True, go_backwards=True, recurrent_activation='sigmoid')output = tf.keras.layers.Bidirectional(fw, backward_layer = bw)(output)keras_format_weights = [lstm_w,lstm_r,lstm_b,lstm_w_inv,lstm_r_inv,lstm_b_inv]  model.get_weights(‘xxxx’).set_weights(keras_format_weights)

注意:bilstm包含6组权重矩阵,如果是单向lstm是3组权重矩阵,其他层的转换方式类似。

http://www.lbrq.cn/news/2453293.html

相关文章:

  • 扁平化设计风格网站网站建设公司官网
  • 广州微网站建设效果刷钻业务推广网站
  • lua做网站焊工培训ppt课件
  • 南京网站排名关键词优化百家号
  • 全国疫情最新资讯windows优化大师会员兑换码
  • 江油网站建设自制网站
  • 建设工程质量监督站网站站长工具ip查询
  • php怎么做网站后台品牌推广营销平台
  • 西安网络公司做网站快照网站
  • wordpress漫画模板宁波seo推广哪家好
  • 介休做网站江阴百度推广公司
  • 电子商务网站建设规划书范文青岛网站制作公司
  • 网站效果用什么软件做百度seo泛解析代发排名
  • 荣成网站制作公司谷歌play
  • wordpress后台编辑网站seo关键词排名查询
  • 实时爬虫网站是怎么做的网络推广员每天的工作是什么
  • 重庆网站营销公司友情链接平台站长资源
  • 不会被封的网站谁做如何让自己网站排名提高
  • 淮安网站建设制作网络营销企业有哪些公司
  • 太仓公司做网站潍坊网站建设
  • 怎么做英文垃圾网站怎么自己弄一个网站
  • 深圳哪里有做网站推广的搜狗搜索排名优化
  • 织梦儿童早教教育培训网站模板石家庄seo
  • 58同城给做网站seo网站快速排名
  • 制作手机网页教程网络seo
  • 打开网站notfound企业网站seo点击软件
  • 类似于kobas的网站做kegg分析新闻头条最新
  • 2021建站公司营销推广投放平台
  • 为什么不能用来名字做网站名拓客软件
  • 网站建设公司简介企业产品推广运营公司
  • 告别配置混乱!Spring Boot 中 Properties 与 YAML 的深度解析与最佳实践
  • 【Python小工具】-英文大小写转换功能的GUI工具
  • SQL性能优化
  • 12. isaacsim4.2教程-ROS 导航
  • 数据结构3-单双链表的泛型实现及ArrayList与LinkedList的区别
  • 【硬件-笔试面试题】硬件/电子工程师,笔试面试题-15,(知识点:DC-DC电源,BUCK电路,铁损,铜损)