Workspace in cuDNN
Allen
is LHCb’s GPU-based real-time trigger framework, responsible for running event reconstruction during data taking. Unlike offline frameworks, it operates under strict latency and throughput constraints, tens of microseconds per event, and targets like 130 kHz throughput (on NVIDIA A5000
). Integrating inference engines into such an environment forces you to think differently about memory. That’s where workspace
in cuDNN
becomes a real design variable, not just an API detail.
When working on the inference engine integration in Allen
, I had to get comfortable with how cuDNN
handles workspace
. It’s one of those things that doesn’t look important at first. It seems like just a couple of extra arguments in function signatures, void* workspace
, size_t workspaceSizeInBytes
. But once I started pushing the GPU at LHCb trigger requirements, workspace
turned out to be one of the key things we have to manage carefully if we don’t want performance to fall off a cliff.
cuDNN
expects the user to query the size of the workspace
buffer needed by a given algorithm and allocate it. There’s no fallback or auto-allocation. No guardrails. If we pass in a null pointer when the algorithm needs memory, it either errors or crashes (and debugging that in production kernels is a pain). If we allocate it inside the event loop, we’re wasting time and introducing fragmentation. So we handle it up front, allocate once, and reuse.
Every layer in the inference engine, conv
, activation
, batchnorm
, etc. has to be queried separately, because each has its own workspace
needs. Here’s what this looks like in practice:
cudnnConvolutionFwdAlgoPerf_t perf;
int ret_count = 0;
cudnnFindConvolutionForwardAlgorithm(
handle, x_desc, w_desc, conv_desc, y_desc, 1, &ret_count, &perf);
size_t ws_size = 0;
cudnnGetConvolutionForwardWorkspaceSize(
handle, x_desc, w_desc, conv_desc, y_desc, perf.algo, &ws_size);
cudaMalloc(&workspace_ptr, ws_size);
This happens once per layer at initialization. We store the pointer and size in the layer struct and reuse it for every event. It’s predictable, low overhead, and avoids runtime memory management. If the layer ends up using an algorithm that doesn’t require a workspace
, cuDNN
just returns 0
, and we skip allocation.
In Allen
, we can’t afford to scatter allocations across layers and streams. At 70 kHz (e.g., what we get for the full default HLT1 sequence on an RTX 2080), even small stalls from GPU heap fragmentation or repeated cudaMalloc
calls add up. So instead of allocating workspace per layer, we reserve a contiguous chunk per engine thread or stream and assign sub-ranges to each layer based on their requirements. This model scales better and makes memory pressure easier to track when we’re running multiple models or larger inputs.
One thing that becomes obvious after working with this setup for a while is that the fastest cuDNN
algorithms aren’t always usable unless we have enough memory to give them their preferred workspace
. So we end up having to choose, either allocate the memory and get the performance, or fall back to a smaller, slower algorithm. This trade off becomes especially relevant when we experiment with inference at higher batch sizes or with multiple concurrent events (which is exactly what we want in trigger environments).
I sometimes benchmark multiple algorithms and compare their workspace
sizes. For example:
IMPLICIT_PRECOMP_GEMM
: fast, but may require 20–50 MB ofworkspace
IMPLICIT_GEMM
: less memory, but noticeably slowerFFT_TILING
: very memory-hungry, not worth it for small inputsWINOGRAD
: good compromise, but sensitive to layout
These are tunable, and we cache the best-performing one per layer during setup. Once chosen, we don’t revisit the decision unless input shapes change significantly.
In our case, the actual workspace
usage per layer tends to be in the hundreds of kilobytes, but we care about predictability more than raw size. That’s why we log the numbers explicitly during initialization:
std::cout << "Layer " << layer_name
<< ": workspace = " << workspace_size / 1024.0f << " KB" << std::endl;
We also track the global peak workspace
per engine stream and provision accordingly. On GPUs with tighter memory budgets like when running inference alongside tracking and matching knowing what eats memory helps make better trade-offs upstream.
Another subtle issue is what happens when layers run concurrently. In some parallel inference configurations, we need distinct workspace
buffers per active stream. That means either pre-allocating N buffers, or using something like a stack allocator to carve out slices from a larger block. Both options work, but the former is simpler and safer in our case.
I tried a shared allocator at one point, but it added locking overhead and didn’t help much. We reverted to fixed per-thread allocations. Since Allen
already uses a fixed thread pool, and each thread runs one model at a time, this worked without issues.
The principle I follow now is, workspace should be silent. It should never be the thing we debug at runtime. If it is, something’s already gone wrong, be it bad allocation policy, broken assumptions about concurrency, or a regression in algorithm selection.
It reminds me of __shared__
memory in CUDA
, fast, local, and very easy to misuse if we don’t plan. I don’t think of workspace as “extra memory” anymore. It is a part of the algorithm contract, not something to treat as an afterthought, but something we model and provision just like thread layout or tensor shape.