Source code for tltorch.base
import torch
from collections import OrderedDict
from torch import nn
import weakref
class RemovableHandle(object):
"""A handle which provides the capability to remove a hook.
Adapted from Pytorch: torch.utils.hooks.py:7
"""
def __init__(self, hooks_dict, name):
# weakref will not keep the dict alive if there are no other references to it.
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.name = name
def remove(self):
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.name in hooks_dict:
del hooks_dict[self.name]
[docs]class TensorModule(nn.Module):
"""A PyTorch module augmented for tensor parametrization
"""
[docs] def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._decomposition_forward_pre_hooks = OrderedDict()
def register_decomposition_forward_pre_hook(self, hook, name=None):
"""Attach a new hook to be applied to the decomposition parametrizing the layer, before the forward.
Decomposition hooks are functions called before the forward pass
that take as input the module and the decomposition
and return a modified decomposition.
Decomposition hooks must be function with the following signature::
hook(module, decomposition) -> modified decomposition
"""
if name is None:
if hasattr(hook, 'name'):
name = hook.name
else:
name = hook.__class__.__name__
handle = RemovableHandle(self._decomposition_forward_pre_hooks, name)
self._decomposition_forward_pre_hooks[name] = hook
return handle
def get_decomposition(self):
"""Returns the tensor decomposition parametrizing the layer
"""
raise NotImplementedError()
def _process_decomposition(self):
"""Applies all the decomposition_forward_pre_hooks before returning the decomposition
This function should be used by all the Tensor layers to get their decomposition in the forward pass.
"""
decomposition = self.get_decomposition()
for hook in self._decomposition_forward_pre_hooks.values():
decomposition = hook(self, decomposition)
return decomposition