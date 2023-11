1

import keras from keras import ops class TokenAndPositionEmbedding ( keras.Layer ) : def __init__ ( self , max_length, vocab_size, embed_dim ) : super ( ) .__init__ ( ) self .token_embed = self .add_weight ( shape= ( vocab_size, embed_dim ) , initializer= "random_uniform" , trainable= True , ) self .position_embed = self .add_weight ( shape= ( max_length, embed_dim ) , initializer= "random_uniform" , trainable= True , ) def call ( self , token_ids ) : # Embed positions length = token_ids.shape [ -1 ] positions = ops.arange ( 0 , length , dtype= "int32" ) positions_vectors = ops.take ( self .position_embed, positions, axis= 0 ) # Embed tokens token_ids = ops.cast ( token_ids, dtype= "int32" ) token_vectors = ops.take ( self .token_embed, token_ids, axis= 0 ) # Sum both embed = token_vectors + positions_vectors # Normalize embeddings power_sum = ops.sum ( ops.square ( embed ) , axis= -1 , keepdims= True ) return embed / ops. sqrt ( ops.maximum ( power_sum, 1e-7 ) )