当前位置: 首页 > news >正文

网站制作网站设计营销网站推荐

网站制作网站设计,营销网站推荐,找最新游戏做视频网站,常见的静态网页模型推理详细步骤 模型加载步骤 首先,模型加载总共分为三步,第一步加载网络结构,需要和你训时的network结构一样。 model Model.FeedBack3(cfg, config_pathNone, pretrainedTrue).to(device)第二步,加载训练好的参数&#xf…

模型推理详细步骤

模型加载步骤

首先,模型加载总共分为三步,第一步加载网络结构,需要和你训时的network结构一样。

model = Model.FeedBack3(cfg, config_path=None, pretrained=True).to(device)

第二步,加载训练好的参数,实际上虽然我们一直说训练模型,实际上训练出来的就是一组参数,这个参数是一个字典类型,一般保存的名称为xxx.pt或者pth。里面存放的是模型每一层中的权重等数据。pytorch中对于加载参数使torch.load()

pretrained_dict = torch.load('outputmicrosoft-deberta-v3-base_fold3_best.pth')

第三步,将参数加载进模型里

model.load_state_dict(pretrained_dict['model_state_dict'], strict=True)

以上就是加载模型的所有步骤了

关于模型参数和字典对不上的问题

一般报错为:Missing key(s) in state_dict: xxxx
最近在做模型部署的时候发现了这个问题,并且之前也遇到过,由于急于求成就简单实在模型加载参数的时候用了strict=False这样的条件,这个条件会使模型直接忽略所有对不上的参数,本质上没有解决问题。今天在debug时对模型每一层的参数排查终于发现了问题所在。
首先开启debug模式,直接将断点打在模型加载的代码上:
首先查看model的结构有没有问题:
在这里插入图片描述
接下来进行下一步,执行到加载参数字典,同样查看你的参数字典(这里由于参数过多就不详细展示了):
在这里插入图片描述
那么要如何排查呢,具体步骤如下:
首先参数字典里都是以键值对和tensor型式存储的,那么我们只需要一一排查键值对和参数。比如首先是model建,那么只有你加载参数的时候只有加载里面的model建模型才能读到参数,实际上我就是错在这里了,因为我加载的是通常使用的‘model_state_dict’这个建,因为我训练部分是网上复制来的代码,没想到他把参数保存为model。
在这里插入图片描述

也就是我只需要把前面的

model.load_state_dict(pretrained_dict['model_state_dict'])

改成

model.load_state_dict(pretrained_dict['model'])

就行了。
那么如果你的问题不是这里,接下来改如何排查呢
接着看OrderedDict里,这里面是模型每一层的参数,对照方法如下:
在这里插入图片描述
相当于网络结构中的每一层都会变为一个对应的tensor
(model)(embeddings)(LayerNorm)在参数中就会存为:(‘model.embdeddings.LayerNorm’, tensor([xxxxx])

这样就看懂了吧,如此对照每一层网络结构,只要你有耐心,就能找出来具体是那一层不对,不过大多情况下这种在网络中间层出现参数不对的情况很少,出现的原因也肯定是你推理部分加载的网络结构和训练时的网络结构不一致导致的。
顺便推荐一个能帮你排查模型参数的代码,他会输出具体有多少参数使用了和没使用:

def check_keys(model, pretrained_state_dict):ckpt_keys = set(pretrained_state_dict.keys())model_keys = set(model.state_dict().keys())used_pretrained_keys = model_keys & ckpt_keysunused_pretrained_keys = ckpt_keys - model_keysmissing_keys = model_keys - ckpt_keys# filter 'num_batches_tracked'missing_keys = [x for x in missing_keysif not x.endswith('num_batches_tracked')]if len(missing_keys) > 0:print('[Warning] missing keys: {}'.format(missing_keys))print('missing keys:{}'.format(len(missing_keys)))if len(unused_pretrained_keys) > 0:print('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys))print('unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))print('used keys:{}'.format(len(used_pretrained_keys)))assert len(used_pretrained_keys) > 0, \'check_key load NONE from pretrained checkpoint'return True

模型推理中的数据处理

首先模型推理中数据最终的处理格式要和训练时输入进网络中的格式一致,不过我们通常不再构造新的dataset和使用dataloader,而是直接针对input处理成我们需要的格式。
主要步骤为,读取数据,embedding,增加维度
读取的数据可以是本地存的,如果你是要将模型部署在web上那么数据就是从客户端传来的json格式的数据,因此通常需要先将真正的input取出来。
接下来是向量化,这里步骤和训练中的一致,比如训练中使用了resize([800,800])和toTensor,那么推理中也要这样设置。
由于我是NLP任务,那么处理的步骤为

inputs = cfg.tokenizer.encode_plus(input,return_tensors=None,add_special_tokens=True,max_length=cfg.max_lenth,pad_to_max_length=True,truncation=True)for k, v in inputs.items():inputs[k] = torch.tensor(v, dtype=torch.long)

至此,再次输出此时的tensor和训练时输入进模型的tensor相比,只是少了一个维度,这个维度通常可以理解我们在训练的时候是有batch_size的,而推理时没有,因此要手动升维,升维度的函数有很多,通常使用unsequeeze(1)或者expand:

for k, v in inputs.items():s = v.shapeinputs[k] = v.expand(1,-1).to(device) #-1自动计算

这样处理完数据格式就和训练时完全一致了,说白了还是要先debug一下训练时的数据,看看到底输进去的是什么格式,然后在推理部分照着一点一点改。

http://www.lbrq.cn/news/2443717.html

相关文章:

  • 莆田外贸专业建站关键词排名网站
  • 网站制作的方法网站推广方法有哪些
  • 最大网站建设公司郑州网站建设用户
  • 网站建设和网络搭建是一回事吗系统优化的意义
  • 徐州 网站建设如何做好百度推广
  • 成都锦江建设局网站b2b平台排名
  • 网站开发税点开户推广竞价开户
  • 汽车网站推广策略乐陵seo优化
  • 上海做得好的网站建设公司免费的舆情网站app
  • 兰州展柜公司网站建设线上销售平台有哪些
  • 百货商城网站建设网络营销是什么意思
  • wordpress和站点网页制作软件dreamweaver
  • 莱西网站建设公司企业网站建设方案
  • 历史价格查询百度seo关键词排名优化教程
  • 珠宝静态网站模板茂名seo顾问服务
  • 成都网上商城网站建设互联网推广招聘
  • 本网站建设广州网页推广公司
  • 网站设计方案及报价单湖南百度推广公司
  • 全球网站排名查询外贸软件
  • 算命网站建设开发沈阳seo按天计费
  • 织梦做的网站打包在dw修改友情链接的四个技巧
  • wordpress主题更改首页贵州seo技术查询
  • 网站标题写什么作用是什么网站推广优化价格
  • 网站建设合同封面网络营销是指
  • 个人网站免费模板seo外链友情链接
  • 学校网站模板html下载女孩子做运营是不是压力很大
  • 昆明定制化网站建设天津seo推广软件
  • 哈尔滨住房和城乡建设委员会网站百度app登录
  • 微信自己怎么弄小程序北京seo供应商
  • 酒泉如何做百度的网站经典软文广告
  • Linux权限机制:RUID/EUID/SUID与进程安全
  • TwinCAT3编程入门1
  • Ethereum:Geth + Clef 本地开发环境,如何优雅地签名并发送一笔以太坊交易?
  • Java学习第七十五部分——Docker
  • Python常用医疗AI库以及案例解析(场景化进阶版)
  • 【C++】使用中值滤波算法过滤数据样本中的尖刺噪声