HNSW的基本原理及使用

1. Small world vs. Random graph

在正式的介绍NSW和HNSW之前,先来了解一下小世界和随机图的概念方便后续理解为什么NSW能够做近邻查找。

1.1 Regular graph vs. Random graph

在图论中对正则图的定义如下:

正则图是指每个顶点都有相同数目邻居的图,即每个顶点的度相同。若每个顶点的度均为 $k$,称为 $k$-正则图。

正则图

随机图是指在随机过程的生成的图,也就是节点和节点之间的连接是随机建立的。

随机图和正则图的对比

在正则图中,当聚类系数接近饱和的时候,聚类系数比较高,平均路径也比较短,但是此时节点的度比较高。

随机图节点的聚类系数比较低,并且节点的度也比较低。

1.2 Small world

在介绍完了随机图和正则图,我们再来看一下小世界网络。

在1967年Stanley Milgram从Kansas和Nebraska两个州招募了一批志愿者, 请他们分别将一封信转寄给一个住在Cambridge神学院学生的妻子和一个住在Boston郊区的股票经纪人。 他给志愿者们这样的要求:

  1. 虽然有寄信目标的相关信息,如果不是私人关系,不能把信直接寄给TA.
  2. 每次只能把信寄给最有可能知道这个人的熟人。
  3. 原始信封里有15张追踪卡片,每次转寄都要回寄一张给实验者,其他的放在信封里寄给下一个人,这样研究员可以随时追踪这些信函的路径。

在到达的信函中,Stanley Milgram计算信函平均到达的节点为5个,也就是我们和一个陌生人建立连接只需要6步。

Stanley Milgram基于他的实验提出了著名的六度分离理论,这个理论指出:

  1. 现实世界中的短路径是普遍存在的。
  2. 人们可以有效地找到并且利用这些短路径。

在小世界网络中,可以把点与点之间的关系可以分为两种:

  • 同质性:同质性也就是相似的点会聚集到一起,相互连接具有邻接边。
  • 弱连接:弱连接是指从每一个节点上,会有一些随机的边随机连接到网络中的节点上,这些节点是随机均匀的。

1.3 三者之间的关系

有研究表明,小世界网络介于正则图和随机图之间,正则图随着随机性的增加具有小世界的特性。 Regular graph, Small-World和Random Graph的关系

我们可以这么理解:小世界在局部同类节点的连接呈现出规则,从全局来看不同类节点的连接呈现出随机性。 这两种性质也就是上面我们所说的同质性和弱连接。

2. Navigable Small World

可导航小世界的原理很简单,其基本原理如下图所示:

在NSW算法中通过构建一个小世界网络,希望通过黑色相似的近邻边来检索最近邻节点, 通过红色长边(高速公路)来实现不同类节点之间的快速检索。

这里我们不妨考虑一下,为什么regular graph不能做近邻检索? 为什么random graph不能做近邻检索,为什么small world 可以用来做近邻检索?

2.1 图的检索

在了解完NSW的基本思路之后,接下来我们看一下NSW当中,如何对整个图中的节点查找K个最近邻节点。

K 近邻查找

在NSW中K近邻检索的过程如下:

1) 随机选择1个元素,放入到candidates当中

2) 从candidates中选取最近邻节点c,将这些元素的邻居节点放置到q当中

3) 从candidates中移除最近邻节点c

4) 如果c的距离远大于result中的第k个节点,跳出循环

5) 否则,对于c的每个邻居节点,遍历其邻居,如果没有在visited set里面。

6) 将e加入到visited set, candidates, tempRes

7) 遍历完成candidate中所有的节点后,把tempRes的结果传入到result

8) 重复执行上述步骤m遍, 返回result中最优的k个近邻结果。

具体的伪代码描述如下:

K-NNSearch(object q,integer:m,k)
1	TreeSet[object]tempRes, candidates, visitedSet, result
2 	for(i<-0; i<m; i++) do:
3 		put random entry point in candidates
4 		tempRes<-null
5 		repeat:
6 			get element c closest from candidates to q
7 			remove c from candidates
8			#checks to p condition:
9			if c is further than k-th element from result
10			than break repeat
11			#update list of candidates:
12			for every element e from friends of c do:
13				if e is not in visited Set than
14					add e to visited Set, candidates, tempRes
15
16 		end repeat
17 		#aggregate the results:
18 		add objects from tempRes to result
19 	end for
20 	return best k elements from result

