• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

python/pytorch计算tensor的余弦相似度

武飞扬头像
Jumbo星
帮助1

一、相似度和点积

很多场景里,需要比较两个tensor的相似度(NLP或者CV里都有可能),这种相似度的计算一般用余弦相似度来计算,也就是常说的向量点积(dot-product),比如Transformer里self-attention的相关操作,用点积来计算Q和K的“相似度”

学新通

学新通

二、Pytorch的简单实现

很好的是,torch里有现成的函数cosine_similarity,不需要像网上那种要自己定义一个复杂的类来实现。

torch.cosine_similarity(input1,input2,dim=1, eps=1e-8)

  • input1和input2都需要是两个torch.Tensor类型的变量
  • dim指定在某个维度上进行计算相似度,default=1,即可以不输入
  • eps是避免出现除数为0的一个极小值,一般不输入

例:

通过transform的编码器对两张图进行编码,得到了两个shape为[1,1,768]的tensor:img1和img2

  1.  
    import torch
  2.  
    # img1.shape = [1,1,768] = img2.shape
  3.  
    cos_sim = torch.cosine_similarity(img1, img2, dim=2)
  4.  
    # tensor([[0.9457]], device='cuda:0')
  5.  
    print(cos_sim)

可以看到这两张图的相似度是0.9457

如果是批量化计算,得到一组cos,怎么方便计算平均余弦相似度呢?

参考做法:

  1.  
    import torch
  2.  
    # img1.shape = [1,1,768] = img2.shape
  3.  
     
  4.  
     
  5.  
    cos_list = []
  6.  
    for i in range(n):
  7.  
    cos_sim = torch.cosine_similarity(...)
  8.  
    cos_list.append(cos_sim)
  9.  
    #此时cos_list为list,但是里面都是一个个tensor 不方便计算
  10.  
    # cos_list.shape = [9,1,1]
  11.  
    # 可以用下面的方法 先建一个新维度 然后在这个维度上mean
  12.  
    mean_cos=torch.stack(cos_list,dim=0).mean(dim=0)
  13.  
    # tensor([[0.9599]], device='cuda:0')
  14.  
    print(mean_cos)

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhiagjag
系列文章
更多 icon
同类精品
更多 icon
继续加载