Module pysimt.layers.fusion

A convenience layer that merges an arbitrary number of inputs.

Expand source code
"""A convenience layer that merges an arbitrary number of inputs."""

import operator
from typing import Optional
from functools import reduce

import torch

from . import FF
from ..utils.nn import get_activation_fn


class Fusion(torch.nn.Module):
    """A convenience layer that merges an arbitrary number of inputs using
    concatenation, addition or multiplication. It then applies an optional
    non-linearity given by the `activ` argument. If `operation==concat`,
    additional arguments should be provided to define an adaptor MLP
    that will project the concatenated vector into a lower dimensional space.

    Args:
        operation: `concat`, `sum` or `mul` for concatenation, addition, and
            multiplication respectively
        activ: The activation function name that will be searched
            in `torch` and `torch.nn.functional` modules. `None` or `linear`
            disables the activation function
        input_size: Only required for `concat` fusion, to denote the concatenated
            input vector size. This will be used to add an MLP adaptor layer
            after concatenation to project the fused vector into a lower
            dimension
        output_size: Only required for `concat` fusion, to denote the
            output size of the aforementioned adaptor layer
    """
    def __init__(self,
                 operation: str = 'concat',
                 activ: Optional[str] = 'linear',
                 input_size: Optional[int] = None,
                 output_size: Optional[int] = None):
        """"""
        super().__init__()

        self.operation = operation
        self.activ = activ
        self.forward = getattr(self, '_{}'.format(self.operation))
        self.activ = get_activation_fn(activ)
        self.adaptor = lambda x: x

        if self.operation == 'concat' or input_size != output_size:
            self.adaptor = FF(input_size, output_size, bias=False, activ=None)

    def _sum(self, inputs):
        return self.activ(self.adaptor(reduce(operator.add, inputs)))

    def _mul(self, inputs):
        return self.activ(self.adaptor(reduce(operator.mul, inputs)))

    def _concat(self, inputs):
        return self.activ(self.adaptor(torch.cat(inputs, dim=-1)))

    def __repr__(self):
        return f"Fusion(type={self.operation}, activ={self.activ})"

Classes

class Fusion (operation: str = 'concat', activ: Union[str, NoneType] = 'linear', input_size: Union[int, NoneType] = None, output_size: Union[int, NoneType] = None)

A convenience layer that merges an arbitrary number of inputs using concatenation, addition or multiplication. It then applies an optional non-linearity given by the activ argument. If operation==concat, additional arguments should be provided to define an adaptor MLP that will project the concatenated vector into a lower dimensional space.

Args

operation
concat, sum or mul for concatenation, addition, and multiplication respectively
activ
The activation function name that will be searched in torch and torch.nn.functional modules. None or linear disables the activation function
input_size
Only required for concat fusion, to denote the concatenated input vector size. This will be used to add an MLP adaptor layer after concatenation to project the fused vector into a lower dimension
output_size
Only required for concat fusion, to denote the output size of the aforementioned adaptor layer
Expand source code
class Fusion(torch.nn.Module):
    """A convenience layer that merges an arbitrary number of inputs using
    concatenation, addition or multiplication. It then applies an optional
    non-linearity given by the `activ` argument. If `operation==concat`,
    additional arguments should be provided to define an adaptor MLP
    that will project the concatenated vector into a lower dimensional space.

    Args:
        operation: `concat`, `sum` or `mul` for concatenation, addition, and
            multiplication respectively
        activ: The activation function name that will be searched
            in `torch` and `torch.nn.functional` modules. `None` or `linear`
            disables the activation function
        input_size: Only required for `concat` fusion, to denote the concatenated
            input vector size. This will be used to add an MLP adaptor layer
            after concatenation to project the fused vector into a lower
            dimension
        output_size: Only required for `concat` fusion, to denote the
            output size of the aforementioned adaptor layer
    """
    def __init__(self,
                 operation: str = 'concat',
                 activ: Optional[str] = 'linear',
                 input_size: Optional[int] = None,
                 output_size: Optional[int] = None):
        """"""
        super().__init__()

        self.operation = operation
        self.activ = activ
        self.forward = getattr(self, '_{}'.format(self.operation))
        self.activ = get_activation_fn(activ)
        self.adaptor = lambda x: x

        if self.operation == 'concat' or input_size != output_size:
            self.adaptor = FF(input_size, output_size, bias=False, activ=None)

    def _sum(self, inputs):
        return self.activ(self.adaptor(reduce(operator.add, inputs)))

    def _mul(self, inputs):
        return self.activ(self.adaptor(reduce(operator.mul, inputs)))

    def _concat(self, inputs):
        return self.activ(self.adaptor(torch.cat(inputs, dim=-1)))

    def __repr__(self):
        return f"Fusion(type={self.operation}, activ={self.activ})"

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, *input: Any) ‑> NoneType

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def _forward_unimplemented(self, *input: Any) -> None:
    r"""Defines the computation performed at every call.

    Should be overridden by all subclasses.

    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.
    """
    raise NotImplementedError