Projects tagged with machine-learning:

  • SequenceLayers: Streaming made simple.

    2019 - present tags: google machine-learning open-source

    In October 2019, I created SequenceLayers to make the process of building streaming neural networks easier. In 2023 we open-sourced the library, and in 2025 we published a tech report to explain the rationale behind the design decisions in the library.

    SequenceLayers enables declarative definitions of sequence processing models that can be processed in a layer-wise fashion or in a block-by-block fashion over the time dimension, producing identical results in each mode. It is akin to Keras, which also has a declarative API, but supports streaming and sequence modeling as a first-class feature.

    Here is an example of a decoder-only Transformer block in SequenceLayers.

    import jax
    import numpy as np
    import sequence_layers.jax as sl
    
    config = sl.Serial.Config([
    
      # Self attention.
      sl.Residual.Config([
        sl.RMSNormalization.Config(name='pre_norm'),
        sl.DotProductSelfAttention.Config(
          num_heads=16,
          units_per_head=64,
          # Global causal attention.
          max_past_horizon=-1,
          max_future_horizon=0,
          name='self_attention'
        ),
        sl.DenseShaped.Config([d_model], name='output_projection),
        sl.RMSNormalization.Config(name='post_norm'),
        sl.Dropout.Config(dropout_rate),
      ], name='attention_block'),
    
      # Gated GeLU FFN.
      sl.Residual.Config([
        sl.RMSNormalization.Config(name='pre_norm'),
        sl.Dense.Config(4 * d_model, name='dense1'),
        sl.GatedUnit.Config(jax.nn.gelu, None),
        sl.Dense.Config(d_model, name='dense2'),
        sl.RMSNormalization.Config(name='post_norm'),
        sl.Dropout.Config(dropout_rate),
      ], name='ffn_block')
    
    ], name='transformer_block')
    
    transformer = config.make()
    
    k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
    
    # Random input sequence:
    x = sl.Sequence(
      values=jax.random.normal(k1, (2, 4096, 1024))),
      mask=jax.random.uniform(k2, (2, 4096)) > 0.5
    )
    
    # Run Flax layer initialization.
    params = transformer.init(k3, x, training=False)
    
    # Bind the layer for imperative/example usage.
    transformer = transformer.bind(params)
    
    # Process x layer-wise:
    y_layer = block.layer(x, training=True)
    
    # Process x 8 steps at a time:
    block_size = 8
    num_blocks = (x.shape[1] + block_size - 1) // block_size
    
    state = block.get_initial_state(x.shape[0], x.channel_spec, training=False)
    y_step = []
    for i in range(num_blocks):
      x_i = x[:, i * block_size : (i + 1) * block_size]
      y_i, state = block.step(x_i, state, training=False)
      y_step.append(y_i)
    
    y_step = sl.Sequence.concatenate_sequences(y_step)
    
    np.testing.assert_array_allclose(y_layer.values, y_step.values)
    np.testing.assert_array_equal(y_layer.mask, y_step.mask)
    

    Building streaming sequence models is surprisingly tricky. There are four common pitfalls I kept running into:

    • Batching unequal sequences requires tracking invalid timesteps and verifying all layers handle padding correctly, including pooling and sampling operations.
    • Causality constraints mean modern architectures need separate efficient parallel training and autoregressive sampling code paths, both avoiding causality violations.
    • Offline vs. streaming mismatch: converting parallel inference (via masking) to streaming requires re-implementation due to lookahead windows and memory constraints.
    • Unnecessary coupling: architecture details become entangled with algorithms, such as pairing AutoregressiveTransformer when architecture and autoregressive modeling are independent choices.

    SequenceLayers addresses these with three core features. Each SequenceLayer is:

    • Streamable: SequenceLayers gives you streaming for free, in a production-friendly way. Every layer implements explicit state and a step method alongside the traditional layer-wise call.
    • Correct: SequenceLayers is correct by default, making entire classes of bugs impossible. Layer and step methods are tested to produce identical results, and mask-aware Sequence objects track padding throughout.
    • Composable: An easy-to-understand declarative API enforces these guarantees, enabling sequence models with concise definitions that read like block diagrams.

    SequenceLayers has been used to abstract architectural details in:

    • Classifiers
    • Contrastive / distance metric learning models.
    • Regression models.
    • Probabilistic models (autoregressive models, normalizing flows, diffusion, VAEs, GANs).

    across a wide variety of tasks:

    • Audio / speech classification.
    • Image classification.
    • Contextualized word embedding.
    • Text-to-speech synthesis.
    • Speech and phoneme recognition.
    • Speech translation.
    • Speech vocoding.
    • Audio tokenization and synthesis.
    • Real-time music synthesis.
    • Video understanding.
    • Language modeling.

    Additionally, SequenceLayers is used extensively in production at Google for many streaming applications.

    For more detail on the library design and rationale behind the design, see the tech report and the code on GitHub, available under the Apache 2.0 license.

    I’m deeply thankful that Google was willing to let me open source this library. My hope is that it serves as a useful example to the community of a simple abstraction that more than pulls its weight when working with neural networks that operate over sequences.

  • Tacotron: End-to-end Speech Synthesis

    2016 - present tags: google machine-learning sound

    General architecture of Tacotron.

    The most exciting work I’ve been involved with on the Sound Understanding team has been the development of Tacotron, an end-to-end speech synthesis system that produces speech “end-to-end” from characters to waveform. The initial system took verbalized characters as input and produced a log-magnitude mel spectrogram, which we then synthesize to waveform via standard signal processing methods (Griffin-Lim and an inverse Short-time Fourier Transform). In Tacotron 2, we replaced this hand-designed synthesis method with a neural vocoder, initially based on WaveNet.

    One line of research on my team is direct-to-waveform acoustic models, skipping the intermediate spectrogram representation. In the Wave-Tacotron paper, we published a model that does just that.

    Check out our publications page for the full list of research that my team has co-authored with the Google Brain, Deepmind, and Google Speech teams.

  • Google Sound Understanding

    2015 - present tags: google sound machine-learning research

    In 2015, I joined the Sound Understanding team within Google Perception. We focus on building systems that can both analyze and synthesize sound. Being able to work on my hobby (sound and digital signal processing) as my full time job has been a dream come true. We operate as a hybrid research team, which means we both publish our work and deploy it to improve Alphabet’s products and services.

    I’ve had the opportunity to work on some neat tasks and projects during my time on the team, but speech synthesis has been what I’ve spent the most time working on.

  • Google TensorFlow

    2015 - present tags: google machine-learning

    TensorFlow's OG logo.

    In 2015, I joined the Sound Understanding team within Google Perception. Our main tool for machine learning research is TensorFlow. Over the years I’ve contributed a number of features to TensorFlow that have been crucial to the research work my team has done.

    The highlights include:

    • Added the tf.signal module of signal processing components.
    • Extended real and complex FFT support, implemented with Eigen (CPU), cuFFT (GPU), and TPU support.
    • Significantly expanded complex number support.
    • Bugfixes and contributions to various parts of the runtime, libraries and more.

    Check out the full list of commits on GitHub.