一.where函数

torch.where(condition, x, y)

where函数具体作用为:利用x和y生成一个新的tensor,其中的参数 condition 为一个与x和y一样shape的tensor,掌管着新的tensor的生成条件,如果condition某个位置的值为1,则新的tensor来自与x, 反之为0,则来自于y。(condition的使用很灵活)例如:

二.gather函数

torch.gather(input, dim, index, out=None)

该函数作用为查表,其中input作为参照表,使用dim来指定维度,index则是你要查找一一的应的tensor,最后将index以input的方式进行返回。具体使用方式及规则如下:

例如:

 

 

 

 

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