Source code for epynn.flatten.backward

# EpyNN/epynn/flatten/backward.py
# Related third party imports
import numpy as np


[docs]def initialize_backward(layer, dX): """Backward cache initialization. :param layer: An instance of flatten layer. :type layer: :class:`epynn.flatten.models.Flatten` :param dX: Output of backward propagation from next layer. :type dX: :class:`numpy.ndarray` :return: Input of backward propagation for current layer. :rtype: :class:`numpy.ndarray` """ dA = layer.bc['dA'] = dX return dA
[docs]def flatten_backward(layer, dX): """Backward propagate error gradients to previous layer. """ # (1) dA = initialize_backward(layer, dX) # (2) Reverse reshape (m, n) -> (m, ...) dX = layer.bc['dX'] = np.reshape(dA, layer.fs['X']) return dX # To previous layer