CUDA Graphs in Pytorch

Posted by : on

Category : Deep_Learning

Introduction

Since its debut in CUDA 10, we have been waiting for Pytorch to add CUDA Graphs features. It is great to see that, by the end of 2021, Pytorch has incorporated CUDA Graphs along with tutorials and documentation on how it works and how to use it.

For the sake of brevity, I shall not repeat what is officially available on Pytorch tutorials, but instead I shall focus on some useful tricks for learning about CUDA Graphs. I strongly encourage readers to check out the blog on CUDA Graphs at https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/. I find the article very helpful and informative. However, I want to extend the tutorial a bit further. Hence, I am writing this post.

Prerequisite

  • Pytorch 1.12.1
  • NVIDIA GeForce GTX 1650 Ti
  • CUDA 11.3
  • Windows 10

Problems

When I try to understand the API example in the tutorial, I run into this situation when I print the torch.cudaStream object and the torch.cuda.CUDAGraph object.

<torch.cuda.Stream device=cuda:0 cuda_stream=0x20af329de50>

<torch.cuda.graphs.CUDAGraph object at 0x0000018E91AD4540>

All I can see is that I have created both objects and they are sitting at a specific location/address on memory. Yet, it tells me nothing about any change I make to the graphs and streams as I experiment them.

A dummy operation and tensor are therefore added to reveal the subtle difference and the functioning of this API tutorial.

Below is the original tutorial.


N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                            torch.nn.Dropout(p=0.2),
                            torch.nn.Linear(H, D_out),
                            torch.nn.Dropout(p=0.1)).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')

# warmup
# Uses static_input and static_target here for convenience,
# but in a real setting, because the warmup includes optimizer.step()
# you must use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(static_input)
        loss = loss_fn(y_pred, static_target)
        loss.backward()
        optimizer.step()
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()

real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    # Fills the graph's input memory with new data to compute on
    static_input.copy_(data)
    static_target.copy_(target)
    # replay() includes forward, backward, and step.
    # You don't even need to call optimizer.zero_grad() between iterations
    # because the captured backward refills static .grad tensors in place.
    g.replay()
    # Params have been updated. static_y_pred, static_loss, and .grad
    # attributes hold values from computing on this iteration's data.

If this code looks unfamiliar, please visit https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/. The tutorial suggests that we should start the process of building our own CUDA Graph by doing a “warm up” first. In a “warm up”, we stream capture a linear sequence of execution that occurs on a specific device (in this case, cuda:0). Even though the tutorial uses random data, we ought to use a few batches of real data. Then, we repeat the exact sequence in graph capture. Finally, we replay the captured graph to train a model on real data.

Lesson 1: we perform computation during warm up but not graph capture.

Here is my version of the same tutorial on Lesson 1.

import torch

N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                            torch.nn.Dropout(p=0.2),
                            torch.nn.Linear(H, D_out),
                            torch.nn.Dropout(p=0.1)).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')
static_counter = torch.zeros((1,), device='cuda')

print ('Initial: static_counter ', static_counter)

# warmup
# Uses static_input and static_target here for convenience,
# but in a real setting, because the warmup includes optimizer.step()
# you must use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):

    optimizer.zero_grad(set_to_none=True)
    y_pred = model(static_input)
    loss = loss_fn(y_pred, static_target)
    loss.backward()
    optimizer.step()
    static_counter[0] += 1.0

'''
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams
'''
torch.cuda.current_stream().wait_stream(s)

print ('After stream capture: static_counter ', static_counter)
# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):

    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()
    static_counter[0] +=  1.0

print ('After Graph Capture: static_counter ', static_counter)

ZERO = torch.zeros((1,), device='cuda')
static_counter.copy_(ZERO)

real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    # Fills the graph's input memory with new data to compute on
    static_input.copy_(data)
    static_target.copy_(target)
    # replay() includes forward, backward, and step.
    # You don't even need to call optimizer.zero_grad() between iterations
    # because the captured backward refills static .grad tensors in place.
    g.replay()
    # Params have been updated. static_y_pred, static_loss, and .grad
    # attributes hold values fr

print ('After replay: static_counter ', static_counter)

Static_counter has been added as a dummy tensor, so that we can keep track of the impact during streaming, graph capture and graph replay.

Console outputs are as follows:

Initial: static_counter tensor([0.], device=’cuda:0’)

After stream capture: static_counter tensor([1.], device=’cuda:0’)

