In October 2019, I created SequenceLayers to make the process of building streaming neural networks easier. In 2023 we open-sourced the library, and in 2025 we published a tech report to explain the rationale behind the design decisions in the library.
SequenceLayers enables declarative definitions of sequence processing models that can be processed in a layer-wise fashion or in a block-by-block fashion over the time dimension, producing identical results in each mode. It is akin to Keras, which also has a declarative API, but supports streaming and sequence modeling as a first-class feature.
Here is an example of a decoder-only Transformer block in SequenceLayers.
import jax
import numpy as np
import sequence_layers.jax as sl
config = sl.Serial.Config([
# Self attention.
sl.Residual.Config([
sl.RMSNormalization.Config(name='pre_norm'),
sl.DotProductSelfAttention.Config(
num_heads=16,
units_per_head=64,
# Global causal attention.
max_past_horizon=-1,
max_future_horizon=0,
name='self_attention'
),
sl.DenseShaped.Config([d_model], name='output_projection),
sl.RMSNormalization.Config(name='post_norm'),
sl.Dropout.Config(dropout_rate),
], name='attention_block'),
# Gated GeLU FFN.
sl.Residual.Config([
sl.RMSNormalization.Config(name='pre_norm'),
sl.Dense.Config(4 * d_model, name='dense1'),
sl.GatedUnit.Config(jax.nn.gelu, None),
sl.Dense.Config(d_model, name='dense2'),
sl.RMSNormalization.Config(name='post_norm'),
sl.Dropout.Config(dropout_rate),
], name='ffn_block')
], name='transformer_block')
transformer = config.make()
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
# Random input sequence:
x = sl.Sequence(
values=jax.random.normal(k1, (2, 4096, 1024))),
mask=jax.random.uniform(k2, (2, 4096)) > 0.5
)
# Run Flax layer initialization.
params = transformer.init(k3, x, training=False)
# Bind the layer for imperative/example usage.
transformer = transformer.bind(params)
# Process x layer-wise:
y_layer = block.layer(x, training=True)
# Process x 8 steps at a time:
block_size = 8
num_blocks = (x.shape[1] + block_size - 1) // block_size
state = block.get_initial_state(x.shape[0], x.channel_spec, training=False)
y_step = []
for i in range(num_blocks):
x_i = x[:, i * block_size : (i + 1) * block_size]
y_i, state = block.step(x_i, state, training=False)
y_step.append(y_i)
y_step = sl.Sequence.concatenate_sequences(y_step)
np.testing.assert_array_allclose(y_layer.values, y_step.values)
np.testing.assert_array_equal(y_layer.mask, y_step.mask)
Building streaming sequence models is surprisingly tricky. There are four common pitfalls I kept running into:
AutoregressiveTransformer when architecture and
autoregressive modeling are independent choices.SequenceLayers addresses these with three core features. Each SequenceLayer is:
SequenceLayers has been used to abstract architectural details in:
across a wide variety of tasks:
Additionally, SequenceLayers is used extensively in production at Google for many streaming applications.
For more detail on the library design and rationale behind the design, see the tech report and the code on GitHub, available under the Apache 2.0 license.
I’m deeply thankful that Google was willing to let me open source this library. My hope is that it serves as a useful example to the community of a simple abstraction that more than pulls its weight when working with neural networks that operate over sequences.