Transformer ベースのモデル は、複雑なテキストを解析して解釈する機能で有名です。これらは、単語の順序と文脈を理解することに依存しており、従来の位置エンコーディング方法では限界が示されています。このギャップに対処するために、Rotary Position Embedding (RoPE) を活用した ROFORMER モデルは、位置エンコーディングへのアプローチを再定義します。
従来の位置エンコーディング
トランスフォーマーはテキストを一連のトークンとして扱い、シーケンスの並列処理を可能にして効率を高めます。ただし、この強みは、トークンの順序に対するモデル固有の不可知論という課題をもたらしました。 位置エンコーディングがその答えであり、各トークンにそのシーケンスの位置を示す一意の署名を提供します。
絶対位置の埋め込み
当初、BERT のようなモデルは絶対位置埋め込みを使用し、シーケンス内の各位置に固定ベクトルを割り当てました。この方法は単純ではありますが、シーケンスの長さの変化に適応したり、多くの言語構造を理解するために重要なトークン間の相対的な距離を強調したりする機能が本質的に欠けています。
相対位置の埋め込み
言語の動的な性質を捉えるために、絶対的な位置ではなくトークン間の距離に焦点を当てた相対位置埋め込みが導入されました。概念的な利点にもかかわらず、これらの埋め込みは 計算の複雑さをもたらし、トランスフォーマーの自己注意メカニズムにシームレスに統合できず、その有効性が制限されました。
ROFORMER とロータリー位置の埋め込み
既存の位置エンコーディング戦略の限界を認識し、ROFORMER は、絶対位置情報と相対位置情報の利点をそれぞれの欠点なしに組み合わせたアプローチである Rotary Position Embedding (RoPE) を導入しました。
ロータリー位置の埋め込み
RoPE は、回転行列を使用して位置情報をエンコードし、モデルがトークンがどこにあるかだけでなく、シーケンス内の他のすべてのトークンとどのように関連しているかを理解できるようにします。
Credit: ArXiv
これは幾何学的なレンズを通して機能し、トークンの位置を多次元空間内の点として扱い、それらの連続的な関係をマークするために回転されます。この回転により、モデルは自己注意メカニズム内で絶対位置と相対位置の両方の手がかりを保存し、活用することができます。
RoPE の実装
RoPE の実装には、各トークンの位置を回転行列にエンコードし、この行列を Transformer のセルフ アテンション メカニズム内に適用することが含まれます。このプロセスにより、位置情報の柔軟で動的な解釈が可能になり、さまざまなシーケンス長に対応し、大きな計算オーバーヘッドを発生させることなくトークンの相互関係の本質を捉えることができます。
まず、回転エンベディングを生成する関数が必要です。次に、これらのエンベディングをモデルに統合します。以下の例は、Keras でのカスタム レイヤーの作成に慣れていることを前提としています。
ステップ 1: ロータリー埋め込み関数を定義する
この関数は、最大シーケンス長と埋め込みの次元を指定して回転埋め込みを生成します。
from tensorflow.keras.layers import Layer
import numpy as np
def get_rotary_embedding(dim, max_seq_len):
inv_freq = 1.0 / (10000 ** (tf.range(0, dim, 2, dtype=tf.float32) / dim))
t = tf.range(max_seq_len, dtype=tf.float32)
freqs = tf.einsum('i,j->ij', t, inv_freq)
emb = tf.concat((tf.cos(freqs), tf.sin(freqs)), axis=-1)
return emb
inv_freq = 1.0 / (10000 ** (tf.range(0, dim, 2, dtype=tf.float32) / dim))
この行は、位置インデックスに基づいて指数関数的にスケールされた周波数の逆数を計算します。これらの周波数は、回転埋め込み用の正弦波パターンの生成に使用され、シーケンス内の相対位置情報のエンコードに役立ちます。このメカニズムは、自然言語処理や時系列分析など、要素の順序と相対的な位置を理解することが重要なタスクで特に役立ちます。
詳細に:
-
tf.range(0, dim, 2, dtype=tf.float32)
は、0 から 2 ずつステップでdim
(排他的) までの値の範囲を作成します。dtype=tf.float32
引数は、次の値を指定します。このテンソルの要素は 32 ビット浮動小数点数であることがわかります。たとえば、「dim」が 8 の場合、「[0, 2, 4, 6]」が生成されます。 -
tf.range
によって生成されたテンソルは、埋め込みの次元 (dim
) で除算されます。この操作は、これらのインデックスを 0 から 1 までの範囲にスケールダウンします (範囲ステップは 1 つおきの値をスキップするため、dim
が偶数の場合は排他的ですが、dim
が奇数の場合はわずかに排他的です)。dim
= 8 で例を続けると、8 で割ると[0.0, 0.25, 0.5, 0.75]
が得られます。 -
10000 ** (...)
演算は、以前にスケーリングされたテンソルの各要素の 10,000 乗を計算します。 10,000 の基数はある程度任意ですが、周波数が広範囲にわたって変化するように選択されており、モデルが異なる位置をより効果的に区別するのに役立ちます。[0.0, 0.25, 0.5, 0.75]
の場合、それぞれにべき乗演算が適用され、上位の要素ほど値が大きくなります。 -
最後に、前のステップの値の逆数 (1/x) を計算して逆周波数を取得します。逆周波数はインデックスが高くなるほど小さくなります。つまり、シーケンス内の要素の周波数が小さくなり、その位置がモデルにエンコードされる方法に影響します。これにより、モデルのアテンション メカニズムを通じて相対位置を推測できる方法で、埋め込みをスケーリングすることができます。
この線:
freqs = tf.einsum('i,j->ij', t, inv_freq)
TensorFlow の tf.einsum
関数を使用します。これは、Einstein 総和表記を使用してテンソル演算を簡潔かつ効率的に表現できるツールです。
この演算は、「t」ベクトルと「inv_freq」ベクトルの外積を効果的に計算し、各要素「(i, j)」が「t」の「i」番目の要素と「 inv_freq
の j 番目の要素。この行列 (「freqs」) は、回転埋め込みの正弦波パターンを生成するために使用される周波数を表します。
ステップ 2: ロータリー埋め込み用のカスタム Keras レイヤー
次に、入力テンソルに回転埋め込みを適用するカスタム Keras レイヤーを作成しましょう。この層は、入力テンソルの形状が「(batch_size, sequence_length, embedding_dim)」であると仮定します。
class RotaryEmbeddingLayer(Layer):
def __init__(self, dim, max_seq_len, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_seq_len = max_seq_len
self.rotary_embeddings = get_rotary_embedding(dim, max_seq_len)
def call(self, inputs):
seq_len = tf.shape(inputs)[1]
embeddings = self.rotary_embeddings[:seq_len]
cos_emb = embeddings[:, None, :self.dim // 2]
sin_emb = embeddings[:, None, self.dim // 2:]
# Decompose inputs into sine and cosine components
inputs_cos = inputs[..., :self.dim // 2]
inputs_sin = inputs[..., self.dim // 2:]
# Apply rotary embeddings
rotated_cos = inputs_cos * cos_emb - inputs_sin * sin_emb
rotated_sin = inputs_sin * cos_emb + inputs_cos * sin_emb
return tf.concat([rotated_cos, rotated_sin], axis=-1)
def get_config(self):
config = super().get_config()
config.update({
"dim": self.dim,
"max_seq_len": self.max_seq_len
})
return config
行 embeddings = self.rotary_embeddings[:seq_len]
は、現在の入力シーケンスの長さに基づいて、事前に計算された回転埋め込みの適切なサブセットを選択します。シーケンスの長さはバッチごとに異なる可能性があるため、このスライス操作により、実際のシーケンスの長さに対応する埋め込みのみが使用されることが保証されます。
変数 embeddings
は、形状 (seq_len, embedding_dim)
のテンソルを保持します。ここで、 seq_len
は現在のバッチ内のシーケンスの長さ、 embedding_dim
は埋め込みの次元です。このテンソルには、「seq_len」までのシーケンス内の各位置の回転位置埋め込みが含まれています。
emb = tf.concat((tf.cos(freqs), tf.sin(freqs)), axis=-1)
は、位置周波数のサイン変換とコサイン変換を単一のテンソルに結合します。
-tf.cos(freqs)
と tf.sin(freqs)
は、それぞれコサイン変換とサイン変換を freqs
テンソルに適用します。 「freqs」テンソルには、入力シーケンス内の各位置と、シーケンスの位置と埋め込み次元の逆周波数に基づいて計算された埋め込み空間の各次元の周波数値が含まれます。サイン関数とコサイン関数は要素ごとに適用され、その結果、「freqs」と同じ形状の 2 つのテンソルが生成されます。これらの変換は、位置関係の周期的な性質を捉える方法で位置をエンコードするのに役立ち、相対位置を理解するモデルの能力を促進します。
-tf.concat((tf.cos(freqs), tf.sin(freqs)), axis=-1)
は、最後の軸 (axis=-1
で示される) に沿ってコサインおよびサイン変換されたテンソルを連結します。これらのテンソルを並べて連結すると、「freqs」テンソルの次元が効果的に 2 倍になり、前半は各位置のコサイン変換された値を表し、後半はサイン変換された値を表します。連結により、各位置エンコードにサイン情報とコサイン情報の両方が含まれるようになり、位置信号の振幅と位相の両方に関する情報を保存できるようになります。
- 連結されたテンソル
emb
は、入力位置の完全な回転埋め込みを保持するようになりました。emb
の形状は、最初の 2 次元 (シーケンス位置と埋め込み次元に対応) ではfreqs
と同じになりますが、最後の次元はサイン値とコサイン値の両方を考慮して 2 倍の大きさになります。これらのエンベディングは、回転等価な方法で位置情報を追加することによって入力エンベディングを変調するために使用されます。
-cos_emb = embeddings[:, None, :self.dim // 2]
:
-
最初のコロン
:
は、「この次元内のすべての要素を選択する」ことを意味し、この場合、シーケンス内のすべての位置を指します。 -
None
を使用して追加の次元を追加し、テンソルを 3 次元にします。これは多くの場合、特定の次元数の入力を期待する特定の操作との互換性を確保するために行われます。たとえば、3 次元の別のテンソルと要素ごとの乗算を実行する場合、形状はブロードキャスト ルールに従って整列する必要があります。 -
:self.dim // 2
、最後の軸の次元の前半を選択します。embedding_dimension
はサイン値とコサイン値の両方を含むように 2 倍になるため、2 で割ることにより、埋め込みのコサイン成分のみが効果的に選択されます。
ステップ 3: Keras モデルとの統合
RotaryEmbeddingLayer
を定義したら、それを Keras モデルに統合できます。このレイヤーは、アテンション レイヤーまたは後続のモデル レイヤーにエンベディングをフィードする前に、エンベディングに適用する必要があります。
以下は、回転埋め込みをモデルに統合する方法の簡略化された例です。
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, Dense
max_seq_len = 512
embedding_dim = 64
inp = Input(shape=(max_seq_len,))
x = Embedding(input_dim=10000, output_dim=embedding_dim)(inp)
x = RotaryEmbeddingLayer(dim=embedding_dim, max_seq_len=max_seq_len)(x)
# Add your model's layers here, e.g., Transformer blocks
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inp, outputs=x)
model.summary()