Shared Memory

Workgroup shared variables

The Cooperation Problem

Threads within a workgroup often need to share data. Consider summing an array: each thread computes a partial sum, but you need to combine those partial sums into a final result. Or image convolution: each pixel needs to read neighboring pixels, and neighbors of neighbors.

Without shared memory, threads would pass data through global memory—reading and writing to the same slow VRAM repeatedly. This works, but it is painfully slow. What threads need is a fast, local scratch space where they can exchange data.

That is exactly what shared memory provides. It is a small, fast memory region visible to all threads in a workgroup. Reads and writes cost 10-20 cycles, compared to 200-400+ cycles for global memory.

Interactive: Shared Memory Architecture

Compute Unit
Shared Memory (var<workgroup>)
0
7
14
21
28
35
42
49
56
63
70
77
84
91
98
5
~10-20 cycles latency
Workgroup Threads
Global Memory (VRAM)
≈ GB
~200-400+ cycles
Click a thread to select it. All threads can access all of shared memory—that is what makes cooperation possible.
Shared (fast, 16KB)
Global (slow, large)

Each compute unit has its own shared memory. When a workgroup runs on that compute unit, it gets exclusive access to that memory. Other workgroups cannot see it—the memory is private to the workgroup.

Declaring Shared Variables

In WGSL, you declare shared memory using the var<workgroup> address space:

var<workgroup> shared_data: array<f32, 256>;
var<workgroup> tile: array<array<f32, 16>, 16>;
var<workgroup> counter: atomic<u32>;
wgsl

These variables exist for the lifetime of the workgroup's execution. All threads in the workgroup see the same shared_data array. When the workgroup finishes, the memory is released for the next workgroup.

There are limits. WebGPU guarantees at least 16KB of shared memory per workgroup. Some GPUs provide more, but writing portable code means staying within this limit. A 256-element array of f32 uses 1KB—well within bounds.

The Synchronization Problem

Shared memory introduces a subtle but critical problem: race conditions. When multiple threads read and write the same memory, the order of operations matters.

Consider this flawed code:

var<workgroup> data: array<f32, 64>;
 
@compute @workgroup_size(64)
fn broken(@builtin(local_invocation_id) local_id: vec3u) {
  let idx = local_id.x;
  
  // Thread writes its value
  data[idx] = f32(idx);
  
  // Thread tries to read neighbor's value
  let neighbor = data[(idx + 1) % 64];  // BUG: neighbor might not have written yet!
}
wgsl

Thread 5 writes to data[5], then reads from data[6]. But thread 6 might not have written yet. The result is unpredictable—sometimes you get the correct value, sometimes garbage.

Interactive: Why Barriers Are Needed

Shared Memory
T0
idle
T1
idle
T2
idle
T3
idle

Each thread writes its value, then reads its neighbor's value. Without a barrier, reads may happen before writes complete.

Threads in a GPU do not execute in lockstep. Some threads run ahead, others lag behind. The hardware schedules threads in batches (warps/wavefronts), and even within a batch, memory operations can complete out of order.

The workgroupBarrier

The workgroupBarrier() function solves this problem. It forces all threads in the workgroup to reach the barrier before any thread proceeds past it.

var<workgroup> data: array<f32, 64>;
 
@compute @workgroup_size(64)
fn correct(@builtin(local_invocation_id) local_id: vec3u) {
  let idx = local_id.x;
  
  // Phase 1: All threads write
  data[idx] = f32(idx);
  
  // Barrier: wait for all writes to complete
  workgroupBarrier();
  
  // Phase 2: Now safe to read any element
  let neighbor = data[(idx + 1) % 64];  // Guaranteed to have the correct value
}
wgsl

After workgroupBarrier(), every thread can safely read any element of data. The barrier guarantees that all writes from the previous phase are visible.

Think of it as a meeting point. All threads must arrive at the barrier before any can leave. This synchronization ensures a consistent view of shared memory.

The Tile Loading Pattern

The most common shared memory pattern is tile loading: collaboratively load a chunk of global memory into shared memory, synchronize, then process.

For a 1D convolution with a 5-tap kernel:

var<workgroup> tile: array<f32, 68>;  // 64 elements + 2 on each side for halo
 
@compute @workgroup_size(64)
fn convolve(
  @builtin(global_invocation_id) global_id: vec3u,
  @builtin(local_invocation_id) local_id: vec3u
) {
  let idx = local_id.x;
  let global_idx = global_id.x;
  
  // Load main elements
  tile[idx + 2] = input[global_idx];
  
  // Load halo elements (edges of the tile)
  if (idx < 2) {
    tile[idx] = input[global_idx - 2];
    tile[idx + 66] = input[global_idx + 64];
  }
  
  workgroupBarrier();
  
  // Now apply convolution using only shared memory
  let result = kernel[0] * tile[idx] +
               kernel[1] * tile[idx + 1] +
               kernel[2] * tile[idx + 2] +
               kernel[3] * tile[idx + 3] +
               kernel[4] * tile[idx + 4];
               
  output[global_idx] = result;
}
wgsl

Each element in the tile is loaded once from global memory. Without tiling, each output element would require 5 global memory reads. With tiling, the workgroup collaboratively loads 68 elements once, then each thread reads 5 values from fast shared memory.

Interactive: Tiled Memory Access

Shared Memory (tile + halo)
0
1
2
3
4
5
6
7
8
9
■ halo■ main tile
3-tap convolution
Output (8 elements)
0
1
2
3
4
5
6
7
Global Memory Reads
10
Shared Memory Reads
24
Global Read Reduction
58%
The workgroup loads 10 elements from global memory once. Each of the 8 output elements reads 3 values from fast shared memory instead of slow global memory.

