Module pysimt.utils.topology

Expand source code
from collections import UserString, OrderedDict

from .. import datasets


class DataSource(UserString):
    def __init__(self, name, _type, src=False, trg=False):
        super().__init__(name)
        self._type = _type
        self.src = src
        self.trg = trg
        self.side = 'src' if self.src else 'trg'

        # Assign the method that knows how to create a tensor for a batch
        # of this type
        klass = getattr(datasets, '{}Dataset'.format(_type))
        self.kwargs = {}
        self.torchify = lambda batch: klass.to_torch(batch, **self.kwargs)

    def __repr__(self):
        return "DataSource('{}', kwargs:{})".format(self.data, self.kwargs)


class Topology:
    """A simple object that parses the direction string provided through the
        experiment configuration file.

        A direction is a string with the following syntax:
            feat:<type>, feat:<type>, ... -> feat:<type>, feat:<type>, ...

        where
            feat determines the name of the modality, i.e. 'en', 'image', etc.
            type is the prefix of the actual ``Dataset`` class to be used
                with this modality, i.e. Text, ImageFolder, OneHot, etc.
            if type is omitted, the default is Text.

        Example:
            de:Text (no target side)
            de:Text -> en:Text
            de:Text -> en:Text, en_pos:OneHot
            de:Text, image:ImageFolder -> en:Text
    """
    def __init__(self, direction):
        self.direction = direction
        self.srcs = OrderedDict()
        self.trgs = OrderedDict()
        self.all = OrderedDict()

        parts = direction.strip().split('->')
        if len(parts) == 1:
            srcs, trgs = parts[0].strip().split(','), []
        else:
            srcs = parts[0].strip().split(',') if parts[0].strip() else []
            trgs = parts[1].strip().split(',') if parts[1].strip() else []

        # Temporary dict to parse sources and targets in a single loop
        tmp = {'srcs': srcs, 'trgs': trgs}

        for key, values in tmp.items():
            _dict = getattr(self, key)
            for val in values:
                name, *ftype = val.strip().split(':')
                ftype = ftype[0] if len(ftype) > 0 else "Text"
                ds = DataSource(name, ftype,
                                src=(key == 'srcs'), trg=(key == 'trgs'))
                if name in self.all:
                    raise RuntimeError(
                        '"{}" already given as a data source.'.format(name))
                _dict[name] = ds
                self.all[name] = ds

        # Assign shortcuts
        self.first_src = list(self.srcs.keys())[0]
        self.first_trg = list(self.trgs.keys())[0]

    def is_included_in(self, t):
        """Return True if this topology is included in t, otherwise False."""
        if t is None:
            return False
        return (self.srcs.keys() <= t.srcs.keys()) and (self.trgs.keys() <= t.trgs.keys())

    def get_srcs(self, _type):
        return [v for v in self.srcs.values() if v._type == _type]

    def get_trgs(self, _type):
        return [v for v in self.trgs.values() if v._type == _type]

    def get_src_langs(self):
        return self.get_srcs('Text')

    def get_trg_langs(self):
        return self.get_trgs('Text')

    def __getitem__(self, key):
        return self.all[key]

    def __repr__(self):
        s = "Sources:\n"
        for x in self.srcs.values():
            s += " {}\n".format(x.__repr__())
        s += "Targets:\n"
        for x in self.trgs.values():
            s += " {}\n".format(x.__repr__())
        return s

Classes

class DataSource (name, _type, src=False, trg=False)

All the operations on a read-only sequence.

Concrete subclasses must override new or init, getitem, and len.

Expand source code
class DataSource(UserString):
    def __init__(self, name, _type, src=False, trg=False):
        super().__init__(name)
        self._type = _type
        self.src = src
        self.trg = trg
        self.side = 'src' if self.src else 'trg'

        # Assign the method that knows how to create a tensor for a batch
        # of this type
        klass = getattr(datasets, '{}Dataset'.format(_type))
        self.kwargs = {}
        self.torchify = lambda batch: klass.to_torch(batch, **self.kwargs)

    def __repr__(self):
        return "DataSource('{}', kwargs:{})".format(self.data, self.kwargs)

