Making GPUs Actually Fast: A Deep Dive into Training Performance
Sylvain Gugger & Corwin de Zahr
Jane Street
This talk dives into the performance details of GPUs and why GPUs are useful for training neural network models. We’ll cover the common bottlenecks and how to defeat them. We’ll also show techniques that enable you to use the GPU more efficiently like overlapping CPU work with GPU work and kernel fusion. Doing kernel fusion often requires writing custom kernels, and we’ll go through an example of that.
Transcript
00:04 | Corwin
Hello, everyone. I’m Corwin.
00:05 | Sylvain
I’m Sylvain.
00:06 | Corwin
Today, we’re gonna be talking about how to train a model on GPUs efficiently. So can I get a show of hands of who has used a machine learning model before? Great. If you haven’t used a machine learning model before, I’d highly recommend you going to check out one of the many ones available. There’s Claude, there’s ChatGPT, DeepSeek. Definitely recommend to boost your workflow. But I guess maybe the better question is, who has trained a machine learning model before?
00:35 | Corwin
Still a good number of hands. What about trained a machine learning model using a GPU? So maybe I don’t actually need to give this talk. Seems like you guys already know everything that’s going on. But we can still talk a little bit about it. So I guess the first question is, why would you wanna use a GPU? I mean, we have CPUs. CPUs are great. They’re what we’ve been using forever. At Jane Street, CPUs are really good for our trading systems. They have very good sequential performance, very low latency, a decent number of cores nowadays,
01:06 | Corwin
decent number of threads per core, one or two. And they do that sequential programming extremely fast. GPUs, on the other hand, are just extremely parallelizable and have a lot of throughput. To illustrate that point, what do you guys think? How long does it take to do one floating point operation on a GPU versus a CPU? Anyone have any ideas? How long does it take to do one floating point operation on a CPU? Yeah.
01:36 | Corwin
So a nanosecond. That’s approximately right. What about on a GPU? Anyone have any ideas? Yeah. Two nanoseconds. So two nanoseconds sounds like a good answer. But actually, GPUs are really slow at doing one floating point operation. It’s around a microsecond. So this is a little bit tongue in cheek ‘cause doing one floating point operation is never a thing that you’d actually wanna do on a GPU. But this is like 1,000x different. CPUs are 1,000 times faster than GPUs.
02:07 | Corwin
Why would we use a GPU ever? And the answer is the parallelism. So in doing machine learning, a thing that you often wanna do is matrix multiplication. And matrix multiplication is an operation that is highly parallelizable. You can split up the work entirely. So each thread can deal with only one of the output elements of your matrix entirely independently from all the other threads. And this massive parallelism enables the GPU to do computation extremely quickly. So for doing 2048 by 2048 matrix multiplication,
02:39 | Corwin
how long do you think this takes on a GPU versus a CPU? Anyone have any guesses? Yeah. I don’t know. Maybe a second on the CPU. A second on the CPU. So CPUs are actually a little bit faster than that. That is a good ballpark estimate. But there is stuff like SIMD and multiple cores on the CPU that enable us to get some parallelism. And then it’s actually like 28 milliseconds. And then if you think about the GPU, it’s 200 microseconds.
03:11 | Corwin
So before we had 1,000x faster on the CPU side. Now we’re 100x faster on the GPU side. So this is kind of just absurd. The GPU is getting so much speed up for doing this matrix multiplication. And that’s why we wanna use the GPUs. They have this massive parallelism. They enable us to do this amazing computation that we need, and that’s why we wanna use them. So in this talk, first thing we’re gonna talk about is how to use GPUs to train neural net models.
03:42 | Corwin
And then we’re gonna talk about removing the bottlenecks and doing that training. And then finally, we’re gonna dive a little bit lower level and understand the GPU kernel operations that we do and how to optimize those as well. So I’ll hand it over to Sylvain to talk about neural net training.
03:58 | Sylvain
Thanks. And yeah, maybe to get a sense of how familiar you are with this diagram, can you tell me, with a show of hands, does this make sense or not at all? Just so that I know how much time I should spend on that slide. Please raise your hand if it makes sense.
04:14 | Sylvain
Okay, kind of. So maybe I’ll go a little bit over it, but quickly. So when you train a machine learning model, you do something that looks like this. Mainly, you have some data, usually that comes in the form of some inputs and some desired outputs for your model and you feed them through your model. From the output of your model, you compare them with what you expected, the targets that you had in your data. And you compute the loss, which tells you how wrong your model is. If you are at the start of training, your model is gonna be very wrong because it’s been randomly initialized.
04:46 | Sylvain
And then you compute the derivative of that loss with respect to every parameter of your model, which gives you the gradients. And we do this because, as an analogy, if we’re at the top of a mountain and we wanna go to the bottom, we need to take a step in the downward direction. And this downward direction is gonna be given to us by the gradients. Once we have those gradients, we take that little step in the direction of the gradients to make our model a tiny bit better. And we rinse and repeat the whole process thousands, if not millions of times, until we have a model
05:16 | Sylvain
that’s trained and outputs sensible results. And to do this, we usually use PyTorch here at Jane Street, which is a deep learning framework that is very good for two things. One, it provides us with GPU kernels. So we will be able to execute all of those operations on the GPU and leverage that high parallelism. And two, it comes with the automatic differentiation. So the diagram that we had before looks like this in PyTorch.
05:46 | Sylvain
This is the data part. We iterate through some stream, which is called a data loader in PyTorch that gives us those inputs and targets. We compute the outputs of our model on the batch of inputs and then a loss from this thing. And during those two operations, we execute potentially a lot of math models, leveraging the power of the GPU. PyTorch gives us the gradients kind of automatically. We just have to call this magic incantation, loss.backward, to have them be computed for us. We don’t have to manually code the derivative
06:16 | Sylvain
of every operation we just did. And then we take that little step in the right direction that’s gonna make our model better. And I just added some stuff to basically compute some metrics to figure out how good our model is, and then print the final value once the training is done. And then, so this doesn’t use the GPU for now. But if you wanna use the GPU in PyTorch, it’s actually very easy. You just have to add a couple of lines here and there. Basically, like those .to(“cuda”) kind of everywhere.
06:48 | Sylvain
This is called to(“cuda”) because when PyTorch was first created, it only supported NVIDIA GPUs. And CUDA is the name of the programming language that’s proprietary at NVIDIA to program GPUs. Now PyTorch supports many kinds of GPUs, but they didn’t update this specific API. So sorry, AMD. So you move your model to the GPU. You move your inputs and targets to the GPU. Then all of the heavy computation that we saw before, so the forward of the model and the backward of the model, where you have lots of math models, you use the GPU for that. And at the very end, we move back our outputs
07:19 | Sylvain
and targets to the CPU to be able to compute our metrics. And so yeah, we’ll now take a deep dive into what’s actually happening behind the scene when to pre-execute that code in PyTorch to understand where the bottlenecks might appear. And I’ll lead back over to Corwin to go over this.
07:38 | Corwin
So CPU/GPU pipeline, we wanna understand what the GPU is doing when we’re running kernels in PyTorch. So to do that, we’ll look at a little bit of a simpler piece of code just to make it easier to understand what’s going on.
07:50 | Corwin
So first, we’re gonna execute this getting data operation. And that runs on the CPU, because we’re using the CPU to compute some stuff. Then we’ll send that data over to the GPU. And this operation occurs on both the CPU and the GPU, ‘cause the CPU is sending the data, and the GPU is receiving that data. And so you can see here on the pipeline, we’ll put an operation on both the CPU and the GPU lines. After that, we’ll run a kernel. So a kernel is kind of like a function that does one operation on the GPU.
08:21 | Corwin
And so you might expect that this kernel only executes on the GPU side, because it’s a GPU operation. But in fact, it needs to be run on the CPU as well. So the CPU does some computation to figure out the launch parameters and then communicates that to the driver. And then the driver tells the GPU, please execute this operation. And then the GPU actually does the execution. So all of the kernels take place on both the CPU and the GPU. So after we do this linear layer, which is a matrix multiplication, we then move on to ReLU. ReLU is just another operation. You don’t really need to know what it means. It’s just another thing that is a kernel
08:52 | Corwin
that is launched on the GPU. And finally, we can go back and get data again, and send that data, and then do some more layers. So if you look at the GPU line, you’ll notice that there are a lot of gaps. And that is bad. All of the time that the GPU is not doing work, you’re burning money. Because GPUs are super expensive. You wanna keep them as full as possible. And so this huge set of gaps is really embarrassing. Fortunately, I just lied to you. This is not how PyTorch works. What happens is there is this pipeline.
09:23 | Corwin
So the GPU and the CPU are executing at different times. So going back to our linear layer, the CPU tells the GPU, please launch this operation. And then the CPU can just continue executing while the GPU runs that operation asynchronously. And so immediately, the CPU can continue on and hit the ReLU operation. And I think we’ll need to click the mouse to remove that. Perfect. So it’ll run the ReLU operation, and that will get launched on the GPU side.
09:54 | Corwin
And you’ll notice, because the GPU is still running the linear layer, we will not actually start the execution of the ReLU at that time. The ReLU is queued up. There’s a big queue of operations that the GPU is waiting to execute. And so the GPU will just pull things off that queue whenever it’s ready. But when we launch the ReLU, we don’t immediately start it on the GPU. Instead, we can continue executing on the CPU side, going back to the get data. And then eventually, the linear layer on the GPU side finishes. And we’re able to launch the ReLU. And we can continue on, so on and so forth.
10:27 | Corwin
And you’ll notice that there are no more gaps between the kernels. We’ve eliminated a huge amount of inefficiency when we’re not doing any work. And this is great. This is a much more reasonable way of doing the GPU work. And in particular, what this allows us to do is it allows us to hide the extra computation of the get data. No longer does that get data need to occupy space on the GPU line, because we’ve overlapped it with the kernels that the GPU is executing. Great.
10:57 | Corwin
So we can look at what this looks like in the profiler. So this is the Nsight Systems profiler from NVIDIA. And you can see the CUDA API line describes all the things that are happening on the CPU to launch kernels. And then you can see in the CUDA hardware line the kernels that are actually being launched. And you’ll see the correspondence between the launching of the kernel and the kernel being run. And so you can see that the GPU is far behind the CPU, and the CPU is far ahead. And this is basically the behavior that we want. So the thing that you’re trying to do when you’re executing a GPU program is you’re trying to keep the blue bar at the top
11:28 | Corwin
in the CUDA hardware line full. If your blue bar is not full, your GPU is not being efficient. And so the first step when you’re trying to think about stuff is figure out how to keep the blue bar full. So one thing that sort of gets in our way is usually when you’re doing a computation on the GPU, you wanna actually know what that computation did. And so to do that, you need to copy data back from the GPU to the CPU. Unfortunately, this synchronization point where you copy data from the GPU to the CPU causes your pipeline to stall.
11:59 | Corwin
So you can see here, we’re doing this output.cpu to get the results back. And our CPU is not able to launch more kernels while it’s waiting for the GPU to execute. And so that flushes the pipeline, and that causes us to have really bad performance. So in the profile, you can see this GPU sync is causing the CPU to stall. And this stall takes so long to wait for the GPU to finish, it’s just completely off the screen. So going back to this diagram, you can see we stalled a really long time to wait for this copy to happen. And then that means that the get data,
12:30 | Corwin
which used to be covered up by the asynchronous launching of the GPU kernels, is now exposed on the GPU line. And so you have that time where you’re getting data. That makes your GPU just wait, and this is wasted compute. You’re wasting money. You don’t wanna do that. Before I continue on, does anyone have any questions about how this works? Hang on, can you throw into the mic. Audience (12:52): Why does the CPU have to wait for the GPU to finish before it issues the LIN instruction?
13:00 | Corwin
That’s a good question. And we will talk about that later on. But it’s some complex details of how the copying actually happens. Other questions? Cool.
13:13 | Corwin
So going back to the profile, we can see that synchronizing took a long time. And then you can see, immediately after the synchronized finishes, we have the gap on the GPU side. Looking at that CUDA hardware line, you’ll see that gap in the blue line, which means that we’re not efficient.
13:30 | Corwin
So synchronizing, bad. So the question here is, we had this code that we’re using to do our training loop. Anyone have any ideas where the synchronizations are? Audience (13:41): It’ll be at the .cpu calls at metrics accumulate.
13:44 | Corwin
Yeah, that’s a great one. So yes, we have a synchronization. This .cpu makes it clear to you what’s going on. We don’t wanna do that. Any other ideas where synchronizations are? Audience (14:02): Yeah. When it’s calculating isNaN?
14:05 | Corwin
Where it’s calculating isNaN. That’s a great point. And this one is a little bit sneakier. It doesn’t say .cpu anywhere. So you’re like, why is it synchronizing? Well, what’s happening here is that in order to run the if statement, we need to know what the value of the tensor is. In particular, when you’re doing the conversion from the .isNaN tensor into a Boolean, that’s where the synchronization happens. And that is super pernicious,
14:33 | Corwin
because .isNaN doesn’t cause a synchronization. It’s the if that causes a synchronization. And so this is one of the things that when you’re writing PyTorch code, you ought to be really careful about and make sure you’re checking for any tensor coercions, because those are places where you’re going to have accidental synchronizations. Okay, any ideas for the last synchronization? Maybe there’s more than just one. I know of one. Yeah. Audience (15:03): When you step your optimizer, you’d have to reload your weights onto the GPU to do your next cycle. So there’s probably some synchronization that needs to happen.
15:08 | Corwin
Sorry, can you use the mic? Audience (15:08): When you step your optimizer, you’d have to reload your weights onto the GPU to do your next cycle. So there’s probably some synchronization that needs to happen.
15:17 | Corwin
That’s a good idea. But I think PyTorch is a little bit smarter than that. It doesn’t need you to reload the weights on the GPU every step. The weights are just stored on the GPU, and they remain on the GPU. And the optimizer is a bunch of kernels that also just run on the GPU. So in fact, this is a good non-example of a synchronization
15:38 | Corwin
where, in fact, operations that you might think need to occur on the CPU side are actually just launching a whole bunch of kernels. And so they do occur on the CPU side to launch those kernels, but the actual execution and the storage of the data is all held on the GPU, so we don’t need any syncs.
15:56 | Corwin
Any other ideas? You guys actually spotted this last one earlier. So this copying of the data from the CPU to the GPU,
16:09 | Corwin
you would think, wait, that’s the correct direction. CPU to GPU is the direction we’re launching stuff in. Why do we need to worry about that? And this is some details of how the memory copying works. So when you think about what is the GPU doing in order to do this copy, you have some data on the CPU, and you’re sending it to the GPU. But what happens if the kernel decides to swap out the page containing that data on the CPU? Or if in Python, we decide to deallocate
16:41 | Corwin
the tensor that contains that data? All of these would be really bad, and then the GPU would be trying to read memory that no longer points to the right thing. And so what we have to do is we have to wait on the CPU side for the GPU to actually have finished doing the copying before we can continue executing, ‘cause we have to hold that memory in place while the GPU is copying it into GPU memory. So this is not entirely true for very small tensors. You’re able to do the copy entirely asynchronously,
17:12 | Corwin
where you just squeeze the little data inside of the buffer. So if you have one number and you copy it, it’ll just fit inside the buffer, and so no synchronization needs to happen. But for most things that you’re trying to copy from the CPU to the GPU, you need to have a synchronization in the way the code is currently written in order to keep that data around.
17:32 | Corwin
So one thing is that synchronization is not always so bad. If immediately after finishing your synchronization you start launching kernels again, you’re not gonna wait that long.
17:43 | Corwin
You’re only gonna wait a millisecond. And so if your training loop takes one second to execute, that’s like 10 bps of inefficiency, so it’s not too bad. So not all synchronizations are bad, but some synchronizations are really bad. So this is an example of the CPU sync at the end of the loop. And right after we do that CPU sync, we do the work to get the next batch of data. And that work of getting the next batch of data is really expensive. And so that takes a huge amount of time on the GPU side to get that data, even though the GPU should have been doing work for a previous batch
18:14 | Corwin
while we were getting the next batch. And so this synchronization is extremely bad because it makes us not be able to do the overlap of CPU work and GPU work. And so you really want to avoid these synchronizations.
18:26 | Corwin
Okay, so how do we fix this? Well, to fix the first synchronization of copying the data back to the CPU to do the metrics, all we need to do is just not do the metrics on the CPU. We can just have the metrics be computed instead using kernels. And then only at the very end of our loop
18:46 | Corwin
when we actually need the metrics do we send them back and do that copy. And what this allows us to do is avoid doing the copy on the hot loop of the training loop and instead only do it rarely. For the next one, we can just compute the loss asynchronously. We can check whether the loss is NAND asynchronously. So instead of blocking the main loop, waiting for the loss isNAN check to run,
19:17 | Corwin
we do the loss isNAN check and eventually stop the loop at some point in the future. So this does change the behavior of our code. But if the loss was NAND, the training is gonna blow up anyway. And so we’re doing a little bit of extra useless work. But in the grand scheme of things, we made our training loop so much faster by skipping the synchronization that the extra work that we did at the end when our training did blow up is just not as big of a deal. And so this is a great way. You just wanna do all these checks to make sure things are stable in an asynchronous
19:47 | Corwin
background fashion as much as you can. Okay, so the last one, copying the data from the GPU to the CPU. Well, yeah, people thought about this. They’re like, this copying, it should be done asynchronously. And so the main thing that you need to do to make that happen is you need to tell all of the various systems on the CPU side that this memory should not move. So that’s what this pin memory does. So you tell the tensor, you say, put it into place and make sure that we don’t swap it out to disk. We don’t deallocate the tensor.
20:19 | Corwin
We just leave the memory there. And so the GPU can do DMA, which stands for direct memory access. It basically means the GPU can reach into the CPU memory, grab the data out, pull it back into GPU memory without the CPU being involved. And because we’ve pinned it into a particular place, we know that it’s not gonna move. We know it’s not gonna be deallocated. And so then it’s safe to have this be done asynchronously by the GPU. And so the last thing is you need to pass this non-blocking equals true to the copy. Because a thing that you could do if you didn’t pass non-blocking equals true is you could go in and mess with the memory.
20:52 | Corwin
You pin it to a particular place. And you say, it’s not gonna be moved. It’s not gonna be copied. It’s not gonna be swapped out. But you yourself could go and then touch it and change the values. And this is what would be some non-determinism execution. And so non-blocking equals true says, I know what I’m doing. I am not gonna mess with the memory after I’ve started the copy. And that’s basically what it’s like a promise to say that you know what you’re doing. Cool. Anyone have any questions about how we fixed all these synchronizations?
21:23 | Corwin
Yeah, question? You can pass the… Audience (21:28): So what does check_isnan_async actually do then? Will it still break out of the loop?
21:33 | Corwin
Yeah, so that’s a good question. So there’s a bunch of different things that you can do. One thing is you could have another thread on the CPU side be waiting for the copy from the GPU to the CPU. And so your main thread can continue launching kernels, while that other thread is the one waiting for the copy to finish.
21:54 | Corwin
So that’s one technique. You could also do it using a pipeline mechanism where you are copying stuff. But basic point is you don’t wanna have the main thread stalled waiting for the copy to finish. Audience (22:06): Right, and then the other thread that does the check will preempt your thread?
22:11 | Corwin
Yeah, it just kills the whole process. Audience (22:13): Oh, okay. Makes sense.
22:13 | Corwin
Yeah, I think if you wanna do something intelligent with the data and affect the behavior of the main thread in some way that’s not just killing it, then you need to do some complicated pipeline or inter-thread communication
22:25 | Corwin
with shared memory or something. But if all you’re doing is killing the process, you’re just raising.
22:28 | Corwin
Cool, so now that we’ve fixed all the problems, we see this beautiful blue bar with no synchronizations in our profile. Well, almost. There’s no synchronizations, but there is this gap in when we’re doing memory traffic of doing the copy. It is actually possible to overlap the copy with the computation. I didn’t bother doing that for this plot, but it is, in fact, possible to do that. Just a little bit trickier on the code side.
22:58 | Corwin
But then you would get a fully beautiful blue bar. Cool, so now that we’ve talked a little bit about the high level, about how to remove synchronizations and get your pipeline looking good, and you have that beautiful blue bar, the beautiful blue bar does not mean that you’re using the GPU well. It just means that you’re using the GPU all the time. And so in order to understand if you’re using the GPU well, we need to understand more about how the GPU actually works. So what we’re gonna do is look at the GPU architecture
23:30 | Corwin
a little bit. And I’m gonna do it a little bit high level and simplified, but mostly this is gonna explain what’s going on. So you can imagine, what is a GPU? It has massively parallel computation. It has lots and lots of threads doing lots and lots of computation. And so this is a picture. We have all these cores that are doing computation, and they have all these registers that they’re using to do that computation. And so this seems great, except there’s lots of computations that involve communication between threads. Because if you’re trying to do a reduction
24:01 | Corwin
where you have each thread do a little piece of the computation and then reduce the answers to get the final result, well, they need some way to communicate that. And so actually, we have the compute with the registers, and then we also have this shared memory thing, which is memory that can be accessed by multiple threads at the same time, or multiple threads to communicate with one another. And so this seems great. You have all of the compute, lots and lots of threads. You have the registers for doing that computation, and then you have shared memory that allows them to communicate with each other. And this would be amazing if GPUs worked like this,
24:32 | Corwin
because this is sort of like the fanciful everything is fast world. That’s not how it actually looks. Instead, what a GPU has is hundreds of these individual blocks of compute with shared memory. These are called SMs for NVIDIA GPUs. And an H100 has 132 of these SMs. And the SMs can’t really talk to each other. They can, but it’s really expensive, and you kind of don’t wanna ever do it. So you basically wanna think about you have all these SMs that each can do lots of good computation,
25:04 | Corwin
and there’s lots of them, so you can do lots of stuff in parallel by splitting among the SMs. And within one SM, you have lots of threads, so you can do lots of parallel computation there. And then they can use shared memory to do some amount of reductions. And so the main thing that you wanna do when you’re trying to think about how do you program a GPU is how do you divide up the work of your problem to chunks that each go onto one SM? So you have your big problem. You divide it up into chunks. You assign each SM one chunk. Within that SM, it does some computation. It does some sharing to cooperate between the threads.
25:35 | Corwin
And then you store that result after you’ve computed everything back out to memory. And that’s the basic flow of how a program on a GPU kernel works. Cool. So there’s also these other two things. There’s the L2 cache and global memory. So global memory is where all the results are stored and where all the inputs come from. And the L2 cache also sort of interposes and is a shared thing across all the SMs and interposes to global memory to make things faster. And to give you an idea of how fast things are,
26:08 | Corwin
global memory is like 3 terabytes per second on an H100. So that sounds pretty fast, but it’s kind of slow in comparison with L2 cache with the 7 terabytes per second. And then the shared memory is like 26 terabytes per second. So if you think about how fast it is to access memory, you really only wanna be using shared memory as much as possible. And then, okay, if you can’t use shared memory, then try to fit it in L2 cache as much as possible. And if you can’t fit it in L2 cache, then okay, global memory would be okay. But for the most part, trying to stay in the higher layers
26:39 | Corwin
of the memory hierarchy is kind of the key component of thinking about how to write an efficient kernel. You need to think about, okay, what are the things that I can load in shared memory and do my computation only using those things? And that’s the whole thing about chunking up the work, is what can I fit into shared memory? So when thinking about the performance of a kernel, the two things that you need to think about are how fast the compute is, like how many matmuls can you do or how many floating point operations can you do versus how fast is it to get the memory in from global memory to your tile of computation?
27:12 | Corwin
And so those are the two main things that you need to consider. And you think about how big is your total problem and how much total computation or how big is your tile and how much compute for that tile and how much memory traffic for that tile and where is my bottleneck? And there is actually one other one, which is the kernel overhead. So here, we’re launching a bunch of kernels. It takes actually some decent time to launch a kernel. It’s like 5 to 10 milliseconds on the CPU side. Sorry, 5 to 10 microseconds on the CPU side. And then on the GPU side, there’s usually about a microsecond of gap.
27:44 | Corwin
So if you’re launching kernels that are too small, that take under a microsecond to execute, well, then you’re just entirely bottlenecked by overhead. And so those are the three main things that you need to worry about. And you can see here in this plot, we’re doing all these CUDA API launches. And then the GPU just gets completely starved. We used to have that beautiful blue bar. And then we started launching these small kernels. And things are still mostly blue but getting a little jagged. And then eventually, we just gap, gap, gap, gap,
28:15 | Corwin
waiting for the CPU to launch kernels. And this GPU starvation is like, I want you to think about Gromit here. You’ve got to lay down enough kernels for the GPU to execute before the GPU gets there. Otherwise, you’re gonna crash. And we don’t want Gromit to be sad and crash. So got to make sure your work is long enough, your rails are long enough for the GPU to continue executing. So I’ll hand it over to Sylvain to talk about these bottlenecks.
28:40 | Sylvain
So yeah, as Corwin said, we have three types of bottlenecks, compute bottleneck, memory bottleneck,
28:45 | Sylvain
and kernel overhead. And fortunately, there are ways to mitigate all of those. And I’ll explain high level what they are, and then we’ll dive a little bit more into some of those. For the compute bottleneck, NVIDIA GPUs come with these things called tensor cores. Especially for kernels that are very compute bottleneck, like matmuls, where you have to do lots and lots of compute for not that many memory reads, tensor cores are great. Because tensor cores are specialized cores on the GPU that do very small tiles of matmuls.
29:16 | Sylvain
And this is where all of the teraflops of training machine learning model come from nowadays. We’re not gonna dive too much into that, because it would take an hour-long talk, and we don’t have another hour to talk about them. But if you’re interested, I definitely recommend checking out talks, for instance, talking about flash attention kernels. I know they kind of change the algorithm of the attention to leverage the tensor cores and make sure they use them 100% of the time. For the memory bottlenecks, the main thing you wanna do,
29:46 | Sylvain
usually, is something called kernel fusion. Basically, if you’re bottlenecked in your kernel by loading stuff from memory, you need to do more stuff inside of your kernel. So instead of launching one small kernel that does, I don’t know, an addition, and then another small kernel that does a multiplication, like try to fuse those two things together so that you load the memory, do all of your compute, hopefully enough compute to kind of amortize the cost of loading the memory, and then save the results. And this is the thing we’re gonna look deeper into for the rest of the talk.
30:18 | Sylvain
And then for the kernel overhead, again kernel fusion can help because if you have less kernels and they do more stuff, you will amortize the cost of that overhead over like more operations. Another thing that can help is a technology called CUDA Graphs on NVIDIA GPUs. The idea there is that if you want to launch lots of small kernels, you can record that once and for all. So the first time you do it, it’s still gonna be slow and look like the profile trace that Corwin showed earlier. But the second time you replay that graph and the second time you do the same thing, everything’s gonna be like stitched together on the GPU
30:48 | Sylvain
because it will have remembered like all the launch parameters that were needed for all of those kernels, so you won’t pay that overhead more than once. And so this is a technology that’s particularly helpful when you want to launch lots of small kernels and you can’t really like fuse them all together for some reason. So yeah, I kind of said the idea at a high level, but let’s say we have this string of competition. So let’s say we have a particular operation that’s often appearing in neural nets called LayerNorm and then we want to do like some activation function like a ReLU and then we want to like add stuff back.
31:21 | Sylvain
This is a skip connection. Again, all that stuff is like very common in neural nets and all of those, I’ve chosen them because like all of those operations are heavily memory bottlenecked. Like basically, this is just like element-wise vector operations that you want to do. And so for each of those kernel, you’ll spend some time like loading the inputs, doing some compute and then storing the output just so that you load them again. So like we start with LayerNorm. LayerNorm like loads inputs and weights. It’s gonna split them across the SMs that we showed earlier.
31:51 | Sylvain
Each SM is gonna do some work, store the results back in the global memory and then immediately after we load potentially like the same output from global memory back into our SM to do the ReLU and store the results back to the global output. And then finally, we reload again the inputs that maybe were loaded previously from global memory into our shared memory to do the addition. And that’s kind of inefficient. Like each time you have a red followed by an orange, you kind of wanna remove that just because that seems pretty inefficient and like the thing you really wanna do is this like loading memory once,
32:24 | Sylvain
doing all of the operations at once, LayerNorm, ReLU and addition and then store the results back. The problem there is that then you need like an infinite number of kernels. Like each of those operations, LayerNorm, ReLU, add, this is a kernel that someone hand-coded and if you want to have like fused kernels for everything, you basically need to like write those few kernels for everything and you need like kind of infinite number of people coding infinite number of kernels. Fortunately, PyTorch come with something that helps which is called torch.compile.
32:55 | Sylvain
So the torch compiler is something that is gonna trace what you’re doing with your inputs and try to find opportunities for fusions. And it’s actually extremely good. Like I chose those LayerNorm, ReLU and add because they are memory bottlenecked operations but also because that’s something that torch.compile is really good at kind of automatically fusing together. So yeah, it’s gonna analyze. The first time you run a model with torch.compile it’s gonna be very slow because it’s gonna trace what’s happening to your inputs and like what operations are done inside of your models to those inputs.
33:26 | Sylvain
And then it’s gonna try to find opportunities for fusion and run some micro benchmarks to find like is the fusion here helpful or not? Like is it making things faster or slower? And once it’s run all of that and decided like you come up with a model with some fused kernel that’s gonna be much faster the second time you try to use it. One thing to know when you’re doing things that use torch.compile is that the PyTorch compiler is becoming better and better but it doesn’t necessarily know everything. And so if you have an operation that’s kind of fancy
33:56 | Sylvain
in the middle of your model that PyTorch doesn’t know about it’s gonna create what we call a graph break. Like PyTorch is not gonna try to fuse that operation it doesn’t know with anything because it can’t do that, it doesn’t know it. So it’s gonna try to fuse things before, it’s gonna try to fuse things after. But basically if you have lots of graph breaks you’re kind of removing lots of opportunities for kernel fusion inside of your model which makes the torch.compile model a bit less efficient. So yeah, kernel fusion, fine, torch.compile does that automatically for us. Maybe this is the end of the talk, but no.
34:27 | Sylvain
Like sometimes the PyTorch compiler is not… Oh, you have a question, sorry. Audience (34:31): Do you have an example of these bad operations that might come up?
34:36 | Corwin
So there’s some bad examples. I think at one point PyTorch didn’t know how to represent dictionary comprehensions. And so if you had a dictionary comprehension then it would break the fusion through the dictionary comprehension. And this is just like a pure Python feature so it should be totally fine but actually broke the compilation. There’s like other things like if you write
34:58 | Corwin
a custom kernel and that like you’ve put that in as a custom operator then PyTorch is like not gonna know how to fuse that custom kernel that you wrote with the kernels that are happening around it.
35:09 | Sylvain
And another thing that could happen is that you have an accidental GPU/CPU sync inside of your model that you hadn’t detected before. This is also gonna create a graph break because PyTorch is like, yeah, the operation is… everything that requires synchronizing the CPU and the GPU basically is considered a graph break for the compiler. And it can’t fuse that with anything.
35:30 | Sylvain
Which is actually a great way to detect if you have CPU/GPU sync in your model. Like try to compile it with the option for graphical true and it’s gonna error if you have kind of those hidden synchronization as we had before.
35:45 | Sylvain
And so yeah, sometimes the PyTorch compiler is great and does everything right and sometimes it’s not so great and does everything wrong. And so we do need to like get our hands a bit dirty and like write those custom kernels that are gonna fuse a lot of operations. And I’m gonna chat a little bit
36:00 | Sylvain
about that for the end of our talk. So the operation I’m gonna take an example of, is this operation where we wanna sum many tensors. Like the actual use case is a researcher wants to log the mean of every gradient in the model because they want to like kind of understand like why the training is not going as they expected. And so to do that you need to like sum all of the elements in lots of tensors.
36:32 | Sylvain
And this is something that’s particularly bad when you want to use your GPUs. In general like each time you have a for loop and you’re programming stuff on the GPU that is probably extremely bad. Because like here for instance, this for loop could be done in parallel because there is no interaction between the values t.sum. And so like we are not leveraging the parallelism of the GPU because the for loop is gonna be sequential. And then the other thing that happens is that we kind of are launching lots and lots of very small kernels and we’re back to the picture that Corwin showed before about the kernel overhead problem. Where like if we zoom in or why this blue line
37:04 | Sylvain
is like so small is basically because our GPU is completely starved. Like the CPU takes five microseconds to enqueue a kernel that takes less than a microsecond to execute on the GPU. So basically we’re spending like 20 milliseconds of just having the GPU waiting for some work of the CPU. And this is… so like we have lots of small kernels. PyTorch, torch.compile is great at fusing kernels together. Maybe if we apply torch.compile it’s all gonna go away. Except not. It’s actually a bit worse.
37:34 | Sylvain
The idea is why is it a bit worse? It’s just because when you launch kernels without torch.compile, PyTorch is gonna take the absolute best kernel that exists for the operation you want. Usually coded by NVIDIA. Whereas the PyTorch compiler relies on some templates of usually written in a language called Triton which are maybe 80 or 90% of the performance of the best NVIDIA kernel but not like the most performance thing. And the idea is that it’s still gonna be faster at the end because you fuse things together. So you will gain way more than what you lose
38:05 | Sylvain
with this template thing. But when the torch.compile is not really able to fuse things, you actually end up with something that takes a little bit longer. And so that’s the case here. Basically the torch compiler is not really a real compiler. Like it’s not like saying oh this is a for loop. This for loop could occur in parallel. I’m gonna create a kernel that’s gonna do exactly that. It’s just like it’s tracing what’s happening to your input so like it sees the for loop completely unrolled and it’s not able to like, yeah, it’s gonna fuse things here and there inside of that for loop but it’s not able to do that thing very efficiently.
38:36 | Sylvain
So this is a situation where we actually need to write a custom kernel to go way faster. We’ll see at the end of the talk how much faster we get with that. So yeah, when we wanna write a custom kernel, the first question that we need to ask ourselves is how do I split the work between my SMs? So Corwin like pointed over like the GPU architecture a little bit earlier, we have some compute and shared memory grouped in those streaming multiprocessors and each of them cannot talk to the other. So you need to split the work in independent chunks.
39:07 | Sylvain
So here for our problem it’s very easy because we have lots of tensors. Computing the sum of individual tensors are problems where you don’t need to talk to each other so like we can just assign one tensor to one SM. You might think like there’s a problem. We have 132 SMs on an H100 and we have like thousands of tensors because this is a very big model that has like lots of layers. That’s actually not that problematic because when you assign work on the SM it’s very dynamic. So like as soon as one SM is finished it’s gonna request more work and it’s gonna be assigned a new tensor and even if like our tensor shapes
39:38 | Sylvain
are kind of different it’s gonna even out in the long term. Like for instance here SM2 is probably gonna be done before the other SMs because it has a smaller tensor. So it’s gonna like be the first one to request more work. Maybe the next work will be a bigger tensor. Maybe it’s gonna be another small tensor again but like if it just keeps getting small tensors it will do more work than, it will like compute the sum of more tensors than the other SMs. And so in the grand scheme of things it’s just gonna be fine. So yeah, we split our work this way, one tensor per SM
40:09 | Sylvain
and then inside of each SM we also have lots of parallelism. So we need to like compute the sum of one tensor on an SM using that parallelism, using those mini-cores that we have on each of SM. So like here it’s just the idea of this is a sum. Sum has a very parallel implementation which is just like we kind of take each thread and assign them like we’re gonna compute partial sums and then we’re gonna reduce the partial sums together using the shared memory. One thing we need to worry about is that we can’t just say like thread one is gonna take the first four elements
40:40 | Sylvain
and then thread two is the next four elements just because we want threads to be loading data that sits next to each other. This is just an optimization because on GPUs if each thread wants to read things that are consecutive in memory they’ll be able to do a much wider load together that’s gonna go much faster. So that’s why we split the work this way and not like saying just compute the sum of the first four elements for instance. And this is assuming we have eight threads in the picture because I was lazy but in actuality we’ll have like 1,024,
41:11 | Sylvain
threads on each SM for instance. And so once each thread has computed the partial sums we use the shared memory and we can like reduce until like for instance like the first thread has the final result and then this first thread will store that final result into the global memory where we wanted the result. Looking briefly at code I know it’s been a long talk and you’re a bit tired but like let’s just look at some big blocks of how we would code this.
41:41 | Sylvain
So the custom kernel you need to go to like the language in which we program GPUs which is called CUDA and which is a variant of C++. So writing a custom kernel looks like this. You have a function that is executed on the CPU that launches as a kernel. So here it’s called like sum_many_kernel and when you launch it you say how many blocks you want that’s gonna be like how many SMs and then in each block how many threads you want. So like here we request 1,024 threads in each block, and we launch like one block per tensor because we said like we would split the work
42:11 | Sylvain
giving one tensor to one SM. Then the sum_many_kernel is going to look like this roughly. The first thing is finding out who you are. Like every thread on the GPU is gonna execute the exact same code so we just need to make sure that they specialize to like the particular tensor they are supposed to be taking the sum of. And we have this global constant called blockIdx which tells us which SM we are. So by basically taking the value of this constant we can know, oh, I’m supposed to compute like the sum of this particular tensor.
42:44 | Sylvain
And then this is the loop where each thread computes partial sums. So each thread begins at its threadIdx. So threadIdx is the same thing as blockIdx except it tells you like which thread you are among like the 1,024 threads of the block. So like thread zero begins at zero and then goes to like the element 10, 24, while thread one begins at one and then goes to the element 1,025 and so forth. And so this for loop computes the partial sums and then we have like to reduce all those partial sums together to compute the final result.
43:15 | Sylvain
I’ll go over that in a minute. And finally the thread number zero will store the final result at the right place in the output. And again to find like what is the right place in the output we use this blockIdx which tells us for which tensor we’re responsible of computing the sum. And yeah, the sum block function going roughly over this is, oh we talked about that, okay. I forgot this part. The sum block function is going to use
43:48 | Sylvain
another thing called warps. One thing we didn’t tell you about is that the code on the GPU is gonna be executed like not all the threads are executing in lock steps. They are executing in lock step by groups called warps. So 32 threads, that’s whyWARP_SIZE is 32 at the top. 32 threads always execute everything in lock steps. And inside of a warp we have like some way to share things even faster than shared memory. We don’t even need to go to shared memory. Inside a warp we can just use registers to share values.
44:19 | Sylvain
So that’s why like this distinction here is important. So we requested like a block of size 1,024. So we have 32 warps of 32 threads. So we can then compute like what’s my warp_Idx and what index I have inside of the warp. And we have our shared memory where we’re gonna store all the partial results. And the way we do this is we only store the partial results for 32, the number of warps, 32. Like not the partial results for all of the threads because we can use those warp_reduce things
44:49 | Sylvain
that are much faster. And so this way we only have like to store like NUM_WARPS, partial sums. So we do that and then we store like lane_Idx zero. So the threadIx zero inside of warp will store the results in the shared memory. And very importantly, we need to synchronize everyone to make sure that since different warps don’t necessarily execute in lock step, they might arrive at this line at different times and we want to make sure that everyone is here before we continue. Because we want to make sure that everyone sees the same shared memory before we continue.
45:22 | Sylvain
And then if we are the warp_Idx zero, we do one final warp reduce because since our thread block size 1,024 was 32 by 32, we have 32 values to reduce which is exactly the size of a warp. So we can just use another warp_reduce to do the final warp reduction. And if we’re not warp zero, we just return something. We don’t care because if you remember the code of the previous slide, only threadIdx zero is gonna store the results. So only threadIdx zero needs to have the right results returned to it. So that was a lot of content. But the main result is that it used to take
45:56 | Sylvain
like 20 milliseconds to do that sum of lots of tensors. And with the custom kernel that takes a week to write if you’re new, an afternoon to write if you’re more experienced you get that time down to 30 microseconds. So almost like a thousand times faster. So it’s definitely worth learning how to sometimes go deeper into what the GPU programming looks like. More questions? Audience (46:28): Yeah. When you’re writing custom kernels at Jane Street, how often do you write like CUDA C++ directly or something higher level or lower level like PTX?
46:36 | Sylvain
Yeah, so that’s a great question. So first thing we try not to go deep at all. Like if we can we express the computation in pure PyTorch like maybe vectorizing the code a little bit better, maybe like accessing some function that can do some stuff for us a bit better. We do that. And then the next thing we try is usually Triton which is DSL in Python that gives you a high-level language
47:03 | Sylvain
to program the GPU but doesn’t give you access to all of the features of the hardware. So usually you can code things there but like usually you will maybe have access to 80 to 90% of the performance of the GPU. And then when you either like because our kernel is not suited at all to Triton and we can’t write it here or because we want the absolute best performance, we’ll reach deeper into CUDA. We don’t write a lot of PTX. Sometimes you have a CUDA kernel and like you need to like inject a particular PTX instruction
47:33 | Sylvain
because you really want the compiler to do that but we don’t like write the whole kernel in PTX.