After Graph Capture: static_counter tensor([1.], device=’cuda:0’)

After replay: static_counter tensor([10.], device=’cuda:0’)

We observe that static counter’s value has increased by +1 during stream capture, but not graph capture. Therefore, we conclude that we perform no computation during graph capture! Just prior to replay, we reset static counter to zero. After replay, static counter value has raised to 10, reflecting the number of iterations of the for loop.

Lesson 2: only what happens during graph capture matters.

Stream capture aims only at debugging and sanity check. Lesson 2 script is as follows:


import torch

N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                            torch.nn.Dropout(p=0.2),
                            torch.nn.Linear(H, D_out),
                            torch.nn.Dropout(p=0.1)).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')
static_counter = torch.zeros((1,), device='cuda')

print ('Initial: static_counter ', static_counter)

# warmup
# Uses static_input and static_target here for convenience,
# but in a real setting, because the warmup includes optimizer.step()
# you must use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):

    optimizer.zero_grad(set_to_none=True)
    y_pred = model(static_input)
    loss = loss_fn(y_pred, static_target)
    loss.backward()
    optimizer.step()

'''
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams
'''
torch.cuda.current_stream().wait_stream(s)

print ('After stream capture: static_counter ', static_counter)
# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):

    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()
    static_counter[0] +=  1.0

print ('After Graph Capture: static_counter ', static_counter)

ZERO = torch.zeros((1,), device='cuda')
static_counter.copy_(ZERO)

real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    # Fills the graph's input memory with new data to compute on
    static_input.copy_(data)
    static_target.copy_(target)
    # replay() includes forward, backward, and step.
    # You don't even need to call optimizer.zero_grad() between iterations
    # because the captured backward refills static .grad tensors in place.
    g.replay()
    # Params have been updated. static_y_pred, static_loss, and .grad
    # attributes hold values fr

print ('After replay: static_counter ', static_counter)

I deliberately remove the increments of static counter during stream capture but keep it during graph capture. Console outputs are as follows:

Initial: static_counter tensor([0.], device=’cuda:0’)

After stream capture: static_counter tensor([0.], device=’cuda:0’)

After Graph Capture: static_counter tensor([0.], device=’cuda:0’)

After replay: static_counter tensor([10.], device=’cuda:0’)

Since its removal, static counter value remains at zero after stream capture. However, after replay, our graph actually performs 10 iterations of addition!

Of course, in practice, we should execute the same sequence during stream capture and graph capture because we can catch bugs only during stream capture when computation is verified and performed on the device.

Lesson 3: we can explicitly capture a for loop in a graph.

The script below shows that we add a for loop during both stream capture and graph capture.


import torch

N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                            torch.nn.Dropout(p=0.2),
                            torch.nn.Linear(H, D_out),
                            torch.nn.Dropout(p=0.1)).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')
static_counter = torch.zeros((1,), device='cuda')

print ('Initial: static_counter ', static_counter)
# warmup
# Uses static_input and static_target here for convenience,
# but in a real setting, because the warmup includes optimizer.step()
# you must use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(static_input)
        loss = loss_fn(y_pred, static_target)
        loss.backward()
        optimizer.step()
        static_counter[0] +=  1.0

torch.cuda.current_stream().wait_stream(s)
print ('After stream capture: static_counter ', static_counter)
# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    for i in range(3):
        static_y_pred = model(static_input)
        static_loss = loss_fn(static_y_pred, static_target)
        static_loss.backward()
        optimizer.step()
        static_counter[0] +=  1.0

print ('After Graph Capture: static_counter ', static_counter)

ZERO = torch.zeros((1,), device='cuda')
static_counter.copy_(ZERO)  

real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    # Fills the graph's input memory with new data to compute on
    static_input.copy_(data)
    static_target.copy_(target)
    # replay() includes forward, backward, and step.
    # You don't even need to call optimizer.zero_grad() between iterations
    # because the captured backward refills static .grad tensors in place.
    g.replay()
    # Params have been updated. static_y_pred, static_loss, and .grad
    # attributes hold values from computing on this iteration's data.

print ('After replay: static_counter ', static_counter)

Console outputs are as follows:

Initial: static_counter tensor([0.], device=’cuda:0’)

After stream capture: static_counter tensor([3.], device=’cuda:0’)

After Graph Capture: static_counter tensor([3.], device=’cuda:0’)

After replay: static_counter tensor([30.], device=’cuda:0’)

