Skip to main content

H-Net - Scaling Laws (Byte)

·3380 words
Table of Contents
H-Net - This article is part of a series.
Part 6: This Article

To justify scaling up H-Nets, we must prove it scales well with compute.

Experiment setup
#

Prior to my investigation into μP, I conducted various compute optimality experiments with Byte-level H-Nets.

Expectations
#

What should a good BPE-killer do?

Opinions may differ, but there is one basic expectation that seems to echo through most papers:

Given a fixed pretraining compute budget, using the same dataset, the bits-per-byte (BPB) of the byte-level architecture should be lower than what a standard BPE Transformer LLM can achieve.

This is, strictly speaking, an unfair comparison that favors byte-level models over BPE models, due to underestimation of the latter’s performance via greedy tokenization.

If your byte-level architecture cannot win, even under this biased comparison, it should just thrown into the garbage bin, regardless of whether it’s “cool” or can solve certain edge char-level tasks. Compute is always the constraining factor (or, nowadays, inference compute, which training compute is an OK proxy for here)

Blackbox setup
#

The first thing I did, after having efficient training implemented && basic exploration done, was to conduct blackbox hparam optimization for best cross entropy loss for varying fixed-flops budgets.

To be pedantic:

  1. Using a byte-length sorted dataset of Fineweb-10BT for all experiments,

    • That is, with each document sorted by (bytelen, sha256hash)
    • Documents are packed up to a per-run token-limit-per-batch, mbs. At each step, this produces batches that have total token counts in range(mbs-max_seqlen, mbs)+1
    • Each document has one BOS prepended, and docs are concatenated to a varlen batch without slicing. lbls = iids.roll(-1,0)
  2. Starting with pure torch implementations of H-Net, without any special init scheme / lr rules,

    • Varlen documents are always computed with no cross-document interaction.
    • The H-Net is designed such that transformer-based $S=0$ is llama2.
    • Use LR Modulation for $S\gt 0$. For $S=0$, LR Modulation is a constant factor that can be subsumed into sweeping, but I disable it anyway.
    • Always trained on 1GPU, with amp bf16 & fp32 weight, foreach torch adamw $\beta=(0.9,0.95),\lambda=0$. WSD lr sched with 10%-80%-10% split.
    • Use fixed $N=[1,5]$ or $N=[1,3,9]$. In the future, I vary this, but these are empirically good heuristic values.
  3. Add a flops calculation method to that module, which produces the theoretical Tensor Core FLOPs required for a given input batch.

    • In practice, I use microbatchsize & max_seqlen to approximate this. Due to sequence length sorting && microbatchsize being high, max_seqlen-min_seqlen is typically small, so this is a very close approximation.
    • I do not include scalar FLOPs, as it is quite meaningless to combine them with Tensor FLOPs at the same magnitude. This is a deviation from the H-Net paper, albeit extremely minor. Future work may consider combining them with a reasonable ratio to match hardware constraints.
    • The flops calculation is always done under the theoretical assumption of 0 recompute. In practice, both transformers and mamba use recompute by default (in flash-attn & mamba-ssm), but I think it is not scientifically helpful to include this.
  4. Using that method, all training runs have step counts determined by FLOPs consumption rather than fixed steps. The LR scheduler’s state is determined by the % of flops consumed.

    • For a BPE LLM with linear attention, the % of flops consumed per step is a constant, making the behavior identical to typical step-constant WSD.
    • For a transformer LLM, there is a very small increase in flops per step due to max_seqlen increases, but this is only non-negligible towards the end of the dataset, which is not encountered below 1e18 flops.
    • For an $S>0$ H-Net, FLOPs are inherently dynamic and so becomes LR warmup/cooldown.

    For reference, here’s an image of the 1e17 llama2 case:

    And here’s a reference for a 1e18 byte $S=1$ sweep:

    As you can see, the llama2 runs are ~constant steps throughout, whereas the $S=1$ case has target steps && lr curvature naturally adjusted to account for inc/dec in total FLOPs consumed (and also has LR Modulation)

  5. For all (architectures,FLOPs budget), sweep over reasonable bounds of width+depth+lr+bsz.

    • I use HEBO(..., model_name="gp") to sweep over a design space of all those constraints.
      • I use rand_sample=max(2*len(gpu_ids), int(trials*0.4)) to avoid overfitting on a small number of guesses. trials is typically 100, but is sometimes cut to 30/50/70 due to compute poverty.
      • I use base2 pow ranges for lr/mbs, int ranges for depths, and cat lists for model dimension. For $S>0$, I express each deeper $D_s$ as a non-negative delta from the previous $D_{s-1}$.
      • I concurrently consider up to 10 runs in parallel, depending on hardware constraints (primarily 8/9/10. Very rarely, 4), and asynchronously request suggestions / yield observations.
      • Note that the duration of each run is inherently random due to the dynamic step counts / batch size / model dims involved.
    • The metrics for each run are defined as the mean of the last 20 logged points of the run.
      • By default, I log once every 5 steps, so this is averages the last 100 steps of training.
      • I estimate each step’s training bits-per-byte by interpreting each training batch as a ‘dataset’. This is not equivalent to a fixed eval bpb dataset (which I also implemented and computed, but realized was confounded by the terminating seqlen of each compute-optimal run and thus in practice a useless metric)
      • I optimize for best cross entropy across runs. Optimizing for cross-entropy is not identical to optimizing for the bpb proxy, but it is reasonably close.
    • I define bounds as “reasonable” if the best runs’ hparams are not obviously clustered around the edge of the provided design space. This is determined with my eyeballs.
    • Where any of the following occurs, I reject run results and request a new suggestion:
      • infeasibly long runs (>65k steps)
      • unknown crashes, exceptions, networking errors, etc (rare)
    • When OOMs occur, runs are re-executed with blockwise activation checkpointing.
  6. Please contact me if any statements in the above text feel ambiguous or understated. I believe the above text should be sufficiently accurate for a code agent to reconstruct the full nature of my setup.