2.2 图的构建

基于NSW的原理,我们希望NSW的局部节点之间的在距离上具有同质性(也就是近邻节点能够相互连接)。从而使得当我们检索 到一个近邻节点时,其大部分近邻节点都是近邻节点。同时也希望保留一些随机边,能够在不同区域之间快速跳转。

那么我们需要怎么样构建一个,具有同质性同时又具备随机性的小世界网络呢?

Delaunay 三角剖分

为了使得相邻的点在空间距离上相近,我们引入Delaunay三角剖分,相关的定义如下

  • Delaunay 边

在点集 $V$ 中存在两点 $a$ 和 $b$,圈内不包含点集 $V$ 中的任何其他点。这个特质被称为空圈特质个。 节点 $a$ 和节点 $b$ 连接起来的边称为Delaunay边。

  • Delaunay 三角剖分

如果一个点集 $V$ 的三角剖分 $T$ 都只包含 Delaunay边,那么该三角剖分称为Delaunay剖分。

NSW的构建

构建图的时候,理论上来说我们对所有的点做Delaunay三角剖分,然后添加一些随机的长边构建快速检索通道, 就构建了一个可导航的小世界网络。

由于构建Delaunay三角剖分的复杂度太高实际的代码实现过程中是通过节点随机插入来引入随机性,利用已有节点构建Delaunay边来引入同质性。

NSW的网络构建过程如下:

  1. 在候选节点$V$里面随机挑选一个节点$v_i$
  2. 将节点$v_i$插入到已经构建好的图中,并构建边。
  3. 边构建的规则:找到节点$v_i$最近邻的 $f$ 个邻居,建立$v_i$和这些邻居的边连接。

对应的伪代码如下:

Nearest_Neighbor_Insert(object: new_object,integer:f, integer:w)
1 SET[object]:neighbors<-k-NNSearch (new_object, w, f);
2 for(i<-0; i<f; i++) do
3 	neighbors[i].connect(new_object);
4 	new_object.connect(neighbors[i]);

在构建NSW图结构的时候,在局部通过寻找 $f$ 个最近邻来建立类似于Delaunay三角剖分的结构, 在全局通过随机顺序插入,引入随机边从而使得所以具备可导航小世界的特性。

3. Hierarchical Navigable Small World

在NSW中,构建图的阶段通过节点的随机插入来引入随机性,构建出一个类似于小世界的网络结构。在NSW中很明显地会存在 几个问题。

  • 对于最先插入的节点,其连接的邻居节点,基本都比较远(弱连接属性较强)
  • 对于最后插入的节点,其连接的邻居节点,基本都比较近(弱连接属性较弱)
  • 对于具有聚类效应的点,由于后续插入的点可能都和其建立连接,对应节点的度可能会比较高。

等等

如果继承NSW基于long link快速检索,short link具有聚类特性的思想。怎么样能够使得查找更为稳定, 或者怎么样能够把long link的查找和short link查找有效区分。在此基础上引入了分层图的思想。

基于这些问题在NSW的基础上我们来看一下HNSW。

根据上图可以直接看出HNSW在NSW基础上所作的优化。

在HNSW中,引入Layers的概念,总体思想如下:

  1. 在Layer = 0 层中,包含了连通图中所有的点。
  2. 随着层数的增加,每一层的点数逐渐减少并且遵循指数衰减定律
  3. 图节点的最大层数,由随机指数概率衰减函数决定。
  4. 从某个点所在的最高层往下的所有层中均存在该节点。
  5. 在对HNSW进行查询的时候,从最高层开始检索。

3.1 HNSW的查询

在HNSW的查询阶段,包括以下几个算法。

  • SEACHER-LAYER: 在指定层查询K个最近邻节点。
  • SELECT-NEIGHBORS-SIMPLE: 简单的查找某一层最近的邻居节点。
  • SELECT-NEIGHBORS-HEURISTIC: 探索式查找某一层最近的邻居节点。
  • K-NN-SEARCH: 从所有候选结果中找出K个最近邻结果。

接下来,我们来具体看一下这几个算法和对应的具体查询逻辑。

3.1.1 SEACHER-LAYER

