neupy.layers.GatedAverage

class neupy.layers.GatedAverage[source]

Using output from the gated layer weights outputs from the other layers and sum them.

Parameters:
gating_layer_index : int

Input layers passed as a list and current variable specifies index in which it can find gating network. Defaults to 0, which means that it expects to see gating layer in zeros position.

name : str or None

Layer’s identifier. If name is equal to None than name will be generated automatically. Defaults to None.

Examples

>>> from neupy.layers import *
>>>
>>> gating_network = Input(10) > Softmax(2)
>>> network_1 = Input(20) > Relu(10)
>>> network_2 = Input(20) > Relu(20) > Relu(10)
>>>
>>> network = [gating_network, network_1, network_2] > GatedAverage()
>>> network
[(10,), (20,), (20,)] -> [... 8 layers ...] -> 10
Attributes:
input_shape : tuple

Returns layer’s input shape in the form of a tuple. Shape will not include batch size dimension.

output_shape : tuple

Returns layer’s output shape in the form of a tuple. Shape will not include batch size dimension.

training_state : bool

Defines whether layer in training state or not. Training state will enable some operations inside of the layers that won’t work otherwise.

parameters : dict

Parameters that networks uses during propagation. It might include trainable and non-trainable parameters.

graph : LayerGraph instance

Graphs that stores all relations between layers.

Methods

disable_training_state() Context manager that switches off trainig state.
initialize() Set up important configurations related to the layer.
gating_layer_index = None[source]
options = {'gating_layer_index': Option(class_name='GatedAverage', value=IntProperty(name="gating_layer_index")), 'name': Option(class_name='BaseLayer', value=Property(name="name"))}[source]
output(*input_values)[source]

Return output base on the input value.

Parameters:
input_value
output_shape[source]
validate(input_shapes)[source]

Validate input shape value before assigning it.

Parameters:
input_shape : tuple with int