Step 1: Getting Started with CUDA Programming

Parallel matrix multiplication using CUDA C++.

Step 1: Getting Started with CUDA Programming

Introduction

If we look at the matrix multiplication algorithm carefully (Figure 1), it's obvious that each output matrix element can be computed independently. Each output element requires a unique combination of row and column of the input matrices, and most importantly, one output element does not depend on any other output element.

Figure 1: Matrix multiplication algorithm
void cpu_xgemm(MatrixFP32 A_mat, MatrixFP32 B_mat, MatrixFP32 C_mat)
{
    // Getting A Matrix Dimension
    int A_n_rows = A_mat.n_rows; 
    int A_n_cols = A_mat.n_cols;

    // Getting B Matrix Dimension
    int B_n_rows = B_mat.n_rows; 
    int B_n_cols = B_mat.n_cols;

    // Getting C Matrix Dimension
    int C_n_rows = C_mat.n_rows; 
    int C_n_cols = C_mat.n_cols;

    // Asserting dimensions
    assert (A_n_cols == B_n_rows && "Matrices A & B must have one common dimension");
    assert (A_n_rows == C_n_rows && "A rows must be equal to C rows");
    assert (B_n_cols == C_n_cols && "B cols must be equal to C cols");

    // Matrix Multiplication
    for (int row = 0; row < A_n_rows; row++)
    {
        for (int col = 0; col < B_n_cols; col++)
        {
            float val = 0.0f;
            for (int k = 0; k < A_n_cols; k++)
            {
                val += A_mat.ptr[row*A_n_cols + k] * B_mat.ptr[k*B_n_cols + col];
            }
            C_mat.ptr[row*C_n_cols + col] = val;
        }
    }
}

Sequential matrix multiplication on a CPU

This means we can parallelize loops row and col and reduce the computation cost significantly. GPUs are designed specifically for problems like this. Figure 2 shows a screenshot from NVIDIA's website. The first thing they mention about their GPUs is the number of CUDA cores. CUDA cores are the basic computational units in NVIDIA GPUs that handle parallel processing tasks.

Figure 2: NVIDIA GPU specifications
💡
There is a lot more to a GPU than the number of cores and we will discuss other details whenever they are necessary. The last thing I want to do is overwhelm with details at the very start and lose sight of the goal.

CPU and GPU

CPUs and GPUs have execution units (called cores) that perform the arithmetic operations and an off-chip memory unit (RAM and VRAM, respectively) to store the required data. The big difference is that a CPU has much fewer cores than a GPU.

Figure 3: CPU and GPU

A GPU can't function independently. It's the job of a CPU to move data between RAM and VRAM (or global memory). In other words, a CPU can be seen as an instructor who manages most of the tasks and is responsible for assigning specific tasks to the GPU (where it has an advantage).

On the software side of things, we don't really control the processing units. Instead, the hardware spawn threads, which a programmer can work with. A thread can be seen as an individual worker, and the execution of a thread is sequential as far as a user is concerned, i.e., a worker can only do one task at a time, but we can have multiple workers that can all work in parallel.

💡
Threads are at the heart of modern computing. A thread is a simplified view of how a processor executes a sequential program in modern computers.

GPUs are suited for parallel programming because the hardware can spawn millions of threads (way more than the number of physical cores). CPUs on the other hand can only generate a handful of threads (~8 to 128).

GPU Programming Model

When we write a program in high level language like C/C++, by default, the CPU execution uses a single thread. All the matrices are on RAM (at this point) and it's the CPU thread that copies the data to global memory. The GPU can then spawn multiple threads, that work in parallel to reduce computation time. Once the execution finishes, CPU thread copies the results back from global memory to RAM.

Figure 4: GPU programming model

A CPU thread can't just copy data straight away. The very first step is to allocate the exact amount of memory required on global memory. It is done using cudaMalloc function (which is provided by CUDA). This function accepts two parameters:

  • The first parameter is the address of the pointer variable for the select matrix. The address of this pointer must be cast to (void **) because the function expects a generic pointer (not restricted to a specific type).
  • The second parameter is the size of the data to be allocated (in number of bytes).
// Device array pointers for N x N matrix
float* d_A;