算法伪代码

传入参数

q:表示需要查找的节点

eq: 固定的起始节点,如果Layer是最顶层,有固定的入口节点。如果不是最顶层则是上一层查询到的最近邻。

ef: 查找的邻居节点数目

lc: 查询的层数

输出 :

q元素最近邻的ef个节点。

功能:

SEARCH LAYER算法的功能是在给定一个节点q和起始查询节点eq、查询的层lc的情况下,查找出 节点q在层lc下的ef个最近邻。

查询步骤:

1) 首先根据ep 初始化visited set V, candidate set C, 以及动态最近邻W

2) 当 candidate set 不为空的时候执行:

2.1) 从candidate set C中选取离q最近的点c,

2.2) 从动态最近邻中选取最远的点f,

2.3) 比较distance(c,q)和distance(f,q)

2.4) 如果distance(c,q) > distance(f,q)执行步骤 3 否则继续执行 2.5

2.5) 对在lc层中c节点的每个邻居e。如果e在visited中,重新执行步骤 2, 否则继续执行 2.6

2.6) 将e节点加入visited set

2.7) 从W中获取最远的节点f

2.8) 如果distance(e,q) < distance(f,q) 或者 |W| < ef 将 e分别加入 candidate set C和动态最近邻W

2.9) 如果 |W| > ef 移除最大元素。

3) 返回集合W

3.1.2 SELECT-NEIGHBORS

在select neighbors主要分为两个部分由SELECT-NEIGHBORS-SIMPLE以及SELECT-NEIGHBORS-HEURISTIC两个算法组成。

SELECT-NEIGHBORS-SIMPLE和SELECT-NEIGHBORS-HEURISTIC两个算法都是用在图构建的过程中,而不用在KNN的近邻 检索,与SIMPLE不同HEURISTIC方法添加了更多的随机性,从而同一层节点之间的连接随机性更强。

  • SELECT-NEIGHBORS-SIMPLE

算法伪代码

参数输入

q:表示需要查询的节点。

C:表示查询的候选节点集合。

M:表示返回最近邻居的个数。

输出

q在C中的M个最近邻居

功能

选取出节点q在候选集C中的M个最近邻居。

  • SELECT-NEIGHBORS-HEURISTIC

算法伪代码

参数输入

q:表示我们需要查询的节点

C:表示candidate 节点

M:表示返回的最近邻节点的个数M

lc:表示返回的层的编号

extendCandidates:表示是否需要扩展candidate

keepPrunedConnection:表示是否需要把废弃节点加入到返回结果中

返回结果

通过探索式查找返回最近邻的M个结果。

3.1.3 K-NN-SEACHER

KNN查询的逻辑很简单,从固定的enter节点进入,在顶层开始检索。 在每一层检索到唯一一个最近邻然后作为下一层入口节点。最后在底层检索top K个最相似节点。

算法伪代码

3.2 HNSW的插入

在HNSW中,通过插入算法来构建整个图结构并在此基础上进行检索。HNSW的插入算法如下。

算法伪代码

算法参数

  • hnsw: 节点所需要插入的目标图结构
  • q: 需要插入的节点
  • M: 每个节点需要与其他节点建立的连接数,也就是节点的度。
  • efConstruction: 用来设置查询网络节点集合的动态大小
  • mL: 用来选择节点q的层数的时候所需要用到的归一化因子。

算法输出

插入节点q后的hnsw网络结构。

节点插入过程

在整个HNSW的insert的过程中包含以下几个部分。

1) 初始化当前最近邻集合W,初始化固定节点ep,获取顶层编号L,获取新插入节点的层l

2) 对于属于L->l+1的每一层查找出q的最近邻节点。

3) 对于lc <- min(L,l)..0的每一层执行以下步骤:

3.1) 每一层查找出最近的efConstruction个节点得到集合M。

3.2) 在每个节点中查找到最近的M个neighbors。(采用算法3,或者算法4)

3.2) 将在层lc中的所有neighbors和节点q建立连接。

3.3) 对于neighbors中的每个节点e重新判断一下节点个数,然后减少e节点的邻居节点重新建立连接。

4) 如果 l > L,将q设置为hnsw的enter point

在上述伪代码中对于每个新节点而言获取器layer id的公式如下。