So, in plots like these:

  • each individual training run is a dot
  • each connected line / logfile indicates one sweep for a particular flops budget
  • the entire graph is plotted for a specific architecture – the above being $S=0$ / llama2.

Result
#

After thousands of runs, this is the final aggregated end result:

That may be difficult to see, so here is a simplified version that only contains the best bpb end loss:

And here are some of the optimal runs found:

flopstypebpb$D_0$$D_1$$L_m$$L_t$$\eta$mbs(est) stepsceloss
1e+16bpe1.52477256100.004201911536953.772508
3e+16bpe1.42522256160.005872076673333.611341
3e+16byte1.43770256512350.00088332477143050.996538
1e+17bpe1.3021451280.005323278886373.357063
1e+17byte1.295882562565100.00069622614689460.898232
3e+17bpe1.20598512140.0046632811168003.178416
3e+17byte1.21250256512590.00035350011457250.840438
1e+18bpe1.11295512120.00617127283163503.017545
1e+18byte1.127245127686120.00021398163213790.78134

As you can see, none of the H-Net ($S>0$) runs ever achieve more compute efficiency than L2 (llama2).

In some cases, they achieve parity – I was quite excited when I obtained the 1e17 points – but the dreaded scaling curve bounce occurred, as FLOPs increased.

Self-questioning
#

There are many copes I tried to investigate to avoid this painful conclusion.

Cope of Thought

For example, I considered the possibility that my base $S=0$ implementation was wrong, and created a separate pure llama2 model to ablate behavioral difference:

In my opinion, this does not show differences beyond random noise.

I considered the possibility that target compression factor needed to increase with scale. In principle, this is a reasonable assumption – if BPE vocab sizes should scale with size, then the analogous chunk size in H-Nets should grow as well.

The correct answer to that possibility is complicated, but in short: it does grow, but accounting for that does not help the issue.

I considered if my FLOPs calculation could be broken. However, this consideration was a complete waste of time, as I had already considered (and fixed) its correctness many times in the past && had verified it with several independent sources, including Cartesia themselves.

I checked to see if a preponderance of my sweeps were actually exploding and therefore invalid,

The vast majority (>95%) of them were quite stable, so I concluded there ought to be nothing wrong with the general search space.

So, the analysis above convinced me my H-Nets were failing due to uncontrolled growth in compression ratio, leading to various poor attempts at architectural interventions to control it. Eventually, I concluded my assumptions were wrong, and that I needed to look elsewhere.

So, I took a step backwards, and tried to work from first principles, developing a ground-up understanding of activation statistics throughout H-Net training. Unfortunately, I realized it is actually quite nonsensical to expect the learned chunking behavior to remain consistent across compute scales, putting me back to square one.

Conclusion
#

I started off with the following statement:

If your byte-level architecture cannot win, even under this biased comparison, it should just thrown into the garbage bin, regardless of whether it’s “cool” or can solve certain edge char-level tasks. Compute is always the constraining factor…

I am sticking to my guns. It can’t be helped, BPE wins again… 😢

But, as mentioned prior, this still leaves the opportunity for BPE H-Nets to win where byte H-Nets failed. The merits of that thesis will presented in a future post.


Squaring the circle
#

A reasonably intelligent reader may, at this point, ask:

Don’t your findings contradict the paper?

The authors show in the paper, repeatedly, that byte-level H-Net should beat BPE in the FLOPs-equal setting.

How do you explain your results?

First of all, let’s understand: how does Cartesia’s baseline work?

  1. They start with a (BPE) Transformer baseline. Its aspect ratio & tokenizer are a copy of GPT-2-Large, and the rest of the architecture matches llama2.
  2. They use the 100B fineweb-edu dataset for all experiments, and aim to process ~8192 bytes per sequence, with a batch size of 256.
  3. They control for compute in all differing architectures by adjusting them to be equal on a FLOPs-per-byte basis.

Let me start off by saying: I find the above to be a reasonable & good faith effort at comparing different architectures, without too complex of a search space.

But, notably, it is not equivalent to a true compute optimal search. Cartesia notes this themselves:

We didn’t formally run out true scaling laws in the paper, which would require sweeping over many more model sizes and compute horizons. But as I mentioned in the previous part, I have some reasons to believe that they will have better scaling (i.e. shift the scaling law curve by a non-trivial amount)…

So, it is entirely reasonable to assume the setup could be biased in either direction (in favor or against BPE), and we must work to consider which biases win out.

Consider the following observations:

FLOPs standardization
#

It does not take a lot of pretraining experience to be aware that it is rather numerically impossible to 100% match the FLOPs-per-byte used across architectures. Just to highlight the two obvious issues:

  1. When different blocks have different arithmetic expressions for FLOPs, it becomes incredibly likely that all correct solutions for 100% equivalence have fractional values of $L$ or $D$, which cannot exist
  2. The actual FLOPs used at runtime by the H-Net is routing dependent, and no static estimator can ever capture this behavior

Of course, being pedantic about this stuff is not helpful. It doesn’t matter much if the FLOPs-per-byte is only off by – let’s say – ±5%, because that alone wouldn’t explain the gulf between our results.

So, let’s assess the damage.

Arch estimators
#

For (1), what the authors actually do, is:

For calculating FLOPs, we follow standard methodology (Hoffmann et al. 2022) with an extension for Mamba-2 layers (see Appendix B). We use the BPE-tokenized Transformer’s #FLOPs as a reference, and the number of layers of the other models is adjusted accordingly to match the reference #FLOPs.

I take that to mean:

  • starting with an oracle get_seq_flops(num_layers, l0, rest_of_arch),
  • obtain f = get_seq_flops(24, 1792, ...) for the GPT-2 baseline
  • for all archs, minimize abs(f-get_seq_flops(l, 8192, ...)) over l

What does that imply? Roughly, the FLOPs-per-Byte calc in cartesia’s baseline should be off by a factor of $\pm\frac{1}{2l}$. Since $l\approx 25$, this is not more than 2%.

Therefore, (1) is not a significant issue.

Seq lengths
#

How about (2)? How does the paper handle dynamic seqlen FLOPs?

This isn’t ever explicitly stated, but by looking at this table:

You can see that the number of transformer layers from H-Net pool/space –> 1-stage is decremented from 28–>22. This is presumably because the authors use empirical BPIC to obtain the average FLOPs used by the main network in training.

To me, that doesn’t really make sense. I feel like there should be a catch-22:

  • For a H-Net, it is impossible to know the average BPIC of a specific arch under a dataset, unless you have already conducted a run.
  • For a H-Net, it is impossible to adjust the number of layers in the arch to match FLOPs-per-byte unless you have average BPIC.

But, you know, maybe the authors have a learned sense for how much BPIC their models will get under certain hparams, and are thus able to guess ahead-of-time to stunning accuracy.

And I say stunning accuracy, because in table 2, they state their that H-Nets use (almost) the exact same GFLOPs/byte as the BPE baseline, at evaluation time.

Note that the table only proves FLOPs are identical at eval time, which (based on the rest of the table) should refer purely to GFLOPs/byte at end of training.

Ideally, BPIC should be defined as the average BPIC across the entire training run. But this is never asserted in the paper; if BPIC was only measured at the end of the run, it would be impossible to determine how off their DC FLOPs are, without access to the exact curvature of their compression ratio across the whole run.

So, I think there’s some significant uncertainty here about the accuracy of their inner sequence length estimation, but I nonetheless find it moderately unlikely to be the cause of gigantic divergence in results.

Hyperparameter tuning
#

This problem is much more likely to be the bigger one.

The paper’s setting, to reiterate, is:

All tokenizer-free models process 8192 utf-8 encoded bytes per sequence, while the Transformer uses 1792 tokens from the GPT2 tokenizer (roughly equivalent to 8192 bytes). We use batch size 256 for all models;

Following Hägele et al. (2024) which recommends WSD schedulers with half the maximum learning rates as a cosine schedule, we adopt learning rates 2.5× higher than GPT-3 (Radford et al. 2019) standards; this corresponds to half of the maximum learning rate used in Gu and Dao (2024), yielding 6.25 × 10−4 for Large-scale models and 5.0 × 10−4 for XL-scale models.

That is, they use fixed LR and fixed batch size for all experiments.

This is not intrinsically bad. Arguably, this is preferable to the varying batch, flops-dynamic LR abomination I cooked up for my experiments…

In any case, without empirical evidence, it is impossible to determine whether this setting benefits BPE Transformer or H-Net more.

Thankfully, I have an abundance of empirical evidence to dissect this issue with.

Batch size
#

Claim:

Under the compute optimal setting, the optimal token batch size tends to be similar for both BPE Transformer && byte H-Net.

That is, the optimal number of tokens at L₀, not the number of bytes.

Justification?

It’s what I empirically find…

So, given that the paper normalizes runs by bytes-per-batch, the following corollary should be true:

Due to the use of identical bytes-per-batch in the paper, its BPE Transformer baseline is trained on a suboptimal (low) batch size relative to the H-Net experiments.

I’m not familiar enough with the numbers to tell you how damaging this might be, but it should be non-negligible.

Learning rate
#

On learning rates, the paper claims in appendix C:

Empirically, we find that a reasonable set of multipliers (e.g., 𝜆⁰=2.0, 𝜆¹=1.5, 𝜆²=1.0) works well in general.

By proxy, this implies the learning rate of H-Nets is generally increased by LR Modulation, relative to the BPE baseline.

Ipso facto, this is already bad: Given the absence of any observed loss spikes in the paper, it’s likely the paper’s global LR is far from the edge of stability, and that the BPE Transformer is thus learning ‘slower’ than it could otherwise be, even if the different architectures had the same optimal LR.

But even that much doesn’t seem to be true:

Although I’m modestly doubtful about my own optimal LRs, I think it is generally true that, when used with LR Modulation, $S=1$ byte-level H-Nets should have a significantly lower base learning rate than a similarly sized BPE transformer.

So, I claim:

It is likely the H-Net paper’s BPE Transformer LR is too low.

I don’t have the compute to determine exactly how too-low it is, but it could be anywhere from 2x to 10x, depending on how things shift with scale, and whether they enabled a redundant LR modulation rule for their BPE baseline or not.

Conclusion
#

