Skip to content

Optimizer

The optimizer module contains utility functions for automatically building optimizers as well as a implementation for the new Muon optimizer.

build_optimizers(model_params, config)

Build a list of optimizers based on the model parameters and configuration.

Parameters:

Name Type Description Default
model_params List[Parameter]

List of model parameters to optimize.

required
config Mapping

Configuration dictionary containing optimizer settings.

required

Returns:

Type Description
list[Optimizer]

List[optim.Optimizer]: List of optimizers.

Source code in bfm/training/optimizer/builders.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def build_optimizers(
    model_params: list[torch.nn.Parameter], config: Mapping
) -> list[optim.Optimizer]:
    """
    Build a list of optimizers based on the model parameters and configuration.

    Args:
        model_params (List[torch.nn.Parameter]): List of model parameters to optimize.
        config (Mapping): Configuration dictionary containing optimizer settings.

    Returns:
        List[optim.Optimizer]: List of optimizers.
    """
    params = [p for p in model_params if p.requires_grad and p.is_floating_point()]
    optimizers: list[optim.Optimizer] = []

    use_muon = config["training"]["optimizer"] == "Muon"
    matrix_params = [p for p in params if p.ndim == 2] if use_muon else []
    other_params = [p for p in params if p.ndim != 2] if use_muon else params

    if use_muon and matrix_params:  # Muon only supports matrix parameters
        optimizers.append(
            Muon(
                matrix_params,
                lr=config["training"]["learning_rate"],
                momentum=0.95,
                nesterov=True,
                backend="newtonschulz5",
                backend_steps=5,
                weight_decay=config["training"]["weight_decay"],
            )
        )

    if other_params:  # covers: non-Muon path, or Muon’s leftovers, or Muon→fallback
        optimizers.append(
            optim.AdamW(
                other_params,
                lr=config["training"]["learning_rate"],
                weight_decay=config["training"]["weight_decay"],
                betas=(0.9, 0.95),
            )
        )

    return optimizers

build_schedulers(optimizers, config, training_setup)

Build learning rate schedulers for the given optimizers. Both warmup and falloff schedules are supported (both optional).

Parameters:

Name Type Description Default
optimizers Optimizer

List of optimizers to build schedulers for.

required
config Mapping

Configuration dictionary containing scheduler settings.

required
training_setup

Training setup object containing dataloaders.

required

Returns:

Type Description
list[_LRScheduler]

List[_LRScheduler]: List of learning rate schedulers.

Source code in bfm/training/optimizer/builders.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def build_schedulers(
    optimizers: list[optim.Optimizer], config: Mapping, training_setup
) -> list[_LRScheduler]:
    """
    Build learning rate schedulers for the given optimizers.
    Both warmup and falloff schedules are supported (both optional).

    Args:
        optimizers (optim.Optimizer): List of optimizers to build schedulers for.
        config (Mapping): Configuration dictionary containing scheduler settings.
        training_setup: Training setup object containing dataloaders.

    Returns:
        List[_LRScheduler]: List of learning rate schedulers.
    """
    schedulers: list[_LRScheduler] = []
    total_steps = config["training"]["n_epochs"] * len(training_setup.train_dataloader)

    for optimizer in optimizers:
        # Warmup schedule
        if config["training"]["warmup_steps"] > 0:
            warmup = LinearLR(
                optimizer,
                start_factor=1e-5,
                end_factor=1.0,
                total_iters=config["training"]["warmup_steps"],
            )
        else:
            warmup = None

        # Main schedule
        if config["training"]["lr_schedule"] == "linear":
            main = LinearLR(
                optimizer, start_factor=1.0, end_factor=1e-5, total_iters=total_steps
            )
        elif config["training"]["lr_schedule"] == "cosine":
            main = CosineAnnealingLR(optimizer, T_max=total_steps)
        else:
            main = None

        if warmup and main:
            schedulers.append(
                SequentialLR(
                    optimizer,
                    [warmup, main],
                    milestones=[config["training"]["warmup_steps"]],
                )
            )
        elif warmup:
            schedulers.append(warmup)
        elif main:
            schedulers.append(main)

    return schedulers

Muon is like the newest and coolest optimizer that works better than Adam.

Muon

Bases: Optimizer

Muon is a new optimizer that works better than Adam.

Parameters:

Name Type Description Default
params Parameter

Parameters to optimize.

required
lr float

Learning rate.

0.02
momentum float

Momentum factor.

0.95
nesterov bool

Whether to use Nesterov momentum.

True
weight_decay float

Weight decay (L2 penalty).

0.0
backend str

Backend to use for optimization.

'newtonschulz5'
backend_steps int

Number of steps for backend optimization.

5
Source code in bfm/training/optimizer/muon.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class Muon(torch.optim.Optimizer):
    """
    Muon is a new optimizer that works better than Adam.

    Args:
        params (torch.nn.Parameter): Parameters to optimize.
        lr (float): Learning rate.
        momentum (float): Momentum factor.
        nesterov (bool): Whether to use Nesterov momentum.
        weight_decay (float): Weight decay (L2 penalty).
        backend (str): Backend to use for optimization.
        backend_steps (int): Number of steps for backend optimization.
    """

    def __init__(
        self,
        params: torch.nn.Parameter,
        *,
        lr: float = 0.02,
        momentum: float = 0.95,
        nesterov: bool = True,
        weight_decay: float = 0.0,
        backend: str = "newtonschulz5",
        backend_steps: int = 5,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay,
            backend=backend,
            backend_steps=backend_steps,
        )
        super().__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            zeropower_backend = _zeroth_power_via_newtonschulz5

            for p in group["params"]:
                g = p.grad
                # assert g is not None
                # print(f"Param {i} has shape {p.shape}, and grad % of nan values: {torch.sum(torch.isnan(g)).item() / g.numel() * 100:.2f}%, and % of zeros: {torch.sum(g == 0).item() / g.numel() * 100:.2f}%")

                if g is None:
                    continue
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf = state["momentum_buffer"]
                buf.mul_(momentum).add_(g)
                if group["nesterov"]:
                    g = g.add(buf, alpha=momentum)
                g = zeropower_backend(g, steps=group["backend_steps"])
                # g.mul_(0.2 * max(g.shape[0], g.shape[1])**0.5) --- from moonlight paper
                if group["weight_decay"] > 0:
                    p.data.mul_(1 - group["weight_decay"] * lr)
                # p.data.add_(g, alpha=-lr * max(1, (g.shape[-2] / g.shape[-1]))**0.5) #XXX

                # print(f" == After Muon, % of zeros in grad: {torch.sum(g == 0).item() / g.numel() * 100:.2f}%")
                p.data.add_(g, alpha=-lr)

orthogonalize(G)

Orthogonalize G using the Newton-Schulz zeroth-power iteration (iterative inverse square root method).

Parameters:

Name Type Description Default
G Tensor

Input tensor to orthogonalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Orthogonalized tensor, ≈ G (GᵀG)^(-1/2).

Source code in bfm/training/optimizer/muon.py
26
27
28
29
30
31
32
33
34
35
36
37
def orthogonalize(G: torch.Tensor) -> torch.Tensor:
    """
    Orthogonalize G using the Newton-Schulz zeroth-power iteration
    (iterative inverse square root method).

    Args:
        G (torch.Tensor): Input tensor to orthogonalize.

    Returns:
        torch.Tensor: Orthogonalized tensor, ≈ G (GᵀG)^(-1/2).
    """
    return _zeroth_power_via_newtonschulz5(G, steps=10, eps=1e-8, abc=(3, -3.2, 1.2))