Ancestors

  • collections.UserString
  • collections.abc.Sequence
  • collections.abc.Reversible
  • collections.abc.Collection
  • collections.abc.Sized
  • collections.abc.Iterable
  • collections.abc.Container
class Topology (direction)

A simple object that parses the direction string provided through the experiment configuration file.

A direction is a string with the following syntax: feat:, feat:, … -> feat:, feat:, …

where feat determines the name of the modality, i.e. 'en', 'image', etc. type is the prefix of the actual Dataset class to be used with this modality, i.e. Text, ImageFolder, OneHot, etc. if type is omitted, the default is Text.

Example

de:Text (no target side) de:Text -> en:Text de:Text -> en:Text, en_pos:OneHot de:Text, image:ImageFolder -> en:Text

Expand source code
class Topology:
    """A simple object that parses the direction string provided through the
        experiment configuration file.

        A direction is a string with the following syntax:
            feat:<type>, feat:<type>, ... -> feat:<type>, feat:<type>, ...

        where
            feat determines the name of the modality, i.e. 'en', 'image', etc.
            type is the prefix of the actual ``Dataset`` class to be used
                with this modality, i.e. Text, ImageFolder, OneHot, etc.
            if type is omitted, the default is Text.

        Example:
            de:Text (no target side)
            de:Text -> en:Text
            de:Text -> en:Text, en_pos:OneHot
            de:Text, image:ImageFolder -> en:Text
    """
    def __init__(self, direction):
        self.direction = direction
        self.srcs = OrderedDict()
        self.trgs = OrderedDict()
        self.all = OrderedDict()

        parts = direction.strip().split('->')
        if len(parts) == 1:
            srcs, trgs = parts[0].strip().split(','), []
        else:
            srcs = parts[0].strip().split(',') if parts[0].strip() else []
            trgs = parts[1].strip().split(',') if parts[1].strip() else []

        # Temporary dict to parse sources and targets in a single loop
        tmp = {'srcs': srcs, 'trgs': trgs}

        for key, values in tmp.items():
            _dict = getattr(self, key)
            for val in values:
                name, *ftype = val.strip().split(':')
                ftype = ftype[0] if len(ftype) > 0 else "Text"
                ds = DataSource(name, ftype,
                                src=(key == 'srcs'), trg=(key == 'trgs'))
                if name in self.all:
                    raise RuntimeError(
                        '"{}" already given as a data source.'.format(name))
                _dict[name] = ds
                self.all[name] = ds

        # Assign shortcuts
        self.first_src = list(self.srcs.keys())[0]
        self.first_trg = list(self.trgs.keys())[0]

    def is_included_in(self, t):
        """Return True if this topology is included in t, otherwise False."""
        if t is None:
            return False
        return (self.srcs.keys() <= t.srcs.keys()) and (self.trgs.keys() <= t.trgs.keys())

    def get_srcs(self, _type):
        return [v for v in self.srcs.values() if v._type == _type]

    def get_trgs(self, _type):
        return [v for v in self.trgs.values() if v._type == _type]

    def get_src_langs(self):
        return self.get_srcs('Text')

    def get_trg_langs(self):
        return self.get_trgs('Text')

    def __getitem__(self, key):
        return self.all[key]

    def __repr__(self):
        s = "Sources:\n"
        for x in self.srcs.values():
            s += " {}\n".format(x.__repr__())
        s += "Targets:\n"
        for x in self.trgs.values():
            s += " {}\n".format(x.__repr__())
        return s

Methods

def get_src_langs(self)
Expand source code
def get_src_langs(self):
    return self.get_srcs('Text')
def get_srcs(self, _type)
Expand source code
def get_srcs(self, _type):
    return [v for v in self.srcs.values() if v._type == _type]
def get_trg_langs(self)
Expand source code
def get_trg_langs(self):
    return self.get_trgs('Text')
def get_trgs(self, _type)
Expand source code
def get_trgs(self, _type):
    return [v for v in self.trgs.values() if v._type == _type]
def is_included_in(self, t)

Return True if this topology is included in t, otherwise False.

Expand source code
def is_included_in(self, t):
    """Return True if this topology is included in t, otherwise False."""
    if t is None:
        return False
    return (self.srcs.keys() <= t.srcs.keys()) and (self.trgs.keys() <= t.trgs.keys())