Skip to content
Snippets Groups Projects
Commit 6fead273 authored by SurajSSingh's avatar SurajSSingh
Browse files

Updated Attention model

parent 27cc5786
No related branches found
No related tags found
No related merge requests found
loss,categorical_crossentropy,categorical_accuracy,categorical_hinge,val_loss,val_categorical_crossentropy,val_categorical_accuracy,val_categorical_hinge
0.6102505922317505,0.6102591753005981,0.7721169590950012,0.5578135251998901,0.2365977168083191,0.23661620914936066,0.9482620358467102,0.21556608378887177
0.2410024255514145,0.24101324379444122,0.9489823579788208,0.22259768843650818,0.22830058634281158,0.22832421958446503,0.9520950317382812,0.219150111079216
0.22811080515384674,0.22812050580978394,0.9507037997245789,0.22212223708629608,0.20287160575389862,0.20289486646652222,0.9597100019454956,0.21067200601100922
0.2257438749074936,0.225755974650383,0.9488319158554077,0.22673241794109344,0.19838610291481018,0.19840732216835022,0.9598119854927063,0.1895865648984909
0.24677418172359467,0.24678997695446014,0.9415083527565002,0.24014297127723694,0.20171019434928894,0.20173151791095734,0.9596517086029053,0.1882585883140564
0.22916194796562195,0.22917920351028442,0.9510580897331238,0.22440536320209503,0.1999731957912445,0.19999487698078156,0.9598484039306641,0.19534263014793396
0.21525806188583374,0.2152738720178604,0.9554435610771179,0.21104231476783752,0.19979862868785858,0.19982029497623444,0.9598557353019714,0.19377031922340393
0.21643704175949097,0.2164536565542221,0.9556373953819275,0.21280315518379211,0.19801299273967743,0.19803403317928314,0.9598265886306763,0.1905786246061325
0.22762203216552734,0.2276403158903122,0.9521276950836182,0.2212676852941513,0.21021290123462677,0.21023407578468323,0.9591853022575378,0.20358486473560333
0.23046843707561493,0.2304823100566864,0.949951708316803,0.23079520463943481,0.19998398423194885,0.20000584423542023,0.9586096405982971,0.1999059021472931
0.21614259481430054,0.21616177260875702,0.9552162289619446,0.21368272602558136,0.19799767434597015,0.19801850616931915,0.9597464203834534,0.19059626758098602
0.2314920574426651,0.23151175677776337,0.950182318687439,0.23680256307125092,0.19927160441875458,0.1992924064397812,0.9598192572593689,0.19148322939872742
0.19912201166152954,0.19913870096206665,0.9593844413757324,0.197440966963768,0.1971137374639511,0.1971348226070404,0.9596953988075256,0.1911185085773468
0.22077631950378418,0.22079353034496307,0.9546881318092346,0.21700823307037354,0.19683417677879333,0.1968551129102707,0.9598119854927063,0.1903052181005478
0.19817253947257996,0.19818846881389618,0.9590468406677246,0.1978733092546463,0.19659003615379333,0.19661082327365875,0.9597901105880737,0.1899634152650833
0.21108753979206085,0.211105078458786,0.9546680450439453,0.21126605570316315,0.19589678943157196,0.19591759145259857,0.9598557353019714,0.1899382472038269
0.20053814351558685,0.2005544900894165,0.9582446217536926,0.20047253370285034,0.19655226171016693,0.19657309353351593,0.959863007068634,0.19047877192497253
0.19827072322368622,0.19828687608242035,0.95877605676651,0.19837716221809387,0.19621789455413818,0.19623875617980957,0.9598119854927063,0.189311683177948
0.2156815230846405,0.2156958281993866,0.955353319644928,0.2140825092792511,0.1963217407464981,0.19634267687797546,0.9598484039306641,0.19015157222747803
0.2074679136276245,0.20748521387577057,0.9574725031852722,0.2063201367855072,0.1954033225774765,0.19542405009269714,0.959863007068634,0.19042037427425385
root"_tf_keras_sequential*{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 10, 7]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "Attention", "config": {"name": "attention", "trainable": true, "dtype": "float32", "units": 16}}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "shared_object_id": 5, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 10, 7]}, "ndim": 3, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 10, 7]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 10, 7]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 10, 7]}, "float32", "input_1"]}, "keras_version": "2.8.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 10, 7]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "Attention", "config": {"name": "attention", "trainable": true, "dtype": "float32", "units": 16}, "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 4}]}}, "training_config": {"loss": {"class_name": "CategoricalCrossentropy", "config": {"reduction": "auto", "name": "categorical_crossentropy", "from_logits": true, "label_smoothing": 0.0, "axis": -1}, "shared_object_id": 7}, "metrics": [[{"class_name": "CategoricalCrossentropy", "config": {"name": "categorical_crossentropy", "dtype": "float32", "from_logits": true, "label_smoothing": 0}, "shared_object_id": 8}, {"class_name": "CategoricalAccuracy", "config": {"name": "categorical_accuracy", "dtype": "float32"}, "shared_object_id": 9}, {"class_name": "CategoricalHinge", "config": {"name": "categorical_hinge", "dtype": "float32"}, "shared_object_id": 10}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Adam", "config": {"name": "Adam", "learning_rate": 0.0010000000474974513, "decay": 0.0, "beta_1": 0.8999999761581421, "beta_2": 0.9990000128746033, "epsilon": 1e-07, "amsgrad": false}}}}2
root.layer_with_weights-0"_tf_keras_layer*{"name": "attention", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Attention", "config": {"name": "attention", "trainable": true, "dtype": "float32", "units": 16}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10, 7]}}2
root.layer_with_weights-1"_tf_keras_layer*{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 16}}, "shared_object_id": 11}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 16]}}2
root"_tf_keras_sequential*{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 10, 7]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "CustomAttention", "config": {"name": "custom_attention", "trainable": true, "dtype": "float32", "units": 32}}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "shared_object_id": 5, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 10, 7]}, "ndim": 3, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 10, 7]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 10, 7]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 10, 7]}, "float32", "input_1"]}, "keras_version": "2.8.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 10, 7]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "CustomAttention", "config": {"name": "custom_attention", "trainable": true, "dtype": "float32", "units": 32}, "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 4}]}}, "training_config": {"loss": {"class_name": "CategoricalCrossentropy", "config": {"reduction": "auto", "name": "categorical_crossentropy", "from_logits": true, "label_smoothing": 0.0, "axis": -1}, "shared_object_id": 7}, "metrics": [[{"class_name": "CategoricalCrossentropy", "config": {"name": "categorical_crossentropy", "dtype": "float32", "from_logits": true, "label_smoothing": 0}, "shared_object_id": 8}, {"class_name": "CategoricalAccuracy", "config": {"name": "categorical_accuracy", "dtype": "float32"}, "shared_object_id": 9}, {"class_name": "CategoricalHinge", "config": {"name": "categorical_hinge", "dtype": "float32"}, "shared_object_id": 10}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Adam", "config": {"name": "Adam", "learning_rate": 0.001, "decay": 0.0, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2
root.layer_with_weights-0"_tf_keras_layer*{"name": "custom_attention", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "CustomAttention", "config": {"name": "custom_attention", "trainable": true, "dtype": "float32", "units": 32}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10, 7]}}2
root.layer_with_weights-1"_tf_keras_layer*{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 32}}, "shared_object_id": 11}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 32]}}2
 -root.layer_with_weights-0.attention_score_vec"_tf_keras_layer*{"name": "attention_score_vec", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "attention_score_vec", "trainable": true, "dtype": "float32", "units": 7, "activation": "linear", "use_bias": false, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 12}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 13}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 14, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 7}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10, 7]}}2
 root.layer_with_weights-0.h_t"_tf_keras_layer*{"name": "last_hidden_state", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Lambda", "config": {"name": "last_hidden_state", "trainable": true, "dtype": "float32", "function": {"class_name": "__tuple__", "items": ["4wEAAAAAAAAAAAAAAAEAAAAFAAAAUwAAAHMWAAAAfABkAGQAhQJkAWQAZACFAmYDGQBTACkCTun/\n////qQApAdoBeHICAAAAcgIAAAD6MS9ob21lL250b3AvUHljaGFybVByb2plY3RzL2ZpbmFsX2xh\nYi9hdHRlbnRpb24ucHnaCDxsYW1iZGE+GwAAAHMCAAAAFgA=\n", null, null]}, "function_type": "lambda", "module": "attention", "output_shape": {"class_name": "__tuple__", "items": [7]}, "output_shape_type": "raw", "output_shape_module": null, "arguments": {}}, "shared_object_id": 16}2
 root.layer_with_weights-0.h_t"_tf_keras_layer*{"name": "last_hidden_state", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Lambda", "config": {"name": "last_hidden_state", "trainable": true, "dtype": "float32", "function": {"class_name": "__tuple__", "items": ["4wEAAAAAAAAAAAAAAAEAAAAFAAAAUwAAAHMWAAAAfABkAGQAhQJkAWQAZACFAmYDGQBTACkCTun/\n////qQApAdoBeHICAAAAcgIAAAD6TS9Vc2Vycy9ub3dhZG1pbi9Eb2N1bWVudHMvU2Nob29sIEZv\nbGRlci9DUyA0MzcvTGFiL0ZpbmFsIFByb2plY3QvYXR0ZW50aW9uLnB52gg8bGFtYmRhPhoAAABz\nAgAAABYA\n", null, null]}, "function_type": "lambda", "module": "attention", "output_shape": {"class_name": "__tuple__", "items": [7]}, "output_shape_type": "raw", "output_shape_module": null, "arguments": {}}, "shared_object_id": 16}2
)root.layer_with_weights-0.attention_score"_tf_keras_layer*{"name": "attention_score", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dot", "config": {"name": "attention_score", "trainable": true, "dtype": "float32", "axes": [1, 2], "normalize": false}, "shared_object_id": 17, "build_input_shape": [{"class_name": "TensorShape", "items": [null, 7]}, {"class_name": "TensorShape", "items": [null, 10, 7]}]}2
*root.layer_with_weights-0.attention_weight"_tf_keras_layer*{"name": "attention_weight", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Activation", "config": {"name": "attention_weight", "trainable": true, "dtype": "float32", "activation": "softmax"}, "shared_object_id": 18}2
(root.layer_with_weights-0.context_vector"_tf_keras_layer*{"name": "context_vector", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dot", "config": {"name": "context_vector", "trainable": true, "dtype": "float32", "axes": [1, 1], "normalize": false}, "shared_object_id": 19, "build_input_shape": [{"class_name": "TensorShape", "items": [null, 10, 7]}, {"class_name": "TensorShape", "items": [null, 10]}]}2
*root.layer_with_weights-0.attention_output"_tf_keras_layer*{"name": "attention_output", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Concatenate", "config": {"name": "attention_output", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 20, "build_input_shape": [{"class_name": "TensorShape", "items": [null, 7]}, {"class_name": "TensorShape", "items": [null, 7]}]}2
*root.layer_with_weights-0.attention_vector"_tf_keras_layer*{"name": "attention_vector", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "attention_vector", "trainable": true, "dtype": "float32", "units": 16, "activation": "tanh", "use_bias": false, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 21}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 22}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 23, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 14}}, "shared_object_id": 24}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 14]}}2
droot.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 25}2
eroot.keras_api.metrics.1"_tf_keras_metric*{"class_name": "CategoricalCrossentropy", "name": "categorical_crossentropy", "dtype": "float32", "config": {"name": "categorical_crossentropy", "dtype": "float32", "from_logits": true, "label_smoothing": 0}, "shared_object_id": 8}2
froot.keras_api.metrics.2"_tf_keras_metric*{"class_name": "CategoricalAccuracy", "name": "categorical_accuracy", "dtype": "float32", "config": {"name": "categorical_accuracy", "dtype": "float32"}, "shared_object_id": 9}2
groot.keras_api.metrics.3"_tf_keras_metric*{"class_name": "CategoricalHinge", "name": "categorical_hinge", "dtype": "float32", "config": {"name": "categorical_hinge", "dtype": "float32"}, "shared_object_id": 10}2
\ No newline at end of file
*root.layer_with_weights-0.attention_vector"_tf_keras_layer*{"name": "attention_vector", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "class_name": "Dense", "config": {"name": "attention_vector", "trainable": true, "dtype": "float32", "units": 32, "activation": "tanh", "use_bias": false, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 21}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 22}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 23, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 14}}, "shared_object_id": 24}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 14]}}2
_root.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 25}2
`root.keras_api.metrics.1"_tf_keras_metric*{"class_name": "CategoricalCrossentropy", "name": "categorical_crossentropy", "dtype": "float32", "config": {"name": "categorical_crossentropy", "dtype": "float32", "from_logits": true, "label_smoothing": 0}, "shared_object_id": 8}2
aroot.keras_api.metrics.2"_tf_keras_metric*{"class_name": "CategoricalAccuracy", "name": "categorical_accuracy", "dtype": "float32", "config": {"name": "categorical_accuracy", "dtype": "float32"}, "shared_object_id": 9}2
broot.keras_api.metrics.3"_tf_keras_metric*{"class_name": "CategoricalHinge", "name": "categorical_hinge", "dtype": "float32", "config": {"name": "categorical_hinge", "dtype": "float32"}, "shared_object_id": 10}2
\ No newline at end of file
No preview for this file type
No preview for this file type
No preview for this file type
File added
File added
......@@ -6,17 +6,16 @@ from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer
# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
# In debug mode, the class Attention is no longer a Keras layer.
# In debug mode, the class CustomAttention is no longer a Keras layer.
# What it means in practice is that we can have access to the internal values
# of each tensor. If we don't use debug, Keras treats the object
# as a layer and we can only get the final output.
debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0))
class Attention(object if debug_flag else Layer):
class CustomAttention(object if debug_flag else Layer):
def __init__(self, units=128, **kwargs):
super(Attention, self).__init__(**kwargs)
super(CustomAttention, self).__init__(**kwargs)
self.units = units
# noinspection PyAttributeOutsideInit
......@@ -32,7 +31,7 @@ class Attention(object if debug_flag else Layer):
self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')
if not debug_flag:
# debug: the call to build() is done in call().
super(Attention, self).build(input_shape)
super(CustomAttention, self).build(input_shape)
def compute_output_shape(self, input_shape):
return input_shape[0], self.units
......@@ -41,7 +40,7 @@ class Attention(object if debug_flag else Layer):
if debug_flag:
return self.call(inputs, training, **kwargs)
else:
return super(Attention, self).__call__(inputs, training, **kwargs)
return super(CustomAttention, self).__call__(inputs, training, **kwargs)
# noinspection PyUnusedLocal
def call(self, inputs, training=None, **kwargs):
......@@ -75,6 +74,6 @@ class Attention(object if debug_flag else Layer):
Returns the config of a the layer. This is used for saving and loading from a model
:return: python dictionary with specs to rebuild layer
"""
config = super(Attention, self).get_config()
config = super(CustomAttention, self).get_config()
config.update({'units': self.units})
return config
\ No newline at end of file
model_checkpoint_path: "TEST"
all_model_checkpoint_paths: "TEST"
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment