from experimaestro import Param
from xpm_torch.parameters import ParametersIterator
from xpm_torch.trainers.context import TrainState, InitializationTrainingHook
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]
class LayerFreezer(InitializationTrainingHook):
"""This training hook class can be used to freeze a subset of model
parameters"""
selector: Param[ParametersIterator]
"""How to select the layers to freeze"""
def __init__(self):
self._initialized = False
def after(self, state: TrainState):
if not self._initialized:
self._initialized = True
for name, module, param, to_freeze in self.selector.iter():
if to_freeze:
logger.info("Freezing layer %s", name)
param.requires_grad = False
[docs]
class LayerSharer(InitializationTrainingHook):
"""This training hook class can be used to share parameters"""
source: Param[ParametersIterator]
"""The parameters to share"""
target: Param[ParametersIterator]
"""The parameters to be shared"""
def __init__(self):
self._initialized = False
def after(self, state: TrainState):
if not self._initialized:
self._initialized = True
for source, target in zip(
self.source.selected(), self.target.selected(), strict=True
):
logger.info("Sharing layer %s -> %s", source.name, target.name)
target.set(source.parameter)