Skip to content
Snippets Groups Projects
Commit 0c167832 authored by SurajSSingh's avatar SurajSSingh
Browse files

Finalized data importing for model

parent 2029fdee
No related branches found
No related tags found
No related merge requests found
# Provided by under the Apache License, Version 2.0 (https://www.apache.org/licenses/LICENSE-2.0)
# Source: https://github.com/philipperemy/keras-attention-mechanism
import os
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer
# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.
debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0))
class Attention(object if debug_flag else Layer):
def __init__(self, units=128, **kwargs):
super(Attention, self).__init__(**kwargs)
self.units = units
# noinspection PyAttributeOutsideInit
def build(self, input_shape):
input_dim = int(input_shape[-1])
with K.name_scope(self.name if not debug_flag else 'attention'):
self.attention_score_vec = Dense(input_dim, use_bias=False, name='attention_score_vec')
self.h_t = Lambda(lambda x: x[:, -1, :], output_shape=(input_dim,), name='last_hidden_state')
self.attention_score = Dot(axes=[1, 2], name='attention_score')
self.attention_weight = Activation('softmax', name='attention_weight')
self.context_vector = Dot(axes=[1, 1], name='context_vector')
self.attention_output = Concatenate(name='attention_output')
self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')
if not debug_flag:
# debug: the call to build() is done in call().
super(Attention, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape[0], self.units
def __call__(self, inputs, training=None, **kwargs):
if debug_flag:
return self.call(inputs, training, **kwargs)
else:
return super(Attention, self).__call__(inputs, training, **kwargs)
# noinspection PyUnusedLocal
def call(self, inputs, training=None, **kwargs):
"""
Many-to-one attention mechanism for Keras.
@param inputs: 3D tensor with shape (batch_size, time_steps, input_dim).
@param training: not used in this layer.
@return: 2D tensor with shape (batch_size, units)
@author: felixhao28, philipperemy.
"""
if debug_flag:
self.build(inputs.shape)
# Inside dense layer
# hidden_states dot W => score_first_part
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
# W is the trainable weight matrix of attention Luong's multiplicative style score
score_first_part = self.attention_score_vec(inputs)
# score_first_part dot last_hidden_state => attention_weights
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps)
h_t = self.h_t(inputs)
score = self.attention_score([h_t, score_first_part])
attention_weights = self.attention_weight(score)
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
context_vector = self.context_vector([inputs, attention_weights])
pre_activation = self.attention_output([context_vector, h_t])
attention_vector = self.attention_vector(pre_activation)
return attention_vector
def get_config(self):
"""
Returns the config of a the layer. This is used for saving and loading from a model
:return: python dictionary with specs to rebuild layer
"""
config = super(Attention, self).get_config()
config.update({'units': self.units})
return config
\ No newline at end of file
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment