Skip to main content

H-Net - Inference

·1913 words
Table of Contents
This page assumes you’ve read H-Net’s paper and code. Please do so first.

The H-Net architecture is a new promising end-to-end approach for byte-level language modelling, which visibly outperforms traditional tokenized LLMs.

What happens during a H-Net’s forward pass?

%%{init: { "themeVariables":{"fontFamily":"Inter, sans-serif", "edgeLabelBackground":"#bbbbbb"}, "flowchart":{"nodeSpacing":30,"rankSpacing":20} }}%% flowchart TD A(["x"]) A --> B{{"model dim increased?"}} B -->|✅| C["x = cat(x, .pad_dimension)"] C --> D{{"innermost h-net?"}} B --> D D -->|✅| X["y = .main_network(x,...)"] X --> Z[["return y[...,:D]"]] D --> Outer Outer --> Z subgraph Outer [Outer H-Net Behavior] E["r = .encoder(x,...)"] --> F["bp = .routing_module(r,...)"] F --> G["h = .chunk_layer(r,bp,...)"] G --> Q{{"any tokens in h?"}} Q -->|✅| Recurse["h = .main_network(h, ...)"] Recurse -->|h| I[".dechunk_layer(h,bp,...)"] Q -->|h| I F --> I E --> R[".residual_proj(r)"] R --> circleId(("`\+`")) I --> circleId(("`\+`")) circleId(("`\+`")) -->|h| Y["y = .decoder(h,...)"] end %% Node styles style A fill:#CBA6F7,stroke:#7F5FA3,stroke-width:2px style Z fill:#CBA6F7,stroke:#7F5FA3,stroke-width:2px style B fill:#BAC2DE,stroke:#6C7086,stroke-width:2px style D fill:#BAC2DE,stroke:#6C7086,stroke-width:2px style X fill:#89B4FA,stroke:#486CAA,stroke-width:2px style E fill:#89B4FA,stroke:#486CAA,stroke-width:2px style Y fill:#89B4FA,stroke:#486CAA,stroke-width:2px %% Gold = recursive H-Net style Recurse fill:#FFE082,stroke:#FFB300,stroke-width:2px %% Optional: rounded subgraph box style Outer rx:12px,ry:12px

Construction
#

To load a H-Net for inference, the code must:

  • create Modules & load weights
  • initialize inference caches

Model
#

H-Nets are mostly made up of Isotropics – transformer/mamba blocks that map $(B,L,D)\rightarrow(B,L,D)$ sequences.

pdb++ dump of hnet_1stage_XL
(Pdb++) model
HNetForCausalLM(
  (embeddings): Embedding(256, 1024)
  (backbone): HNet(
    (encoder): Isotropic(...)
    (main_network): HNet(
      (main_network): Isotropic(...)
    )
    (decoder): Isotropic(...)
    (routing_module): RoutingModule(
      (q_proj_layer): Linear(in_features=1024, out_features=1024, bias=False)
      (k_proj_layer): Linear(in_features=1024, out_features=1024, bias=False)
    )
    (chunk_layer): ChunkLayer()
    (dechunk_layer): DeChunkLayer()
    (residual_proj): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (lm_head): Linear(in_features=1024, out_features=256, bias=False)
)

At a high level, it has:

  1. a top level CausalLM wrapper, giving embedding + lm_head.
    • Note the lack of a norm, as H-Nets put them in Isotropic (see §2.3.1)
    • the wrapper’s .backbone is always a H-Net here, but could be substituted with any conventional transformer (or Isotropic in general)
  2. an $S$-stage H-Net, which is recursively defined by two cases:
    • the innermost ($s=S$) H-Net, which is just a pure Isotropic. Technically, a transformer is an $S=0$ H-Net.
    • all other H-Net levels ($s\in[0,S)$), which contain
      • Isotropic encoder/decoder networks ($\mathcal{E, D}$),
      • an inner main_network H-Net ($\mathcal{M}$),
      • $\text{Chunk}$ modules (chunk_layer, routing_module)
      • $\text{Dechunk}$ modules (dechunk_layer, residual_proj)
    • There is also a pad_dimension vector (Parameter) between any H-Net levels with differing model dim $D$. Pytorch doesn’t print this.

Throughout the rest of the article, I ignore (1) and focus on (2).

Caches
#

In addition to the (predominant factor) inference caches required by transformer/mamba blocks, H-Nets need to track inference state in a few other places:

Memory breakdown for H-Net caching

For each batch,

  • Transformer blocks use a KV Cache of $L\times N_h\times D_h$.
    • For 1stage_XL, $D=2048\land N_h=16\implies D_h=128$
  • Mamba2 blocks use a conv_state of Shape[conv_dim, d_conv], and a $N_h\times D_h\times d_{state}$ ssm_state. for 1stage_XL,
    • d_conv=4,d_state=128,expand=2.
    • d_inner=expand*d_model=2048, so $D_h=64\implies N_h=32$.
    • conv_dim = d_inner + 2*d_state = 2304.
  • routing_module needs to track the last_hidden_state vector – when given $\hat{x}^t$, routing_module will have $\hat x^{t-1}$ cached, so $k^{t-1}$ can be obtained. It also stores a has_seen_tokens flag, but its behavior is redundant under generate.py.
  • dechunk_layer stores a similar last_value vector (the EMA $\bar z_t$)

The rest of the layers in a H-Net are purely functional.


Generation
#

Like normal LLMs, there’s a separation of concerns between Prefill (.forward()) and Decode (.step()) in the H-Net inference code.

For ease of understanding, this article walks through both processes in isolation, for the specific case of hnet_1stage_XL. This should provide sufficient context to understand all other released H-Net variants.

Prefill
#

Similar to Spacebyte, we start with a sequence of UTF-8 bytes (with 0x7E as BOS) as input, which pass through an nn.Embedding to produce the first hidden_states.

So, those hidden_states (shortened to x) are passed to the first (0th) H-Net stage:

~same image as the top of the article. except guaranteed main net
%%{init: { "themeVariables":{"fontFamily":"Inter, sans-serif", "edgeLabelBackground":"#bbbbbb"}, "flowchart":{"nodeSpacing":30,"rankSpacing":20} }}%% flowchart TD A(["x"]) A --> B{{"model dim increased?"}} B -->|✅| C["x = cat(x, .pad_dimension)"] C --> D{{"innermost h-net?"}} B --> D D -->|✅| X["y = .main_network(x,...)"] X --> Z[["return y[...,:D]"]] D --> Outer Outer --> Z subgraph Outer [Outer H-Net Behavior] E["r = .encoder(x,...)"] --> F["bp = .routing_module(r,...)"] F --> G["h = .chunk_layer(r,bp,...)"] G -->|always| Recurse["h = .main_network(h, ...)"] Recurse --> I[".dechunk_layer(h,bp,...)"] F --> I E --> R[".residual_proj(r)"] R --> circleId(("`\+`")) I --> circleId(("`\+`")) circleId(("`\+`")) -->|h| Y["y = .decoder(h,...)"] end %% Node styles style A fill:#CBA6F7,stroke:#7F5FA3,stroke-width:2px style Z fill:#CBA6F7,stroke:#7F5FA3,stroke-width:2px style B fill:#BAC2DE,stroke:#6C7086,stroke-width:2px style D fill:#BAC2DE,stroke:#6C7086,stroke-width:2px style X fill:#89B4FA,stroke:#486CAA,stroke-width:2px style E fill:#89B4FA,stroke:#486CAA,stroke-width:2px style Y fill:#89B4FA,stroke:#486CAA,stroke-width:2px %% Gold = recursive H-Net style Recurse fill:#FFE082,stroke:#FFB300,stroke-width:2px %% Optional: rounded subgraph box style Outer rx:12px,ry:12px

Walking through the call stack,

  1. at the outermost stage ($s=0$), the input x will match the current model dim, so no padding is needed.
  2. the innermost stage is ($s=1$), so we follow the outer H-Net behavior.
  3. The encoder ($\mathcal{E}^0$) is a Mamba2 model (or any isotropic, but results clearly show $\text{M}>>\text{T}$ here). It mutates x & caches SSM state for decode later.
  4. The routing module needs to produce bpred_outputs (from $p_t,b_t$). It:
    • (matmul) computes query/keys $q_t,k_t$ from input $x_t$,
    • does cossim = F.cosine_similarity(q[:,:-1],k[:,1:],dim=-1)
    • assigns $p_0 = 1.0$, $p_{t+1} = \frac{1-\text{cossim}}{2}$, and b = p>=0.5
    • caches the lastmost $x_t$ for use in decode later.
    • returns boundary_prob=[1-p,p], boundary_mask=b, selected_probs=max(1-p,p)
  5. The chunking layer needs to drop tokens from x to produce h. This is h=x.masked_select(b) in the unmasked, unbatched case.
  6. h is the input (and also the end result) for the inner $s=1$ H-Net. For 1stage_XL, it’s a transformer, which caches KV state for decode later.
  7. h (renamed $\hat z$) is expanded back to x’s len with the dechunk layer, which:
    • obtains P = p.masked_select(b)
    • borrow mamba parallel scan kernel to quickly obtain EMA $\bar z$ from $\hat z,P$
    • implements $\tilde z$ as out.gather(dim=1, index=b.cumsum(dim=1)...)
    • caches the last vector in $\tilde z$ as last_value
  8. compute h = residual_proj(x) + h
  9. run h through the decoder ($\mathcal{D}^0$) to obtain output y.

Some thoughts:

On batching
#

To maximize understanding, I included intrinsics like masked_select in my explanation, even though they can only be used in the batch size 1 scenario.

The code in the released repo does not use them, and much of it seems to be written with batched training/inference in mind, which I do not describe and am ill-qualified to describe the full constraints of.

Multi-turn issue
#

When comparing the forward/step of the HNet, I understood that the conditional guard against length 0 hidden_states_inner is only required in step, as the forward is guaranteed (by BOS) to have at least one non-zero boundary_mask position.

However, in the case of multi-turn chat models, a prefill of a new user chat input has no $p_0=1.0$ to guarantee boundary_mask.any(). It then becomes possible for very short user messages to encounter 0 selected tokens in deeper hierarchies (even if this is unlikely to be learned behavior).

The simplest way to engineer around that would be to just pad prefill batches and execute anyway. But an alternative that makes slightly more sense to me is the possibility of guaranteeing $p_t=1$ for all special tokens, rather than just the BOS.

Though I am not sure how this would impact training, I think it makes sense that there should be some basis for an end user to ‘force’ a H-Net to start a new chunk at a particular byte. I can imagine this feature would be useful for some multimodal cases as well.

Decode
#

The generation loop of a H-Net LLM is identical to that of any other byte LLM:

$\text{prev byte}\xrightarrow[\text{cache}]{\text{LM}}logits\xrightarrow[]{\text{sampler}} \text{next byte}$

So, like any other LM, the H-Net starts with a vector $x$, plus the prefill cache state.

Much like Prefill’s call stack,

  1. (same) padding happens if dim changes
  2. (same) h-net $s=0$ behavior always starts the same
  3. (cached) The encoder reads/updates its stored ssm_state/conv_state when modifying input x
  4. (cached) The routing module uses cached $x_{t-1}$ to compute $p_t,b_t$ with $x_t$, and writes the latter to cache.
  5. The chunking layer still does the same op, but because there is no guarantee of non-zero $b_t$ (unlike prefill, where $b_0=1$), the resultant h selected can be an empty (length 0) sequence.
  6. Since h can be empty, it’s possible to fully bypass the inner $s=1$ H-Net. This is why H-Nets implicitly implement spec decode / adaptive compute. When it does execute, it is equivalent to a single (cached) decode step in a normal transformer.
  7. The dechunk layer simply returns $\tilde z_{t-1}$ if h is empty, and otherwise calcs & caches $\tilde z_t = P_t\hat z_t+(1-P_t)\tilde z_{t-1}$. It may be difficult to convince yourself this is what occurs based on the batched impl, but it is true.
  8. (same) add residual. This is $\tilde z_{t-1} + W_r\hat x$ when the main net is not executed.
  9. (cached) run decoder with h and cache to obtain output byte logits y.

H-Net Simple
#

I thought the half-batched, half-training implementation of H-Net published was a bit confusing, so I decided to implement my own lighter variant.

Here are some things I learned in the process.

Character Repeats
#

For visibility, I implemented a black/white visualizer for chunk boundaries in H-Net generation.

So, while playing with some H-Net completions, I realized stage1_L was actually incredibly bad at compressing repeated bytes:

On a sequence of pure repeated AAAAAAAAAA..., it would create a new chunk almost every letter. I tested stage1_XL, and saw mostly the same behavior, with almost no improvement (sometimes 2char per chunk instead of 1).

On stage2_XL, this completely flipped. Instead of getting chunks every char, I was getting chunks after huge spans of repeats had been generated:

This is a great demonstration of the power of deeper hierarchies. I strongly support advertising this fact.

Misc other thoughts
#

Current models are definitely intelligent enough to understand H-Net, much faster than I can grasp. But, I did not make an attempt to AI-generate a H-Net impl. It seems quite hard to get any educational value out of code, when it is not manually read and edited by yourself.

Even though NestedTensor would be really useful for H-Nets, I don’t expect future optimized implementations of them to make use of the feature. It seems much more likely that handcrafted cu_seqlens-like approaches will beat time-to-market on that front…

Conclusion
#

Most information written here is sourced from the paper/code, with the bulk of my words a personal verbalization of it.

  • Some minor novel observations were made regarding special tokens && char repetition, but they are not large contributions.
  • There is also my reimplementation of H-Net, which may be fun to read if you enjoy the compressed style of Python I default to.

Ideally, a few dear readers will find H-Nets slightly more interesting due to this post, and will be motivated to work on:

  • covering the problem space of H-Net batched inference
  • coding something analogous to HDP Balance for efficient H-Net training
  • scaling up H-Net experiments, even if not very efficiently implemented

This post took 1 person 3 days of study to write. I believe H-Nets deserve a lot more than that.

Regarding quality

In this article, I have made various strategic errors in writing, like:

  1. word pollution of the term ‘H-Net’, where I fail to disambiguate
    • the drop-in replacement for an arbitrary isotropic backbone
    • the (series of) causal LM model(s) released under the H-Net collection on huggingface
  2. strange repeated emphasis of specific simple concepts to anchor understanding
  3. unreadable levels of code formatting, listing, text formatting
  4. undefined variables, name shadowing, various abuses of notation, etc.

This is my indication to the reader that the contents of the article are authentic & human written.