// Device memory allocation
cudaError_t err_A = cudaMalloc((void**) &d_A, N*N*sizeof(float));
// CUDA error checking code (see code repository for more details)
CUDA_CHECK(err_A);
💡
The first parameter is the address of the pointer because this allows the cudaMalloc function to return a value (of type cudaError_t) reporting errors during global memory allocation (which can then be passed to the CUDA error-checking code).

As I'm using a custom class MatrixFP32 to handle matrices, I can embed the above mentioned code such that global memory gets allocated whenever an object destined for GPU is created. To do this, I made two modifications:

  1. Added a member variable on_device that keeps track of whether the object is destined for CPU or GPU memory (this is set to true for the GPU).
  2. Added a separate initialization/memory allocation code for global memory in the constructor and modified the free_mat method appropriately.
class MatrixFP32
{
public:
    const int n_rows;        // Number of rows
    const int n_cols;        // Number of cols

    // Pointer to dynamic array
    float* ptr;
    
    // Matrix in device memory: true; else: false
    const bool on_device; 
    
    // Constructor to initialize n_rows x n_cols matrix
    MatrixFP32(int n_rows_, int n_cols_, bool on_device);
    
    // Free memory
    void free_mat();
};

MatrixFP32::MatrixFP32(int n_rows_, int n_cols_, bool on_device_) : n_rows(n_rows_), n_cols(n_cols_), on_device(on_device_)
{
    if (on_device_ == false)
    {
        // Initialize dynamic array
        ptr = new float[n_rows*n_cols];
    }
    else
    {
        // Allocate device memory
        cudaError_t err = cudaMalloc((void**) &ptr, n_rows*n_cols*sizeof(float));
        cuda_check(err);
    }
}

void MatrixFP32::free_mat()
{
    if (on_device == false)
        delete[] ptr;
    else
        cudaFree(ptr);
}

Once the memory has been allocated in the global memory, data can be transferred between RAM and global memory using cudaMemcpy function. This function accepts four parameters:

  • Pointer to the destination
  • Pointer to the source
  • Number of bytes to be copied
  • The direction of transfer: 
    • When copying data from RAM to global memory, this will be cudaMemcpyHostToDevice
    • When copying data from global memory to RAM, this will be cudaMemcpyDeviceToHost
    • Note that these are symbolic predefined constant of the CUDA programming environment.
// Copying A to device memory
cudaError_t err_A_ = cudaMemcpy(d_A, A, N*N*sizeof(float), cudaMemcpyHostToDevice);
CUDA_CHECK(err_A_);

// Copy C to host memory
cudaError_t err_C_ = cudaMemcpy(C, d_C, N*N*sizeof(float), cudaMemcpyDeviceToHost);
CUDA_CHECK(err_C_);
💡
In accelerator language, host refers to the CPU and device is the accelerator, here the GPU.

To automate a lot of this stuff, I added two more methods to the class that can copy data between CPU and GPU (while checking the size compatibility and other details).

void MatrixFP32::copy_to_device(MatrixFP32 d_mat)
{
    // Make sure that ptr is on host 
    assert(on_device == false && "Matrix must be in host memory");
    assert(d_mat.on_device == true && "Input Matrix to this function must be in device memory");

    // Copying from host to device memory
    cudaError_t err = cudaMemcpy(d_mat.ptr, ptr, n_rows*n_cols*sizeof(float), cudaMemcpyHostToDevice);
    cuda_check(err);
}

void MatrixFP32::copy_to_host(MatrixFP32 h_mat)
{
    // Make sure that ptr is on device
    assert(on_device == true && "Matrix must be in device memory");
    assert(h_mat.on_device == false && "Input Matrix to this function must be in host memory");

    // Copying from host to device memory
    cudaError_t err = cudaMemcpy(h_mat.ptr, ptr, n_rows*n_cols*sizeof(float), cudaMemcpyDeviceToHost);
    cuda_check(err);
}

What this does is allows us to easily transfer data between CPU and GPU with very few visible lines of code.

// Define MatrixFP32
MatrixFP32 A_FP32 = MatrixFP32(n, n, false);
MatrixFP32 B_FP32 = MatrixFP32(n, n, false);
MatrixFP32 C_FP32 = MatrixFP32(n, n, false);