Interestingly, static counter value has risen by the multiplication of the number of iterations of the two for loops. Thus, a “single” replay is in fact a series of 3 repeated executions.

Speed

In this toy example, CUDA Graph offers a 4% reduction in training time on average. Nonetheless, this speed up is not as impressive as the one realized by the official examples/ case studies.


import torch, time
import numpy as np

def setup():

    N, D_in, H, D_out = 640, 4096, 2048, 1024
    model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                                torch.nn.Dropout(p=0.2),
                                torch.nn.Linear(H, D_out),
                                torch.nn.Dropout(p=0.1)).cuda()
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    # Placeholders used for capture
    static_input = torch.randn(N, D_in, device='cuda')
    static_target = torch.randn(N, D_out, device='cuda')
    static_counter = torch.zeros((1,), device='cuda')

    # warmup
    # Uses static_input and static_target here for convenience,
    # but in a real setting, because the warmup includes optimizer.step()
    # you must use a few batches of real data.
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for i in range(3):
            optimizer.zero_grad(set_to_none=True)
            y_pred = model(static_input)
            loss = loss_fn(y_pred, static_target)
            loss.backward()
            optimizer.step()
            static_counter[0] +=  1.0

    torch.cuda.current_stream().wait_stream(s)
    # capture
    g = torch.cuda.CUDAGraph()
    # Sets grads to None before capture, so backward() will create
    # .grad attributes with allocations from the graph's private pool
    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.graph(g):
        static_y_pred = model(static_input)
        static_loss = loss_fn(static_y_pred, static_target)
        static_loss.backward()
        optimizer.step()
        static_counter[0] +=  1.0

    return g, static_input, static_target, static_counter

def test_with_graph(g, static_input, static_target, static_counter, M=200):

    real_inputs = [torch.rand_like(static_input) for _ in range(M)]
    real_targets = [torch.rand_like(static_target) for _ in range(M)]

    ZERO = torch.zeros((1,), device='cuda')
    static_counter.copy_(ZERO) 

    start_time = time.perf_counter()
    for data, target in zip(real_inputs, real_targets):
        # Fills the graph's input memory with new data to compute on
        static_input.copy_(data)
        static_target.copy_(target)
        # replay() includes forward, backward, and step.
        # You don't even need to call optimizer.zero_grad() between iterations
        # because the captured backward refills static .grad tensors in place.
        g.replay()
        # Params have been updated. static_y_pred, static_loss, and .grad
        # attributes hold values from computing on this iteration's data.
    # print ('After training: static_counter ', static_counter)
    time_elapsed = time.perf_counter() - start_time    
    # print ('After training: static_counter ', static_counter)
    return time_elapsed / M


def test(M=200):

    N, D_in, H, D_out = 640, 4096, 2048, 1024
    model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                                torch.nn.Dropout(p=0.2),
                                torch.nn.Linear(H, D_out),
                                torch.nn.Dropout(p=0.1)).cuda()
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    # Placeholders used for capture
    static_input = [torch.randn(N, D_in, device='cuda') for _ in range(M)]
    static_target = [torch.randn(N, D_out, device='cuda') for _ in range(M)]
    static_counter = torch.zeros((1,), device='cuda')

    # print ('Initial: static_counter ', static_counter)

    start_time = time.perf_counter()
    for input, target in zip(static_input, static_target):

        optimizer.zero_grad(set_to_none=True)
        y_pred = model(input)
        loss = loss_fn(y_pred, target)
        loss.backward()
        optimizer.step()
        static_counter[0] +=  1.0

    time_elapsed = time.perf_counter() - start_time    
    # print ('After training: static_counter ', static_counter)
    return time_elapsed / M

if __name__ == "__main__":
    
    print (t0 := test())

    g, static_input, static_target, static_counter = setup()
    print (t1:= test_with_graph(g, static_input, static_target, static_counter))

    print ((t1-t0)/ t0)

It is worth noting that whenever performance we measure, it is very crucial to separate the CPU launch time and the GPU computation time. As explained in the official tutorial, CPU launch time can be slow. Instead of timing test() and test_with_graph(), I therefore choose to time the for loop.

Conclusion

CUDA Graph is quite useful for speeding training up. Especially, CUDA Graph is so convenient to implement and deploy. As CUDA and Pytorch improve symbiotically, we have more tricks to train better AI models.

About

Hello, My name is Wilson Fok. I love to extract useful insights and knowledge from big data. Constructive feedback and insightful comments are very welcome!