Source code for torchoutil.optim.utils
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from typing import Dict, Iterable, List, Optional, Tuple, Union
from torch import nn
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
pylog = logging.getLogger(__name__)
[docs]def get_lr(optim: Optimizer, idx: int = 0, key: str = "lr") -> float:
"""
Get the learning rate of the first group of an optimizer.
Args:
optim: The optimizer to get.
idx: The group index of the learning rate in the optimizer. defaults to 0.
"""
return get_lrs(optim, key)[idx]
[docs]def get_lrs(optim: Optimizer, key: str = "lr") -> List[float]:
"""
Get the learning rates in all groups of an optimizer.
Args:
optim: The optimizer to get.
"""
return [group[key] for group in optim.param_groups]
[docs]def create_params_groups_bias(
model: Union[nn.Module, Iterable[Tuple[str, Parameter]]],
weight_decay: float,
skip_list: Optional[Iterable[str]] = (),
verbose: int = 2,
) -> List[Dict[str, Union[List[Parameter], float]]]:
if isinstance(model, nn.Module):
params = model.named_parameters()
else:
params = model
del model
decay: List[Parameter] = []
no_decay: List[Parameter] = []
if skip_list is None:
skip_list = {}
else:
skip_list = dict.fromkeys(skip_list)
for name, param in params:
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
no_decay.append(param)
if verbose >= 2:
pylog.debug(f"No wd for {name}")
else:
decay.append(param)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]