// Initialize Matrix
// .
// .
// .

// Move matrix to device
MatrixFP32 d_A_FP32 = MatrixFP32(n, n, true); 
A_FP32.copy_to_device(d_A_FP32);

MatrixFP32 d_B_FP32 = MatrixFP32(n, n, true); 
B_FP32.copy_to_device(d_B_FP32);

MatrixFP32 d_C_FP32 = MatrixFP32(n, n, true); 
C_FP32.copy_to_device(d_C_FP32);

// Perform computations on GPU
// .
// .
// .

// Copy results to host
d_C_FP32.copy_to_host(C_FP32);

// Free Memory
A_FP32.free_mat();
B_FP32.free_mat();
C_FP32.free_mat();

d_A_FP32.free_mat();
d_B_FP32.free_mat();
d_C_FP32.free_mat();

Parallel Matrix Multiplication

Now that the data transfer is all taken care of, let's see how we can program the GPU to perform matrix multiplication in parallel. The algorithm for parallel matrix multiplication involving matrix A of size (M x K) and B of size (K x N) is as follows:

  1. \(M \cdot N\) threads are generated on the GPU (one to compute each element of matrix \(C\) which is of size M x N).
  2. Each thread:
    1. Retrieves a row of matrix \(A\) and a column of matrix \(B\).
    2. Loops over the elements.
    3. Multiplies the two numbers and add the result to the total.
Figure 5: Parallel Matrix Multiplication

Now, who decides which thread works on which element of the matrix \(C\)? To answer this question, we need to understand how threads are organized on a GPU.


When the kernel function is called, it creates a grid of threads on the GPU. The grid is subdivided into multiple blocks, each containing the same amount of threads. The blocks in a grid and threads in a block can be arranged in a 1D/2D/3D manner with coordinates (x,y,z) assigned to the blocks and the threads inside each block.

  • The 1D case is the simplest, with only an x index. For example, consider a grid with 3 blocks and 4 threads in each block. The blocks will then have x coordinates ranging from 0 to 2 and threads in each block will have x coordinates ranging from 0 to 3. It is important to note that the thread index is local for each block.
Figure 6: 1D grid and 1D block
  • In the 2D case, there are x and y indices. For example, consider a grid with 3 blocks in x direction and 2 blocks in y direction. Each block has 4 and 2 threads in x and y directions, respectively. In the case of a 2D organization, it is important to note that the y index precedes x, i.e., it's written as (y,x) such that y is for the vertical direction and x is for the horizontal.
Figure 7: 2D grid and 2D block

As an example, consider matrix multiplication involving 5 x 5 matrices (\(C = A \cdot B\)). To perform matrix multiplication in parallel, let's define 2 x 3 block. To cover all the elements of the output matrix, we need a 3 x 2 grid.

Figure 8: Thread Blocks and Grid

Threads shown in Figure 8 have local indices similar to the ones shown in Figure 7. Using these local indices and block indices we can calculate a global index for each thread. This global index is nothing but a unique index for each thread.

$$\text{Global x} = \text{blockDim}.x \cdot \text{blockIdx}.x + \text{threadIdx}.x$$
$$\text{Global y} = \text{blockDim}.y \cdot \text{blockIdx}.y + \text{threadIdx}.y$$

  • blockDim.x and blockDim.y are the number of blocks in x and y axis respectively.
  • blockIdx.x and blockIdx.y are the block index in x and y axis respectively.
  • threadIdx.x and threadIdx.y are the thread index in x and y axis respectively.
Figure 9: Global Thread Indices

Figure 9 shows the global indices for all the threads alongside the row and column indices for element of matrix C. One way to solve this in parallel is to match the thread's x index with the row index and y with the column index.

Figure 10: Thread to Element Mapping

Figure 10 shows the thread-to-element mapping for the 1st column of the elements of C. The same procedure will map rest of the elements to the threads (Figure 11). Note that there are additional threads that map to non-existent elements. This is common and in code we can made sure that these threads stay idle.

Figure 11: Animation showing thread to element mapping

