RNN Attention の理解と bias の議論

背景

Quora Insincere Questions Classification | Kaggle というコンペで public kernel で使われていた PyTorch Attention 実装に pitfall があったのではという指摘 Common pitfalls of public kernels | Kaggle があった。それを発端に Attention を正しく理解できていないことが分かったのでここにまとめる。

参照

前提 RNN を利用した Encoder& Decoder MNP モデル

特徴

  • timestep t において過去の入力列は fixed size の vector として表現される。この vector は sentence embedding と呼ぶこともできる。つまり文章を embedding したもの。(word embedding との対比を考えるとわかりやすい)
  • 長い sequence に弱い。timestep = 1 のときの入力から max_seq_len step 後に decode が始めるので timestep = 1 のときの影響力が弱い。
    • 語順が似ている言語同士の翻訳だと入力を reverse すると精度が上がるのは上記のような背景がある。

f:id:higepon:20190209144944j:plain

ここで

  •  a^{\langle t\rangle} は Endoder の timestep = t での hidden_state
  •  s^{\langle t^{\prime}\rangle} は Decoder の timestep = t' での hidden_state

Attention

理解のキーポイント

概要は C5W3L08 Attention Model - YouTube を見るのが良いと思う。入出力に注目すると

  • Decoder は timestep = 1 の decode 時に hidden_state vector ではなく c^{<1>} =  \sum_{t^{\prime}} \alpha^{\langle1, t^{\prime}\rangle}a^{\langle t^{\prime}\rangle} を受け取る。
    •  \alpha^{\langle1, t^{\prime}\rangle} が timestep = t' のときにおける入力に pay attention すべき weight。 (すべて足し合わせると1)。
    •  c^{<1>} は encoder の各 timestep における hidden_state の重み付きの和になっている。
  •  \alpha^{\langle t, t^{\prime}\rangle} y^{\langle t\rangle} を出力するときに  a^{\langle t^\prime\rangle} にどれだけ注目 (pay attention) するべきかを表す。
  • 直感的にはまさに下図のように Hi を出力しようとしているときには オッス:0.9 オラ:0.07 悟空:0.03 みたいな重みが提供されて(後述)。なるほどオッスがいちばん大事なのだなと分かるイメージ。

f:id:higepon:20190209155106j:plain

重み

上述の重みはどこからやってくるのか。ここを曖昧に理解していたのだが小さい Neural Network に学習させてしまおうというのが肝っぽい。 ここで下記のように  e^{\langle t, t^{\prime}\rangle} なるものを導入して、重みの和が1 になるようにする。

 \alpha^{\langle t, t^{\prime}\rangle} = \frac{\exp(e^{\langle t, t^{\prime}\rangle} )}{\sum_{t^\prime}^{Tx}{\exp(e^{\langle t, t^{\prime}\rangle}})}

この  e^{\langle t, t^{\prime}\rangle} を Neural Network で下図のように学習する。 f:id:higepon:20190209154549j:plain

PyTorch での Attention の実装例

Common pitfalls of public kernels | Kaggle で触れられているがオリジナルの実装者はここ

# Written by Benjamin Minixhofer
# in https://www.kaggle.com/bminixhofer/deterministic-neural-networks-using-pytorch
class Attention(nn.Module):
    def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
        super(Attention, self).__init__(**kwargs)
        
        self.supports_masking = True

        self.bias = bias
        self.feature_dim = feature_dim
        self.step_dim = step_dim
        self.features_dim = 0
        
        weight = torch.zeros(feature_dim, 1)
        nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight)
        
        if bias:
            self.b = nn.Parameter(torch.zeros(step_dim))
        
    def forward(self, x, mask=None):
        feature_dim = self.feature_dim
        step_dim = self.step_dim

        eij = torch.mm(
            x.contiguous().view(-1, feature_dim), 
            self.weight
        ).view(-1, step_dim)
        
        if self.bias:
            eij = eij + self.b
            
        eij = torch.tanh(eij)
        a = torch.exp(eij)
        
        if mask is not None:
            a = a * mask

        a = a / torch.sum(a, 1, keepdim=True) + 1e-10

        weighted_input = x * torch.unsqueeze(a, -1)
        return torch.sum(weighted_input, 1)

コードの解説

上述の Attention の解説と実装ではいくつか異なる点がある(おそらく亜種)

  • eij の学習時に self.weight だけでなく self.bias も使っている
  • mask をサポートしている(おそらく各入力が長さ違いで pad されているときに pad を無視するため)

基本的には self.weight と self.b が先程の Neural Network を定義するもの。self.weight は上述の  \alpha とは別物であることに注意。NN の出力サイズが (-1, step_dim) なのは各 timestep ごとに e を求めたいから。

問題とされた点

bias 項が以下のように sequence の位置に対して、それぞれ異なる値が持てるように定義されていること。

self.b = nn.Parameter(torch.zeros(step_dim))

Dmitriy Danevskiy さんは以下の2点を問題として挙げている。

  • 長さが異なる入力が想定されているのに step_dim が 70 として与えられている。これは長さが 70 の入力に対する bias になってしまう。truncate されていたり pad されていたりする場合に対応していない
  • bias が token の絶対的位置にひもづいているが、異なる入力 sequence に対して位置が重要なことはあまりない(本当?)

氏はそのかわり bias をすべての sequence element で共有するように定義すべきだと提案している。

筆者の疑問点

bias の値が token の位置に依存していなければ、学習によって bias の値はすべて同じような値になっていくので問題ないのではないか?

間違いを見つけたら

https://twitter.com/HigeponJa にお願いします。