06月06, 2020

XLNET原理和源码分析

有很多大牛在网上分享了从不同角度对模型做了解读,下面是主要贡献

难点

1. permutation 机制的实现与论文有差异

  # Create `perm_mask`
  # `target_tokens` cannot see themselves
  self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

  # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
  # 0: can attend if i > j or j is non-masked
  perm_mask = tf.logical_and(
      self_rev_index[:, None] <= rev_index[None, :],
      masked_or_func_tokens)
  perm_mask = tf.cast(perm_mask, tf.float32)

里面有一段注释 :

# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked

这里呢就是和论文不一样的地方(论文中给人感觉是 一个位置只能看到前面的部分):

  • 对于所有的非 masked 或者非 func (, 的位置是可以被所有其他位置看到的

里面有个shuffle index操作这个其实并不能对应上论文中的permutation ,因为这个shuffle只对masked的和func 位置有作用。 举个例子 3和 8两个位置是target,在shuffle过程中 如果 3位置的index shuffle成了 10, 8 位置的index shuffle成了2,那么10 > 2 ,所以 perm_mask[3][8] 是 0 (因为10 > 2) 所以 3 位置可以看到位置8,但是 perm_mask[8][3]=1 意思是 8位置看不到 3, 另外masked (即target位置)看不到自己

  • func token和 masked 流程一样,区别在于他们能看到自己
  • 关于 memory 概念,很容易理解,具体可以看transformer-xl的介绍,然后里面的reuse其实也很简单,就是每次会复用前面的前面的一半数据

2.双流注意力的G(查询流)可以看到H(内容流)

image.png

这里对layer0 选取了 非 target 的 h1 和 target h6 以及 target g3 来画图

  • h1 含有了 出去 h3,h6, h7 以外的信息
  • h6 (target位置,h层) 则包含了h1,h2,h4,h5,h6的信息 (h attention 是可以看到自己的)
  • g3 (target位置,query 部分) 则包含了 h1,h2,h4,h5,h6,h7 的信息

这里可以比较下 g6 其实是 看不到 h6 和 h7的信息的

那么到这里h层为什么要包含自己的信息,也可以理解到,因为部分target 位置的 g 是可以看到别的 target位置的 h 信息的,所以这样两个不同target位置的 g 才会有差别,这个也就是论文中2.3 开头所说的当permutation前缀一样的时候,我们怎么去区分不同的预测目标

最后 经过n个layer 后 g 层的输出 用来算 loss, 这里要注意的是只有被选做 target的才会计入 loss (即 位置3 和位置6)

注意最后一个位置是CLS 他是有一定的概率看到所有的位置的content 信息的,所以在fine-tuning的时候实现默认拿了最后一个的输出 作为 句子的embedding

3.相对位置编码的详细计算过程

在标准的Transformer里,同一个segment的qi和kj的attention score可以这样分解:

A_{i,j}^{abs} = \underbrace{E^T_{x_i}W_q^TW_kE_{x_j}}{(a)}+\underbrace{E^T{x_i}W_q^TW_kU_j}{(b)} + \underbrace{U_i^TW_q^TW_kE{x_j}}{(c)}+\underbrace{U_i^TW_q^TW_kU_j}{(d)}

参考上面的公式,并且因为希望只考虑相对的位置,所以我们(Transformer-XL)提出如下的相对位置Attention计算公式

A_{i,j}^{rel} = \underbrace{E^T_{x_i}W_q^TW_{k,E}E_{x_j}}{(a)}+\underbrace{E^T{x_i}W_q^TW_{k,R}\color{blue}{R_{i-j}}}{(b)}+ \underbrace{\color{red}{u^T}W{k,E}E_{x_j}}{(c)} + \underbrace{\color{red}{v^T}W{k,R}\color{blue}{R_{i-j}}}_{(d)}

  • 和前面的Aabsi,j相比,第一个是把(b)和(d)里的绝对位置编码Uj都替换成相对位置编码向量Ri−j。注意这里的R是之前介绍的”相对”的正弦函数的编码方式,它是固定的没有可以学习的参数。

  • 在(c)中用可训练的u∈Rd替代原来的UTiWTq。因为我们假设Attention score只依赖于i和j的相对位置,而与i的绝对位置无关,所以这里对于所有i都相同。也就是UTWTq,所以可以用一个新的u来表示。类似的是(d)中的v∈Rd。

  • 最后,我们把key的变换矩阵Wk拆分成Wk,E和Wk,R,分别表示与内容相关的key和与位置相关的key。

在上面的新公式里,每一项的意义都非常清晰:(a)表示内容的计算,也就是xi的Embedding乘以变换矩阵Wq和xj的Embedding乘以Wk,E的内积;(b)表示基于内容的位置偏置,也就是i的向量乘以相对位置编码;(c)全局的内容偏置;(d)全局的位置偏置。

4. XLNetModel在fine-tuning和pretraining的异同

温故而知新,我们再来看一下这个类的构造函数的参数,读者可以对比一下fine-tuning节点参数不同的地方。

  • xlnet_config: XLNetConfig,XLNet模型结构的超参数,比如层数,head数量等等

  • run_config: RunConfig,运行时的超参数,包括dropout、初始范围等等。

  • input_ids: int32 Tensor,shape是[len, bsz], 输入token的ID

  • seg_ids: int32 Tensor,shape是[len, bsz], 输入的segment ID

  • input_mask: float32 Tensor,shape是[len, bsz], 输入的mask,0是真正的tokens而1是padding的

  • mems: list,每个元素是float32 Tensors,shape是[mem_len, bsz, d_model], 上一个batch的memory。fine-tuning为None

  • perm_mask: float32 Tensor,shape是[len, len, bsz]。fine-tuning为None

    1) 如果perm_mask[i, j, k] = 0,则batch k的第i个Token可以attend to j

    2) 如果perm_mask[i, j, k] = 1, 则batch k的第i个Token不可以attend to j

    3) 如果是None,则每个位置都可以attend to 所有其它位置(包括自己)。

  • target_mapping: float32 Tensor,shape是[num_predict, len, bsz]。fine-tuning为None

    1) 如果target_mapping[i, j, k] = 1,则batch k的第i个要预测的是第j个Token,这是一种one-hot表示 2) 只是在pretraining的partial prediction时使用,finetuning时设置为None

  • inp_q: float32 Tensor,shape是[len, bsz]。fine-tuning为None

    需要计算loss的(Mask的位置)为1,不需要的值为0,只在pretraining使用,finetuning时应为None

5. FineTune喂数据给模型

见classifier_utils.py的convert_single_example函数

1) input_ids

2) seg_ids

3) target_mask

4) label_id

本文链接:http://57km.cc/post/XLNET methodology and dive into source code.html

-- EOF --

Comments