Step 6: Vectorized Memory Accesses

Using vectorization to move multiple elements in parallel using a single thread.

Step 6: Vectorized Memory Accesses

Throughout this mini-project, my focus has been on optimizing the memory accesses. It makes sense as memory is the bottleneck for most applications involving GPUs. In this final blog post, I will show how we can utilize vectorization and speed up the data transfer from global memory to shared memory and shared memory to registers. Vectorization involves using a single processor thread to transfer multiple variables. The catch is that, those elements must be next to each other in the memory. As an example, consider a task where we need to retrieve four variables (\(n_0, n_1, n_2, n_3\)). A standard way is to write a loop and access the variables one at a time. If these variables are next to each other in memory (for example, consecutive elements of an array), compiler will automatically use vectorization and get all the values in a single iteration (Figure 1).

Figure 1: Vectorized memory accesses
💡
When utilizing vectorization, four variables are accessed using a single instruction in parallel.

Figure 2 shows the example involving 4x4 matrix multiplication. The strategy here is to use vectorization and access multiple elements of the input matrix A in parallel using a single thread. The 4x2 tile is then stored as a transpose in the shared memory (I will discuss this part shortly). Note that, with vectorization, I don't need multiple iterations to populate the shared memory! The 2x4 tile of the matrix B is accessed similarly. I stored the tile of matrix A as a transpose to again use vectorization and move multiple elements at once into the registers. The same thing is done with the elements of B in the shared memory. Finally, the calculations are performed as usual, and the same thing is repeated for all the other elements of input matrices.

💡
To keep things simple, I'm showcasing parallel access with two elements.
Figure 2: Using vectorization to move data between different GPU memory units

The code with vectorized memory accesses is very similar to the one in step 5. The first big difference is that, the matrix dimensions must also be in multiples of 4.

__global__ void coarse_2d_vec_mat_mul_kernel(float *d_A_ptr, float *d_B_ptr, float *d_C_ptr, int C_n_rows, int C_n_cols, int A_n_cols)
{
    // Number of threads per block
    const int n_threads_per_block = tiles_A_rows * tiles_B_cols / (COARSE_FACTOR_X*COARSE_FACTOR_Y);
    static_assert(n_threads_per_block % tiles_A_cols == 0);
    static_assert(n_threads_per_block % tiles_B_cols == 0);
    static_assert(tiles_A_cols % 4 == 0);
    static_assert(tiles_B_cols % 4 == 0);
    assert(C_n_rows % 4 == 0);
    assert(C_n_cols % 4 == 0);
    assert(A_n_cols % 4 == 0);

    // .
    // .
    // .
}
💡
If the matrices aren't multiples of four, we can always pad them them with zeros.

Everything else is exactly the same, until we get to the point where we move data from global to shared memory. This is where we can use float4 and move 4 elements in parallel using a single thread.

