Module pysimt.layers.selector
A utility layer that returns a particular element from the previous layer.
Expand source code
"""A utility layer that returns a particular element from the previous layer."""
from torch import nn, Tensor
from typing import Iterable, Any
class Selector(nn.Module):
"""Utility layer that selects and returns a particular element out of
a tuple. It is useful to select a particular output from the previous layer,
when used in constructs such as `torch.nn.Sequential()`.
Args:
index: The position to select from the given input.
Example:
>>> layers = []
>>> layers.append(torch.nn.GRU(200, 400))
# By default, GRU returns (output, h_n) but we are not interested in h_n
>>> layers.append(Selector(0))
>>> layers.append(torch.nn.Dropout(0.2))
>>> self.block = nn.Sequential(*layers)
"""
def __init__(self, index: int):
""""""
super().__init__()
self.index = index
def forward(self, x: Iterable[Tensor]) -> Tensor:
"""Returns the pre-determined `self.index`'th position of `x`."""
return x[self.index]
def __repr__(self):
return f"Selector(index={self.index})"
Classes
class Selector (index: int)
-
Utility layer that selects and returns a particular element out of a tuple. It is useful to select a particular output from the previous layer, when used in constructs such as
torch.nn.Sequential()
.Args
index
- The position to select from the given input.
Example
>>> layers = [] >>> layers.append(torch.nn.GRU(200, 400)) # By default, GRU returns (output, h_n) but we are not interested in h_n >>> layers.append(Selector(0)) >>> layers.append(torch.nn.Dropout(0.2)) >>> self.block = nn.Sequential(*layers)
Expand source code
class Selector(nn.Module): """Utility layer that selects and returns a particular element out of a tuple. It is useful to select a particular output from the previous layer, when used in constructs such as `torch.nn.Sequential()`. Args: index: The position to select from the given input. Example: >>> layers = [] >>> layers.append(torch.nn.GRU(200, 400)) # By default, GRU returns (output, h_n) but we are not interested in h_n >>> layers.append(Selector(0)) >>> layers.append(torch.nn.Dropout(0.2)) >>> self.block = nn.Sequential(*layers) """ def __init__(self, index: int): """""" super().__init__() self.index = index def forward(self, x: Iterable[Tensor]) -> Tensor: """Returns the pre-determined `self.index`'th position of `x`.""" return x[self.index] def __repr__(self): return f"Selector(index={self.index})"
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, x: Iterable[torch.Tensor]) ‑> torch.Tensor
-
Returns the pre-determined
self.index
'th position ofx
.Expand source code
def forward(self, x: Iterable[Tensor]) -> Tensor: """Returns the pre-determined `self.index`'th position of `x`.""" return x[self.index]