\[l = \lfloor -ln(unif(0,1)) \cdot mL \rfloor\]

HNSW通过一个随机函数,将所有的点划分到不同层次,越往上节点数越少,边越少。这种情况下节点和节点之间寻找最近邻居的 距离也就越远。因此在从上到小检索的过程中,先通过Long Link找到全局可能的最近节点,然后往下层以该节点为入口 进一步做局部检索。

4 总结

在这篇文章中,主要介绍了NSW和HNSW的算法原理。NSW算法基于六度分离理论将小世界的特性用于近邻检索, 提出了基于图结构的检索方案。

在NSW的基础上,HNSW利用多层的图结构来完成图的构建和检索,使得通过将节点随机划分到不同的layer, 从上层图到下层图的检索中,越往下层节点之间的距离越近, 随机性也越差,聚类系数越高。 HNSW通过从上到下的检索,完成了NSW中Long Link高速公路快速检索的作用,通过最后底层的近邻检索, 完成局部最近邻的查找。

参考资料

[1] Navigable Small-World Networks

[2] 一文看懂HNSW算法理论的来龙去脉

[3] HNSW学习笔记

[4] 近似最近邻算法 HNSW 学习笔记(一)介绍

[5] 近似最近邻算法 HNSW 学习笔记(二) 主要算法伪代码分析

[6] 近似最近邻算法 HNSW 学习笔记(三)对于启发式近邻选择算法的一些看法

[7] Delaunay三角剖分实践与原理

[8] Hierarchical Navigable Small World

[9] Small-World Experiment or Just Six Steps Away off Loneliness…

[10] Navigable Small-World Networks

[11] HNSW学习笔记

Candidate Sampling

在这篇文章中主要介绍一下Candidate sampling在模型训练中的使用。

作为一个菜鸡的推荐炼丹师,前段时间看YouTube DNN的炼丹手册和双塔模型(DSSM)的配药指南。 发现在关于计算优化的部分YouTube DNN和DSSM都用了importance sampling进负样本的选取。

从网络结构的搭建和数据的选取来看,YouTube DNN怎么看都像是一个广义的word2vec。可是在word2vec模型的TensorFlow实现里面用的是NCE做计算性能的采样优化。

此外,对于YouTube DNN的负样本选取知乎上也有广泛的讨论。例如:知乎石塔西在《负样本为王:评Facebook的向量化召回算法》提出的hard模式和easy模式假设。

为了弄清楚不同网络采样的细节和采样的方法,我翻阅大量偏方,并且稍微做一下梳理。

在介绍candidate sampling之前我们先了解一下问题的背景。

1. Softmax 和 Cross Entropy

在多分类问题中,模型训练的目的是在训练集上学习到一个函数$F(x,y)$,该模型对于测试集和验证集上 每一个输入$x$,能够准确地预测到对应的类别$y$。

在一个类别数为$K$的多分类问题中,模型的softmax层对每一个类别计算可能的概率:

\[P(y_j|x) = \frac{\mathrm{exp}(h^\mathsf{T}v'_j)}{\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) }\]

在softmax计算每一个输入类别概率的基础上,需要构建损失函数$J$来评估模型训练中学习得到$F(x,y)$ 的效果。基于最朴素的想法,我们希望预测的类别$y_{prediction}$和真实的类别$y_{label}$具有更为接近的分布。

在评估两个分布的距离上,很自然地联想到使用$KL$散度,那么对于损失函数的相关定义如下:

在模型训练过程中,对应的输入为: $x$

样本对应标签的期望分布为$D(y|x)$

模型$F(x,y)$预测出的类别分布为$P(y|x)$

\[\begin{align} J&= H(D(y|x),P(y|x)) \\ &= -D(y|x) \mathrm{ln}\frac{P(y|x)}{D(y|x)} \\ &= D(y|x)\mathrm{ln}D(y|x) - D(y|x)\mathrm{ln}P(y|x) \end{align}\]

在模型训练过程中,对于相同数据集来说$D(y|x)\mathrm{ln}D(y|x)$ 可以看成是常数。

那么用来评估训练结果的损失函数可以表示为:

\[\begin{align} \mathrm{min}\;(J) & = \mathrm{min}\;(D(y|x)\mathrm{ln}D(y|x) - D(y|x)\mathrm{ln}P(y|x)) \\ & = \mathrm{min}(\mathcal{K} - D(y|x)\mathrm{ln}P(y|x)) \\ & \sim \mathrm{min}(- D(y|x)\mathrm{ln}P(y|x)) \\ \end{align}\]

其中,$\mathcal{K}$ 为常数,那么损失函数的形式可以表示为期望分布$D(y|x)$和真实分布$P(y|x)$的交叉熵。

对所有$K$个类别求和,并且$D(y|x)$用标签$y$的值表示得到损失函数形式如下:

\[\begin{align} J = -\sum_{i=1}^{K} y_i \mathrm{ln}(P(y_i|x)) \\ \end{align}\]

其中,当$i$对应的类别为正样本时$y_i=1$,当$i$为负样本时$y_i=0$,上述公式简化为:

\[\begin{align} J &= - \mathrm{ln}\;P(y_{pos}|x) \\ & = - \mathrm{ln}\; \frac{\mathrm{exp}(h^\mathsf{T}v'_{pos})}{\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) } \\ \end{align}\]

其中$y_{pos}$表示正样本类别。

2. candidate sampling

在上面的章节中,我们讨论了模型训练在多分类问题当中的损失函数。

那么如果在分类数$K$非常多的情况下,对于每个样本分类的预测都需要计算$K$个类别的概率。

显然,在分类数较小的情况下softmax的计算量可以接受,但是当分类数目扩增到百万甚至千万量级的情况下会单个样本的计算量过大。

假设,模型训练的数据集中有1000万条样本数据,其中分类的个数为百万量级。每个分类概率的计算为0.01ms 那么所需要的计算时间:

\[T_{cost} = 10000000 \times 1000000 \times 0.01 \div 1000 \ 3600 \ 24 = 1157.d\]

显然这个计算量级是没有办法被接受的。

那么基本思想就是在怎么样不影响计算效果的前提下减小计算量。

针对这个问题目前具备两种方法:

  • softmax-based approach: 基于树结构的分层softmax,减少损失函数计算过程中计算量。
  • sampling-based approach: 通过用采样的方式,通过计算样本的损失来代替全量的样本计算。

这里我们主要介绍sampling-based approach的方法,也就是candidate sampling。

回到损失函数的计算公式,并做进一步的简化:

\[\begin{align} J & = - \mathrm{ln}\; \frac{\mathrm{exp}(h^\mathsf{T}v'_{pos})}{\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) } \\ &= - h^\mathsf{T}v'_{pos} + \mathrm{ln}\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) \end{align}\]

在这个公式中用$\xi(w) = - h^\mathsf{T}v’_{pos}$,简化为:

\[\begin{align} J &= \xi (w_{pos}) + \mathrm{ln}\sum^K_{i=1} \mathrm{exp}(-\xi(w_i)) \end{align}\]

对损失函数求导并计算梯度

\[\begin{align} \nabla_\theta J &= \nabla_\theta \xi (w_{pos}) + \nabla_\theta \mathrm{ln}\sum^K_{i=1} \mathrm{exp}(-\xi(w_i)) \end{align}\]

因为$log(x)$的梯度为$\frac{1}{x}$

\[\begin{align} \nabla_\theta J &= \nabla_\theta \xi (w_{pos}) + \frac{1}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \nabla_\theta\sum^K_{i=1} \mathrm{exp}(-\xi(w_i)) \end{align}\]

然后我们把求导符号放到累加符内得到:

\[\begin{align} \nabla_\theta J &= \nabla_\theta \xi (w_{pos}) + \frac{1}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \sum^K_{i=1} \nabla_\theta \mathrm{exp}(-\xi(w_i)) \end{align}\]

并且有$\nabla_x \mathrm{exp}(x) = exp(x)$那么:

\[\begin{align} \nabla_\theta J &= \nabla_\theta \xi (w_{pos}) + \frac{1}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \sum^K_{i=1} \mathrm{exp}(-\xi(w_i)) \nabla_\theta (-\xi(w_i) \end{align}\]

上面的公式可以重写成:

\[\begin{align} \nabla_\theta J &= \nabla_\theta \xi (w_{pos}) + \sum^K_{i=1} \frac{\mathrm{exp}(-\xi(w_i))}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \nabla_\theta (-\xi(w_i) \end{align}\]

其中 $\frac{\mathrm{exp}(-\xi(w_i))}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))}$就是输入上下文$c$在类别$i$上的的概率 $P(w_i | c)$

最终要计算的梯度形式如下:

\[\begin{align} \nabla_\theta J &= \nabla_\theta \xi (w_{pos}) + \sum_{i=1}^K P(w_i|c) \nabla_\theta (-\xi(w_i) \\ & = \nabla_\theta \xi (w_{pos}) - \sum_{i=1}^K P(w_i|c) \nabla_\theta \xi(w_i) \end{align}\]

根据最终的公式,我们可以将梯度的计算分为两部分:

  • $\nabla_\theta \xi (w_{pos})$: 是参数关于正样本 $y_{pos}$ 的梯度,可以理解为对目标词的正面优化。

  • $-\sum_{i=1}^K P(w_i|c)\nabla_\theta \xi(w_i)$: 是所有样本概率对应梯度的累加和,可以理解为对其他词汇的负向优化。

在基于采样的优化当中,我们不需要计算所有类别的累加,只需要通过采样求到$\nabla_\theta \xi(w_i)$ 在分布$P(w_i|c)$的期望即可。

那么:

\[\begin{align} \sum_{i=1}^K P(w_i|c)\nabla_\theta \xi(w_i) = \mathbb{E}_{w_i\sim P} [\nabla_\theta \xi(w_i)] \\ \end{align}\]

那么接下来的问题就变成了如何准确的计算梯度在概率分布$P(w_i)$上的期望:

\[\mathbb{E}_{w_i\sim P} [\nabla_\theta \xi(w_i)] \\\]

3. 常见的candidate sampling方法

在了解了candidate sampling方法的基本思想之后,我们怎么样计算期望$\mathbb{E_{w_i\sim P(w_i)}\nabla_\theta \xi(w_i)}$ 成为一个值得考虑的问题。

3.1 Importance Sampling

对于任何概率分布我们计算期望$\mathbb{E}$的时候,可以采用蒙特卡洛方法,根据分布随机采样出一系列样本,然后计算样本 的平均值。

对于上述的例子,如果我们知道模型在不同类别的概率分布$P(w_i)$,在计算期望的时候可以直接采样出$m$个类别$w_1,…,w_m$ 并且计算期望:

\[\mathbb{E}_{w_i \sim P}[\nabla_\theta\xi(w_i)] \approx \frac{1}{m}\sum_i^m \nabla_\theta\xi(w_i)\]

但是为了从分布$P$中采样样本,我们首先需要计算分布$P$。可是candidate sampling的目的就是为了避免计算分布$P$。 为了解决这个问题,能够使用的基本方法是重要性采样:

重要性采样(importance sampling)算法

假设我们需要计算概率密度函数$h(x)$在$\pi(x)$上的期望

\[\mu = \mathbb{E}_\pi{h(x)} = \int h(x)\pi(x)\]

那么重要性采样算法对应的形式如下:

(a) 首先,从分布 $g(\cdot)$ 中随机采样出 $m$ 个样本 $\mathrm{x}_1,…,\mathrm{x}_m$

(b) 计算重要性权重:

\[r(\mathrm{x}_i) = \frac{\pi(\mathrm{x}_i)}{g(\mathrm{x}_i)}, for\;\;j=1,...,m\]

(c) 近似期望 $\hat \mu$

\[\hat u=\frac{r_1 h(\mathrm{x}_1)+...+r_m h(\mathrm{x}_m)}{r(\mathrm{x}_1) + ... + r(\mathrm{x}_m)}\]

为了使得估计的时候误差更小,我们需要尽可能地使得$g(\cdot)$接近原来的$\pi(\mathrm{x})$。

这个时候上述公式可以描述为:

\[\hat \mu = \frac{1}{m} \{r(\mathrm{x}_1) h(\mathrm{x}_1) + ... + r(\mathrm{x}_m) h(\mathrm{x}_m)\}\]

根据上述描述,我们先预设一个分布$Q(w)$,为了使得$Q(w)$尽可能接近$P(w)$,一般可以采样一元分布。

对应的重要性权重 $r(w) = \frac{\mathrm{exp}(-\xi(w))}{Q(w)}$,那么对应的期望计算公式如下:

\[\begin{align} \mathbb{E}_{w_i \sim P} & \approx \frac{r(w_1) \nabla_\theta \xi(w_1) + ... + r(w_m) \nabla_\theta \xi(w_m)}{r(w_i) +...+r(w_m)} \\ & = \frac{\sum_{i=1}^m r(w_i) \nabla_\theta \xi(w_i)}{\sum_{i=1}^m r(w_i)} \end{align}\]

令 $R = \sum_{i=1}^m r(w_i)$ 得到

\[\begin{align} \mathbb{E}_{w_i \sim P} & \approx \frac{1}{R} \sum_{i=1}^m r(w_i) \nabla_\theta \xi(w_i) \end{align}\]

3.2 Noise Contrastive Estimation

在上面介绍完成Importance Sampling之后,我们来看一下Noise Contrastive Estimation(NCE)。抛开上面通过采样的思想 利用importance sampling近似计算多分类问题softmax损失的方法。

在NCE中,完全推翻上述方法并从试图从另外一个角度来解决多分类问题loss计算的问题——我们能否找到 一个损失函数用于替代原来的损失计算,从而避免softmax中归一化因子的计算。

NCE的基本思想是将多分类问题转换成为二分类问题,从噪音分布中采样,减少优化过程的计算复杂度。

在采样NCE方式计算loss的过程中,我们引入噪音分布$Q(w)$。这个噪音分布可以跟语境有关,也可以跟语境无关。 在噪音分布和语境无关的情况下,我们设置噪音分布的强度是真实数据分布的$m$倍。

那么对于训练数据$(c,w)$可以得到真实分布和噪音分布的概率:

\[\begin{align} &P(y=1|w,c) = \frac{P_{train}(w|c)}{P_{train}(w|c) + mQ(w|c)}\\ \\ &P(y=0|w,c) = \frac{mQ(w|c)}{P_{train}(w|c) + mQ(w|c)} \end{align}\]

得到

\[\begin{align} P(w|c) = P_{train}(w|c) + mQ(w|c) \end{align}\]

在原来的推导中:

\[\begin{align} P(w|c) = \frac{\mathrm{exp}(h^\mathrm{T} v'_{w})}{\sum_{i=1}^K \mathrm{exp}(h^\mathrm{T} v'_{w_i})} \end{align}\]

在NCE中为了避免对分母部分归一化因子的计算,将归一化因子表示为一个学习的参数$Z(c)$

\[\begin{align} Z(c) = \sum_{i=1}^K \mathrm{exp}(h^\mathsf{T} v_{w'_i}) \end{align}\]

这个时候简化为:

\[\begin{align} P(w|c) = \mathrm{exp}(h^\mathsf{T} v'_{w}) \end{align}\]

那么对于这个二分类问题计算Logistic regression损失:

\[\begin{align} J = [ln \frac{\mathrm{exp}(h^\mathsf{T} v'_{w_i})}{\mathrm{exp}(h^\mathsf{T} v'_{w_i}) + mQ(w_i)}] + \sum_{j=1}^m [ln(1-ln \frac{\mathrm{exp}(h^\mathsf{T} v'_{w_{i,j}})}{\mathrm{exp}(h^\mathsf{T} v'_{w_{i,j}}) + mQ(w_{i,j})})] \end{align}\]

在上述公式中,当$m\rightarrow \infty$, 上述公式和softmax的损失函数相似。

从NCE采样方法中可知:

  • 基于softmax的多分类问题的损失函数可以表示成为logistic regression二分类的形式。
  • NCE方法中,在梯度更新中放弃了对负样本参数的更新。

4. Tensorflow中candidate sampling的实现

理论很丰满,落地很骨感。

在了解完candidate sampling中的Importance sampling和Noise Contrastive Estimation的原理之后如果要工程落地还是需要依赖 可用的计算框架。在TensorFlow中就实现了这两个方法对应可以调用的API分别是:

  • importance sampling: tf.nn.sampled_softmax_loss()
  • Noise Contrastive Estimation: tf.nn.nce_loss()

4.1 tf.nn.sampled_softmax_loss()

sampled_softmax_loss()中包含了两部分内容。

  1. _compute_sampled_logits()
  2. softmax_cross_entropy_with_logits_v2()

_compute_sampled_logits() 主要进行采样并计算logit。

softmax_cross_entropy_with_logits_v2() 主要计算softmax的交叉熵损失。 接下来我们主要看一下_compute_sampled_logits()的源码。

4.2 tf.nn.nce_loss()

nce_loss()中包含了两部分内容。

  1. _compute_sampled_logits()
  2. sigmoid_cross_entropy_with_logits()

_compute_sampled_logits() 主要进行采样并计算logit。

sigmoid_cross_entropy_with_logits() 主要计算sigmoid的交叉熵损失。

接下来我们主要看一下_compute_sampled_logits()的源码。

def _compute_sampled_logits(weights,
                            biases,
                            labels,
                            inputs,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None,
                            seed=None):

  if isinstance(weights, variables.PartitionedVariable):
    weights = list(weights)
  if not isinstance(weights, list):
    weights = [weights]

  with ops.name_scope(name, "compute_sampled_logits",
                      weights + [biases, inputs, labels]):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])

    # Sample the negative labels.
    #   sampled shape: [num_sampled] tensor
    #   true_expected_count shape = [batch_size, 1] tensor
    #   sampled_expected_count shape = [num_sampled] tensor
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes,
          seed=seed)
    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
    # pylint: disable=unpacking-non-sequence
    sampled, true_expected_count, sampled_expected_count = (
        array_ops.stop_gradient(s) for s in sampled_values)
    # pylint: enable=unpacking-non-sequence
    sampled = math_ops.cast(sampled, dtypes.int64)

    # labels_flat is a [batch_size * num_true] tensor
    # sampled is a [num_sampled] int tensor
    all_ids = array_ops.concat([labels_flat, sampled], 0)

    # Retrieve the true weights and the logits of the sampled weights.

    # weights shape is [num_classes, dim]
    all_w = embedding_ops.embedding_lookup(
        weights, all_ids, partition_strategy=partition_strategy)
    if all_w.dtype != inputs.dtype:
      all_w = math_ops.cast(all_w, inputs.dtype)

    # true_w shape is [batch_size * num_true, dim]
    true_w = array_ops.slice(all_w, [0, 0],
                             array_ops.stack(
                                 [array_ops.shape(labels_flat)[0], -1]))

    sampled_w = array_ops.slice(
        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    # inputs has shape [batch_size, dim]
    # sampled_w has shape [num_sampled, dim]
    # Apply X*W', which yields [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)

    # Retrieve the true and sampled biases, compute the true logits, and
    # add the biases to the true and sampled logits.
    all_b = embedding_ops.embedding_lookup(
        biases, all_ids, partition_strategy=partition_strategy)
    if all_b.dtype != inputs.dtype:
      all_b = math_ops.cast(all_b, inputs.dtype)
    # true_b is a [batch_size * num_true] tensor
    # sampled_b is a [num_sampled] float tensor
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])

    # inputs shape is [batch_size, dim]
    # true_w shape is [batch_size * num_true, dim]
    # row_wise_dots is [batch_size, num_true, dim]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
    row_wise_dots = math_ops.multiply(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    # We want the row-wise dot plus biases which yields a
    # [batch_size, num_true] tensor of true_logits.
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat([[-1], dim], 0))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b
    sampled_logits += sampled_b

    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(
          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
                                        "sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          [array_ops.shape(labels)[:1],
           array_ops.expand_dims(num_sampled, 0)], 0)
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += gen_sparse_ops.sparse_to_dense(
          sparse_indices,
          sampled_logits_shape,
          acc_weights,
          default_value=0.0,
          validate_indices=False)

    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)

    # Construct output logits and labels. The true labels/logits start at col 0.
    out_logits = array_ops.concat([true_logits, sampled_logits], 1)

    # true_logits is a float tensor, ones_like(true_logits) is a float
    # tensor of ones. We then divide by num_true to ensure the per-example
    # labels sum to 1.0, i.e. form a proper probability distribution.
    out_labels = array_ops.concat([
        array_ops.ones_like(true_logits) / num_true,
        array_ops.zeros_like(sampled_logits)
    ], 1)

    return out_logits, out_labels

参考资料

[1] 从最优化的角度看待Softmax损失函数

[2] On word embeddings - Part 2: Approximating the Softmax

[3] 重要性采样

[4] Noise Contrastive Estimation