__global__ void coarse_2d_vec_mat_mul_kernel(float *d_A_ptr, float *d_B_ptr, float *d_C_ptr, int C_n_rows, int C_n_cols, int A_n_cols)
{
    // Number of threads per block
    const int n_threads_per_block = tiles_A_rows * tiles_B_cols / (COARSE_FACTOR_X*COARSE_FACTOR_Y);
    static_assert(n_threads_per_block % tiles_A_cols == 0);
    static_assert(n_threads_per_block % tiles_B_cols == 0);
    static_assert(tiles_A_cols % 4 == 0);
    static_assert(tiles_B_cols % 4 == 0);
    assert(C_n_rows % 4 == 0);
    assert(C_n_cols % 4 == 0);
    assert(A_n_cols % 4 == 0);

    // Details regarding this thread
    const int by = blockIdx.y;
    const int bx = blockIdx.x; 

    const int tx = threadIdx.x;

    // 1D -> 2D while loading A
    const int A_view_ty = tx / (tiles_A_cols / 4);
    const int A_view_tx = tx % (tiles_A_cols / 4);
    const int stride_A = n_threads_per_block/(tiles_A_cols/4);

    // 1D -> 2D while loading B
    const int B_view_ty = tx / (tiles_B_cols / 4);
    const int B_view_tx = tx % (tiles_B_cols / 4);
    const int stride_B = n_threads_per_block/(tiles_B_cols / 4);

    // Working on C[row, col]
    const int row = COARSE_FACTOR_Y * (tx / (tiles_B_cols/COARSE_FACTOR_X));
    const int col = COARSE_FACTOR_X * (tx % (tiles_B_cols/COARSE_FACTOR_X));
    
    // Allocating shared memory (A is transposed)
    __shared__ float sh_A[tiles_A_cols][tiles_A_rows];
    __shared__ float sh_B[tiles_A_cols][tiles_B_cols];

    // Parallel mat mul
    float value[COARSE_FACTOR_Y][COARSE_FACTOR_X] = {0.0f};
    float register_A[COARSE_FACTOR_X] = {0.0f};
    float register_B[COARSE_FACTOR_Y] = {0.0f};

    // Phases
    const int phases = ceil((float)A_n_cols/tiles_A_cols);

    for (int phase = 0; phase < phases; phase++)
    {
        // Load Tiles into shared memory
        for (int load_offset = 0; load_offset < tiles_A_rows; load_offset+=stride_A)
        {
            if ((by*tiles_A_rows + load_offset+A_view_ty < C_n_rows) && (((phase*tiles_A_cols+A_view_tx*4)) < A_n_cols))
            {
                float4 A_tmp = reinterpret_cast<float4 *>(&d_A_ptr[(by*tiles_A_rows + load_offset+A_view_ty)*A_n_cols + ((phase*tiles_A_cols+A_view_tx*4))])[0];
                sh_A[A_view_tx*4+0][load_offset+A_view_ty] = A_tmp.x;
                sh_A[A_view_tx*4+1][load_offset+A_view_ty] = A_tmp.y;
                sh_A[A_view_tx*4+2][load_offset+A_view_ty] = A_tmp.z;
                sh_A[A_view_tx*4+3][load_offset+A_view_ty] = A_tmp.w;
            }
            else
            {
                sh_A[A_view_tx*4+0][load_offset+A_view_ty] = 0.0f;
                sh_A[A_view_tx*4+1][load_offset+A_view_ty] = 0.0f;
                sh_A[A_view_tx*4+2][load_offset+A_view_ty] = 0.0f;
                sh_A[A_view_tx*4+3][load_offset+A_view_ty] = 0.0f;
            }
            
        }
        
        for (int load_offset = 0; load_offset < tiles_A_cols; load_offset+=stride_B)
        {
            if (((phase*tiles_A_cols + B_view_ty+load_offset) < A_n_cols) && (((bx*tiles_B_cols + B_view_tx*4)) < C_n_cols))
            {
                float4 B_tmp = reinterpret_cast<float4 *>(&d_B_ptr[(phase*tiles_A_cols + B_view_ty+load_offset)*C_n_cols + ((bx*tiles_B_cols + B_view_tx*4))])[0];
                sh_B[B_view_ty+load_offset][B_view_tx*4+0] = B_tmp.x;
                sh_B[B_view_ty+load_offset][B_view_tx*4+1] = B_tmp.y;
                sh_B[B_view_ty+load_offset][B_view_tx*4+2] = B_tmp.z;
                sh_B[B_view_ty+load_offset][B_view_tx*4+3] = B_tmp.w;
            }
            else
            {
                sh_B[B_view_ty+load_offset][B_view_tx*4+0] = 0.0f;
                sh_B[B_view_ty+load_offset][B_view_tx*4+1] = 0.0f;
                sh_B[B_view_ty+load_offset][B_view_tx*4+2] = 0.0f;
                sh_B[B_view_ty+load_offset][B_view_tx*4+3] = 0.0f;
            }
            
        }
        __syncthreads();

        // .
        // .
        // .
    }

    // .
    // .
    // .
}

We don't need to explicitly declare vectorization when moving data from shared memory to registers because the compiler takes care of that. After the data is loaded in the registers, calculations are performed as usual.