Dynamic Chunking for End-to-End Hierarchical Sequence Modeling fails to provide an accurate picture of H-Net’s performance under the compute optimal regime.

It’s still possible that,

  • H-Nets are more compute efficient in the data overtrained regime
  • My own experiments are scientifically invalid in various ways

But I think, even if so, it’s unlikely byte-level H-Nets will ever achieve drastic wins on proper compute-optimal ablations.

Appendix
#

FLOPs calc

Here are some code snippets that should sufficiently explain the FLOPs calculation used. They are copied directly from the codebase I used for all experiments.

# mh/config_hnet.py:62
def _mamba_per_token(d: int, ssm_cfg: dict) -> float:
    expand   = ssm_cfg["expand"]
    d_state  = ssm_cfg["d_state"]
    k   = ssm_cfg["d_conv"]

    headdim = 64
    d_inner = d_ssm = expand * d
    n_heads = d_ssm//headdim

    d_inproj = 2*d_ssm + 2*d_state + n_heads
    d_conv = d_ssm + 2*d_state

    flops_inproj = 2 * d * d_inproj
    flops_dwconv = 2 * d_conv * k # <-- NOTE: in causal-conv1d, this is actually implemented as scalar flops. but theoretically you could unfold and do a s[dx4xd] bmm
    flops_ssd = 2 * 3 * d_ssm * d_state
    flops_outproj = 2 * d * d_ssm

    return flops_inproj + flops_dwconv + flops_ssd + flops_outproj


def _attn_per_token(d: int, msl: int) -> float:
    return 2 * d * (msl + 4 * d) # causal non-recompute
def _glu_per_token(d: int, h: int) -> float:
    return 2 * d * h * 3
def layer_per_tok(is_mamba: bool, has_mlp: bool, d: int, h: int, msl: int, ssm_cfg: dict) -> float:
    layer = _mamba_per_token(d, ssm_cfg) if is_mamba else _attn_per_token(d, msl)
    if has_mlp: layer += _glu_per_token(d, h)
    return layer
# mh/xf.py:328
class Isotropic(nn.Module):
    def flops_per_token(self, msl: int):
        if self.window_size != -1 and any(isinstance(l.mixer, CausalMHA) for l in self.layers):
            raise NotImplementedError("window_size != -1 FLOPs calculation is not implemented")
        return sum(
            layer_per_tok(
                is_mamba=isinstance(l.mixer, Mamba2Simple),
                has_mlp=isinstance(l.mlp, GLU),
                d=self.d, h=self.h, msl=msl, ssm_cfg=self.ssm_cfg
            ) for l in self.layers
        )
# mh/modeling_hnet.py:203
class HNet(nn.Module):
    def flops(self, extra_b: list["SeqInfo"]):
        # for modules that have total flops dependent on per-doc seqlen, we overestimate a little bit by the packed batch's max seqlen.
        # this is OK bc our seqlen sorted dataset has very close min vs max seqlen on average.
        d = self.d
        total_slen, max_seqlen = extra_b[0].mbs, extra_b[0].msl
        
        # hnet children: residual & routing mod have 3 dxd linears
        aux_pertok = 0 if self.is_innermost else 2*3*d*d
        iso_pertok = sum(
            m.flops_per_token(max_seqlen)
            for m in ([self.main_network] if self.is_innermost else [self.encoder, self.decoder])
        )
        # NOTE: we do not include dechunk/scan flops as elementwise scan does not touch tensor core

        our_flops = total_slen * (aux_pertok+iso_pertok)
        child_flops = 0 if self.is_innermost else self.main_network.flops(extra_b[1:])
        return our_flops + child_flops
# mh/modeling_hnet.py:302
class HNetLM(BlockBoundaryMixin, nn.Module):
    def flops(self, extra_b: list["SeqInfo"]):
        # H-Net flops estimator, from batch of NJT selection data obtained from fwd.
        # we do not account for indexing/scalar costs here (such as embedding or biases or dechunk layer)
        d = self.c.d_model[0]
        v = self.c.vocab_size
        lmh = 2*v*d*extra_b[0].mbs
        return lmh + self.backbone.flops(extra_b)
H-Net - This article is part of a series.
Part 6: This Article