Hover over output elements to see which input elements they need. Larger kernels make tiling more valuable.

The speedup compounds with larger kernels. A 9-tap kernel would need 9 global reads per output without tiling. With tiling, the ratio improves further.

Reduction in Shared Memory

Summing an array demonstrates shared memory's power for reduction operations—combining many values into one.

var<workgroup> partial_sums: array<f32, 256>;
 
@compute @workgroup_size(256)
fn reduce(
  @builtin(global_invocation_id) global_id: vec3u,
  @builtin(local_invocation_id) local_id: vec3u
) {
  let idx = local_id.x;
  
  // Each thread loads one element
  partial_sums[idx] = input[global_id.x];
  workgroupBarrier();
  
  // Parallel reduction: halve the active threads each iteration
  for (var stride: u32 = 128; stride > 0; stride >>= 1) {
    if (idx < stride) {
      partial_sums[idx] += partial_sums[idx + stride];
    }
    workgroupBarrier();
  }
  
  // Thread 0 has the final sum for this workgroup
  if (idx == 0) {
    atomicAdd(&global_sum, partial_sums[0]);
  }
}
wgsl

The reduction proceeds in log₂(256) = 8 steps. In the first step, threads 0-127 add elements 128-255 to elements 0-127. In the next step, threads 0-63 add elements 64-127 to elements 0-63. This continues until thread 0 holds the sum of all 256 elements.

Each step requires a barrier. Without it, threads might read stale values from the previous iteration.

Matrix Multiplication Tiles

The canonical example of shared memory is tiled matrix multiplication. Instead of each thread loading entire rows and columns from global memory, workgroups collaboratively load tiles.

For matrices A (M×K) and B (K×N), producing C (M×N):

const TILE_SIZE: u32 = 16;
var<workgroup> tileA: array<array<f32, TILE_SIZE>, TILE_SIZE>;
var<workgroup> tileB: array<array<f32, TILE_SIZE>, TILE_SIZE>;
 
@compute @workgroup_size(TILE_SIZE, TILE_SIZE)
fn matmul(
  @builtin(global_invocation_id) global_id: vec3u,
  @builtin(local_invocation_id) local_id: vec3u
) {
  let row = global_id.y;
  let col = global_id.x;
  let local_row = local_id.y;
  let local_col = local_id.x;
  
  var sum: f32 = 0.0;
  
  // Iterate over tiles along K dimension
  for (var t: u32 = 0; t < K / TILE_SIZE; t++) {
    // Load tiles collaboratively
    tileA[local_row][local_col] = A[row][t * TILE_SIZE + local_col];
    tileB[local_row][local_col] = B[t * TILE_SIZE + local_row][col];
    
    workgroupBarrier();
    
    // Compute partial dot product using tiles
    for (var k: u32 = 0; k < TILE_SIZE; k++) {
      sum += tileA[local_row][k] * tileB[k][local_col];
    }
    
    workgroupBarrier();
  }
  
  C[row][col] = sum;
}
wgsl

Each thread computes one element of C. Without tiling, it would read an entire row of A and column of B from global memory—potentially thousands of accesses. With tiling, the workgroup loads two 16×16 tiles (512 elements) per iteration, and each thread reads 32 values from shared memory.

The second barrier before loading the next tile is critical. Without it, fast threads might overwrite tile data that slow threads have not yet read.

Shared Memory Limits

WebGPU guarantees 16KB minimum shared memory per workgroup. This constrains your tile sizes and data structures.

A 16×16 tile of f32 uses 1KB: 16×16×4=1024 bytes16 \times 16 \times 4 = 1024 \text{ bytes}

Two such tiles (for A and B) use 2KB. You could fit larger tiles—32×32 would use 8KB—but remember that larger tiles also mean larger workgroups, which might reduce occupancy.

Beyond size, there are alignment requirements. Arrays should start at natural boundaries. WGSL handles most alignment automatically, but complex nested structures might need explicit padding.

Common Pitfalls

Forgetting barriers is the most common bug. The shader compiles fine but produces incorrect results intermittently. Some runs work, some fail—depending on how threads happen to be scheduled.

Bank conflicts can reduce performance. Shared memory is divided into banks (typically 32). If multiple threads access the same bank simultaneously, accesses serialize. Strided access patterns often trigger conflicts:

// Potential bank conflict: threads 0, 8, 16, 24 all hit bank 0
let value = shared_data[local_id.x * 8];
 
// Better: consecutive threads access consecutive elements
let value = shared_data[local_id.x];
wgsl

Divergent barriers cause deadlocks. All threads must reach the barrier, or none can proceed:

// WRONG: Only some threads hit barrier
if (local_id.x < 32) {
  workgroupBarrier();  // Deadlock! Threads 32-63 never arrive
}
 
// CORRECT: All threads reach barrier
workgroupBarrier();
if (local_id.x < 32) {
  // Only these threads do extra work
}
wgsl

Key Takeaways

  • Shared memory is fast (10-20 cycles) memory visible to all threads in a workgroup
  • Declare with var<workgroup> in WGSL; WebGPU guarantees at least 16KB per workgroup
  • workgroupBarrier() synchronizes all threads, ensuring writes are visible before reads
  • The tile loading pattern reduces global memory accesses by loading data collaboratively
  • Reduction operations use shared memory to combine values efficiently
  • Matrix multiplication uses tiles to amortize global memory latency
  • All threads must reach barriers; divergent barriers cause deadlocks