Looking back at the general case for parallel matrix multiplication involving matrix A of size (M x K) and B of size (K x N), the resulting matrix C will be M x N. The first thing we need to decide is the dimensionality of the block (i.e., number of threads and their organization in a block). CUDA puts a restriction that the total number of threads in a block can't be more than 1024. This means that the organization can be (512,1,1) or (8,16,4) (just as long as the total number is not greater than 1024). As an example let's go with (32, 32). Based on this, we can then figure out the total number of blocks required in x and y dimensions.

The custom-defined CUDA type dim3 defines

  • The number of threads in each block: dim3 dim_block(32, 32, 1). Note that in this case, the z dimension is set to 1.
  • The number of blocks in the grid can then be decided based on the matrix size and the total threads in each block: dim3 dim_grid(ceil(C_n_rows/32.0), ceil(C_n_cols/32.0), 1);.
void naive_xgemm(float *d_A_ptr, float *d_B_ptr, float *d_C_ptr, int C_n_rows, int C_n_cols, int A_n_cols)
{
    // Kernel execution
    dim3 dim_block(32, 32, 1);
    dim3 dim_grid(ceil(C_n_rows/(float)(32)), ceil(C_n_cols/(float)(32)), 1);
    //.
    //.
    //.
}

Now that a grid of threads is defined, the next step is to code the instructions these threads need to execute. This is done with the kernel function. It is a normal looking function that each thread executes independently. In short, all the threads perform the same instructions, the only thing that's different between different threads is the data these instructions are performed on. As the same function is accessed by all the threads, we can use the local indices to compute the global index and then perform the necessary calculations.

__global__ void naive_mat_mul_kernel(float *d_A_ptr, float *d_B_ptr, float *d_C_ptr, int C_n_rows, int C_n_cols, int A_n_cols)
{
    // Working on C[row,col]
    const int row = blockDim.x*blockIdx.x + threadIdx.x;
    const int col = blockDim.y*blockIdx.y + threadIdx.y;

    // Parallel mat mul
    if (row < C_n_rows && col < C_n_cols)
    {
        // Value at C[row,col]
        float value = 0;
        for (int k = 0; k < A_n_cols; k++)
        {
            value += d_A_ptr[row*A_n_cols + k] * d_B_ptr[k*C_n_cols + col];
        }

        // Assigning calculated value (SGEMM is C = α*(A @ B)+β*C and in this repo α=1, β=0)
        d_C_ptr[row*C_n_cols + col] = 1*value + 0*d_C_ptr[row*C_n_cols + col];
    }
}
💡
Notice how there is no loop in the kernel function. Instead, each thread replaces a loop iteration and operates in parallel to other threads. __global__ is a qualifier keyword that suggests this function is callable from the host and can only be executed on the device.

The final step is to launch this kernel function such that it is executed by all the threads. This is done by specifying the grid organization inside <<< >>> and passing the device variables.

void naive_xgemm(float *d_A_ptr, float *d_B_ptr, float *d_C_ptr, int C_n_rows, int C_n_cols, int A_n_cols)
{
    // Kernel execution
    dim3 dim_block(32, 32, 1);
    dim3 dim_grid(ceil(C_n_rows/(float)(32)), ceil(C_n_cols/(float)(32)), 1);
    naive_mat_mul_kernel<<<dim_grid, dim_block>>>(d_A_ptr, d_B_ptr, d_C_ptr, C_n_rows, C_n_cols, A_n_cols);
}

Benchmark

Figure 12: cuBLAS vs My Code

Figure 12 shows the GFLOPS for the above mentioned code against NVIDIA's SGEMM implementation. As you can see, we are achieving 1% of what cuBLAS can do for large matrices.


References

  1. class MatrixFP32
xGeMM/include/MatrixFP32.cuh at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Header File

xGeMM/src/MatrixFP32.cu at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Source File

  1. Naive matrix multiplication on a GPU
xGeMM/include/01_naive_xgemm.cuh at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Header File

xGeMM/src/01_naive_xgemm.cu at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Source File

  1. Benchmarking naive matrix multiplication on a GPU
xGeMM/test/01_benchmark_naive.cu at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Subscribe to 0Mean1Sigma

Don’t miss out on the latest issues. Sign up now to get access to the library of members-only issues.
jamie@example.com
Subscribe