__global__ void coarse_2d_vec_mat_mul_kernel(float *d_A_ptr, float *d_B_ptr, float *d_C_ptr, int C_n_rows, int C_n_cols, int A_n_cols)
{
    // Number of threads per block
    const int n_threads_per_block = tiles_A_rows * tiles_B_cols / (COARSE_FACTOR_X*COARSE_FACTOR_Y);
    static_assert(n_threads_per_block % tiles_A_cols == 0);
    static_assert(n_threads_per_block % tiles_B_cols == 0);
    static_assert(tiles_A_cols % 4 == 0);
    static_assert(tiles_B_cols % 4 == 0);
    assert(C_n_rows % 4 == 0);
    assert(C_n_cols % 4 == 0);
    assert(A_n_cols % 4 == 0);

    // Details regarding this thread
    const int by = blockIdx.y;
    const int bx = blockIdx.x; 

    const int tx = threadIdx.x;

    // 1D -> 2D while loading A
    const int A_view_ty = tx / (tiles_A_cols / 4);
    const int A_view_tx = tx % (tiles_A_cols / 4);
    const int stride_A = n_threads_per_block/(tiles_A_cols/4);

    // 1D -> 2D while loading B
    const int B_view_ty = tx / (tiles_B_cols / 4);
    const int B_view_tx = tx % (tiles_B_cols / 4);
    const int stride_B = n_threads_per_block/(tiles_B_cols / 4);

    // Working on C[row, col]
    const int row = COARSE_FACTOR_Y * (tx / (tiles_B_cols/COARSE_FACTOR_X));
    const int col = COARSE_FACTOR_X * (tx % (tiles_B_cols/COARSE_FACTOR_X));
    
    // Allocating shared memory
    __shared__ float sh_A[tiles_A_cols][tiles_A_rows];
    __shared__ float sh_B[tiles_A_cols][tiles_B_cols];

    // Parallel mat mul
    float value[COARSE_FACTOR_Y][COARSE_FACTOR_X] = {0.0f};
    float register_A[COARSE_FACTOR_X] = {0.0f};
    float register_B[COARSE_FACTOR_Y] = {0.0f};

    // Phases
    const int phases = ceil((float)A_n_cols/tiles_A_cols);

    for (int phase = 0; phase < phases; phase++)
    {
        // Load Tiles into shared memory
        for (int load_offset = 0; load_offset < tiles_A_rows; load_offset+=stride_A)
        {
            if ((by*tiles_A_rows + load_offset+A_view_ty < C_n_rows) && (((phase*tiles_A_cols+A_view_tx*4)) < A_n_cols))
            {
                float4 A_tmp = reinterpret_cast<float4 *>(&d_A_ptr[(by*tiles_A_rows + load_offset+A_view_ty)*A_n_cols + ((phase*tiles_A_cols+A_view_tx*4))])[0];
                sh_A[A_view_tx*4+0][load_offset+A_view_ty] = A_tmp.x;
                sh_A[A_view_tx*4+1][load_offset+A_view_ty] = A_tmp.y;
                sh_A[A_view_tx*4+2][load_offset+A_view_ty] = A_tmp.z;
                sh_A[A_view_tx*4+3][load_offset+A_view_ty] = A_tmp.w;
            }
            else
            {
                sh_A[A_view_tx*4+0][load_offset+A_view_ty] = 0.0f;
                sh_A[A_view_tx*4+1][load_offset+A_view_ty] = 0.0f;
                sh_A[A_view_tx*4+2][load_offset+A_view_ty] = 0.0f;
                sh_A[A_view_tx*4+3][load_offset+A_view_ty] = 0.0f;
            }
            
        }
        
        for (int load_offset = 0; load_offset < tiles_A_cols; load_offset+=stride_B)
        {
            if (((phase*tiles_A_cols + B_view_ty+load_offset) < A_n_cols) && (((bx*tiles_B_cols + B_view_tx*4)) < C_n_cols))
            {
                float4 B_tmp = reinterpret_cast<float4 *>(&d_B_ptr[(phase*tiles_A_cols + B_view_ty+load_offset)*C_n_cols + ((bx*tiles_B_cols + B_view_tx*4))])[0];
                sh_B[B_view_ty+load_offset][B_view_tx*4+0] = B_tmp.x;
                sh_B[B_view_ty+load_offset][B_view_tx*4+1] = B_tmp.y;
                sh_B[B_view_ty+load_offset][B_view_tx*4+2] = B_tmp.z;
                sh_B[B_view_ty+load_offset][B_view_tx*4+3] = B_tmp.w;
            }
            else
            {
                sh_B[B_view_ty+load_offset][B_view_tx*4+0] = 0.0f;
                sh_B[B_view_ty+load_offset][B_view_tx*4+1] = 0.0f;
                sh_B[B_view_ty+load_offset][B_view_tx*4+2] = 0.0f;
                sh_B[B_view_ty+load_offset][B_view_tx*4+3] = 0.0f;
            }
            
        }
        __syncthreads();

        // calculate per-thread results
        for (int k = 0; k < tiles_A_cols; ++k) 
        {
            // block into registers
            for (int i = 0; i < COARSE_FACTOR_Y; ++i)
                register_A[i] = sh_A[k][row+i];
            
            for (int i = 0; i < COARSE_FACTOR_X; ++i)
                register_B[i] = sh_B[k][col+i];
            
            for (int cy = 0; cy < COARSE_FACTOR_Y; ++cy) 
            {
                for (int cx = 0; cx < COARSE_FACTOR_X; ++cx) 
                    value[cy][cx] += register_A[cy] * register_B[cx];
            }
        }
        __syncthreads();
    }

    // Assigning calculated value
    for (int cy = 0; cy < COARSE_FACTOR_Y; ++cy)
    {
        for (int cx = 0; cx < COARSE_FACTOR_X; cx++)
        {
            if ((by*tiles_A_rows+row+cy < C_n_rows) && (bx*tiles_B_cols+col+cx < C_n_cols))
                d_C_ptr[(by*tiles_A_rows+row+cy)*C_n_cols + (bx*tiles_B_cols+col+cx)] = 1*value[cy][cx] + 0*d_C_ptr[(by*tiles_A_rows+row+cy)*C_n_cols + (bx*tiles_B_cols+col+cx)];
        }
    } 
}

Benchmark

Figure 3: cuBLAS vs 2D Thread Coarsening vs 2D Thread Coarsening (with vectorized memory accesses)

With vectorized memory accesses, we got a decent increase in performance and now the code is at around 70% of cuBLAS performance. There is still some room left for improvement, but that's a topic for the future.


References

  1. 2D thread coarsened (with vectorized memory accesses) matrix multiplication
xGeMM/include/06_coarse_2d_vec_xgemm.cuh at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Header File

xGeMM/src/06_coarse_2d_vec_xgemm.cu at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Source File

  1. Benchmarking 2D thread coarsened (with vectorized memory accesses) matrix
xGeMM/test/06_benchmark_coarse_2d_vec.cu at master · tgautam03/xGeMM
Accelerated General (FP32) Matrix Multiplication. Contribute to tgautam03/xGeMM development by creating an account on GitHub.

Subscribe to 0Mean1Sigma

Don’t miss out on the latest issues. Sign up now to get access to the library of members-only issues.
jamie@example.com
Subscribe