Deep Learning Architect at NVIDIA
yifany AT csail.mit.edu
Google Scholar
Github
LinkedIn
Curriculum Vitae
In this blog, I will show how to load (Sec. 3), prefetch (Sec. 4) and multicast load (Sec. 5) a tensor using TMA in Cute. All the code in this blog can be found here.
The Tensor Memory Accelerator (TMA) is a hardware unit introduced in the NVIDIA Hopper architecture to accelerate tensor data movement. To formally motivate TMA, we need to wind the clock back a bit to Volta.
As the tensor core goes faster, it needs enough data to keep it busy. Using little’s law: $\text{throughput} = \frac{\text{buffer_size}}{\text{latency}}$, if we increase the (tensor core) throughput and keep the latency equal, we’d need more staging buffer capacity.
The figure above shows the sequence of operation when feeding the data to the tensor core. In Volta, the load sequence is DRAM->L2->L1->RF->shared memory(smem)->RF->tensor core
. The RF/smem/L1/L2 all serve as staging buffer area when feeding the data to the tensor core. While the L2 and smem hold the loaded data tiles, the RF and L1 purely serve the role of staging buffer. Increasing tensor core throughput adds significant staging buffer (RF/L1) usage. So Ampere introduces asyn memory copy to eliminate the extra staging buffer usage in RF/L1. As shown in the figure above, load completely bypasses RF/L1 such that the new sequence becomes DRAM->L2->shared memory(smem)->RF->tensor core
. This frees up RF/L1 so that they can be used for other computation.
At this point, it seems like we solve the bandwidth and capacity issue in the memory subsystem. We can make the tensor core go even faster. However, the throughput of other stages need to keep up with the tensor core. At a high level, the computation of a typical gemm kernel can roughly be described as address generation->load->tensor core (MMA)->epilog (e.g. ReLU, softmax)->address generation->store
. Now let’s try to make the kernel run faster. We bump up the tensor core throughput, the MMA stage becomes faster. We bump up the memory bandwidth along with the async memory copy optimization, the load/store stage becomes faster. The throughput of all the stages need to match. Therefore, we need to bump up the throughput of the address generation and epilog stage. Notice that these two stages both use the CUDA core and its throughput stays largely the same. Then we hit a problem, the throughput of the CUDA cores limits the overall throughput of the kernel. This is where TMA comes in. TMA offloads the address generation from the CUDA core. It by itself can generate addresses at a high throughput, matching the throughput of the rest of the stages. With this offloading, the entire CUDA core can be dedicated to the epilog stage, achieving higher throughput of the epilog stage. With the introduction of the TMA, now every stage of the gemm kernel can run at a high throughput.
Using the TMA can be tricky, there are several options:
In this blog, I will show how to load (Sec. 3), prefetch (Sec. 4) and multicast load (Sec. 5) a tensor using TMA in Cute. Some basic understanding of Cute is required. We leave the more advanced topics like store, reduction, and swizzle for future blogs.
We will walk through an example on how to use TMA to load a matrix (2d tensor) from global memory (gmem) to shared memory (smem). The figure below shows the shape and layout of the matrix gmem_tensor
. It’s [N, K]
([6, 8] in the example
) and rowMajor
. A common thing one would want to do is for one CTA (aka threadblock) to load a tile of the matrix. In out example, we want to tile the gmem_tensor
into tiles of shape [CTA_N, CTA_K]
([2, 4]
) and rowMajor
. And each CTA loads a tile into smem. Then we need [N/CTA_N, K/CTA_K]=[gridDim.x, gridDim.y]
([6/2, 8/4]=[3,2]
) CTA to achieve this.
In the example figure below, we have CTA at (1,1)
in the grid to load the blue tile in gmem_tensor
to smem.
You can totally imagine the tile to CTA mapping to be different for different applications/implementations. For example, CTA0 loads tile (0, 0), (2, 1)
, CTA1 loads tile (0, 1), (2, 0)
, CTA2 loads tile (1, 0), (1, 1)
. Here we just showcase one example mapping and it’s straightforward to modify the code for other mappings.
Now we have all the information we need to construct the host side code.
template <typename T, int CTA_N, int CTA_K>
void cute_host_load(T* data, int N, int K) {
using namespace cute;
// 1. create the gmem tensor, row major
auto gmem_layout = make_layout(make_shape(N, K), make_stride(K, 1));
auto gmem_tensor = make_tensor(make_gmem_ptr(data), gmem_layout);
// 2. create the smem layout, row major
// smem_layout need to use static integer
// use dynamic integer will cause compilation error
auto smem_layout = make_layout(make_shape(Int<CTA_N>{}, Int<CTA_K>{}), make_stride(Int<CTA_K>{}, _1{}));
// 3. create the TMA object
auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);
// 4. invoke the kernel
cute_tma_load_kernel<T, CTA_N, CTA_K>
<<<dim3{N / CTA_N, K / CTA_K, 1}, 32>>>
(tma_load, gmem_tensor, smem_layout);
}
Let’s break it down:
gmem_tensor
with shape [N, K]
and stride [K, 1]
(i.e. rowMajor
).smem_layout
with shape [CTA_N, CTA_K]
and stride [CTA_K, 1]
(i.e. rowMajor
). The only thing to note here is the smem related object should use static integer like cute::_1{}
to avoid compilation error.gmem_tensor
pointer and layout along with the smem_layout
, we create a TMA Copy_Atom
using make_tma_copy(). Underneath, it creates the TMA descriptor (which the user can also use lower level CUDA APIs to create). The type of TMA operation (e.g. load/store/prefetch) is specified by the CopyOp
SM90_TMA_LOAD. The name is pretty straightforward, it’s a load that uses TMA on SM90 (i.e. hopper). We will discuss the TMA Copy_Atom
construction process in detail in Sec. 3.4.1.gridDim
(dim3{N / CTA_N, K / CTA_K, 1}
) and blockDim
(32
since TMA only needs 1 thread to drive).// assume load a [N, K] row major weight matrix
template <typename T, int CTA_N, int CTA_K, class TmaLoad, class GmemTensor, class SmemLayout>
__global__ void cute_tma_load_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor, SmemLayout smem_layout) {
using namespace cute;
constexpr int tma_transaction_bytes = CTA_N * CTA_K * sizeof(T);
// 1. allocate smem for the tile and memory barrier
__shared__ T smem_data[CTA_N * CTA_K];
__shared__ uint64_t tma_load_mbar;
// 2. create the smem tensor
auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout); // [CTA_N, CTA_K]
// 3. only need 1 thread to drive TMA
if (threadIdx.x == 0) {
// 4. initialize the barrier
initialize_barrier(tma_load_mbar, /* arrival count */ 1);
set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
// 5. gets the coordinate of the smem tile in gmem tensor
auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
auto gmem_tensor_coord_cta = local_tile( // [CTA_N, CTA_K]
gmem_tensor_coord,
Tile<Int<CTA_N>, Int<CTA_K>>{},
make_coord(blockIdx.x, blockIdx.y));
// 6. get the slice of TMA work assigned to this CTA in a threadblock cluster
auto tma_load_per_cta = tma_load.get_slice(0);
// 7. issue TMA load
copy(tma_load.with(tma_load_mbar),
tma_load_per_cta.partition_S(gmem_tensor_coord_cta), // [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
tma_load_per_cta.partition_D(smem_tensor)); // [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
}
// 8. wait for TMA to finish
__syncthreads();
wait_barrier(tma_load_mbar, /* phase */ 0);
// 9. after this line, the TMA load is finished
if (threadIdx.x == 0) {
printf("block: (%d, %d), value: %f, %f\n", blockIdx.x, blockIdx.y, float(smem_tensor(make_coord(0, 0))), float(smem_tensor(make_coord(0, 1))));
}
}
Note that the tma_load
needs to be __grid_constant__
since the TMA descriptor is created on the host and pass to the device. It can’t be modified on device.
Let’s break the code down with the help of the figure above:
smem_data
) and the mbarrier tma_load_mbar
.smem_tensor
using the smem pointer and layout.gmem_tensor_coord
through get_tma_tensor(). This will give us a tensor with the same layout as gmem_tensor
. But each entry, instead of the value, is the coordinate in the tensor as shown in the figure above. Then local_tile() tiles gmem_tensor_coord
and returns us a tile gmem_tensor_coord_cta
(e.g. blue tile) that the CTA wants to load from. The second argument specifies the tile size (i.e. [CTA_N, CTA_K]
) we want to tile the big tensor with. The third argument specifies which tile we want to get. It does that by passing the coordinate ([blockIdx.x, blockIdx.y]
) of the tile in the tile space to local_tile()
. The numpy way to write it would be gmem_tensor_coord[CTA_N * blockIdx.x : CTA_N * (blockIdx.x + 1), CTA_K * blockIdx.y : CTA_K * (blockIdx.y + 1)]
. To be more concrete, to get the blue tile in the figure, we do local_tile(gmem_tensor_coord, Tile<Int<2>, Int<4>>{}, make_coord(1, 1))
. We tile gmem_tensor_coord
by [2, 4]
tile size and wants to get the tile at coordinate (1, 1)
.Copy_Atom
(using function get_slice()) that is assigned to this CTA. This is only relevant if you are in a threadblock cluster and the TMA Copy_Atom
is responsible for the load of entire cluster (with multicasting). Here we only have 1 CTA for the TMA and no threadblock cluster so there is just 1 slice. This is similar to how you get a thread slice of MMA from a TiledMMA (and set the fragment etc.). Here we are getting a CTA slice of the TMA from a tiled TMA in a threadblock cluster. We describe this hierarchy in more details in Sec. 3.4.1.with()
passes the mbarrier to the TMA Copy_Atom
. partition_S
and partition_D
reshapes the smem and gmem tile into the shape the copy function expects. We will have a detailed discussion of how each argument is constructed in Sec. 3.4.2.__syncthreads();
so that every thread arrives at the wait point. Then all the threads wait_barrier() on the tma_load_mbar
. The wait is completed when the phase of the barrier flips. The initial phase is 0 during initialization. The second argument of wait_barrier()
is the current phase bit the barrier waiting to flip. We are waiting for phase 0 to flip, hence 0 is passed in. The TMA unit would decrement the transaction count of the barrier once the loads come back until it reaches 0, meaning all load finishes. Then the barrier flips, all threads’ wait is done. We resume execution.For the curious minds, this entire section is dedicated to explain how the copy function works underneath. You can skip to the harness code (Sec. 3.5) if you are only interested in how to use TMA.
The copy() function is one of the Cute built-in algorithms that copies a tensor from the source location to destination location. We are using the three input argument variant where we specify a Copy_Atom
, the src
tensor and dst
tensor. By specifying the Copy_Atom
, we call tell the compiler explicitly which type of copy we want (e.g. Ampere cp.async
, Hopper TMA, etc.). There is also a two input copy variant where only the source and destination tensor are passed in so that the copy operation falls back to the default algorithm.
Copy_Atom
Constructionmake_tma_copy() constructs the Copy_Atom
. The sequence of operation is similar to creating a TiledMMA, i.e. from bottom up CopyOp->Copy_Traits->Copy_Atom->TiledCopy
. The CopyOp
, Copy_Traits
, Copy_Atom
are all warpers over a PTX copy operation but gradually embed more meta data information in the struct (CopyOp
has the least and Copy_Atom
is the most complete). And then TiledCopy (which confusingly is a derived class of Copy_Atom
) stacks multiple Copy_Atom together to create a larger macro-operation. In the particular case of TMA, a Copy_Atom
is responsible for the TMA operation of a CTA (in a threadblock cluster), and a TiledCopy
is responsible for all TMA operations of the entire threadblock cluster. In our simple example, since there is only 1 CTA in a threadblock cluster, the Copy_Atom
size ([CTA_N, CTA_K]
) should be identical to the TiledCopy
size.
The actual call stack is shown below:
CopyOp
(i.e. SM90_TMA_LOAD). Then we construct the Copy_Atom
from Copy_Traits
.Copy_Atom
of each CTA into a TiledCopy
for the entire threadblock cluster.As we presented above, the copy() function takes three arguments: the Copy_Atom
, the src
tensor and dst
tensor. However, we can’t pass what we currently have to the copy function, there are some nuances that we explain below.
Copy_Atom
: Remember the TMA unit needs the mbarrier to synchronize with the CTA, we have never passed the mbarrier handle to the TMA Copy_Atom
yet! It is called an non-executable Copy_Atom
. This is where tma_load.with(tma_load_mbar)
comes into play. The with (which eventually got dispatched here) embeds the mbarrier information into the Copy_Traits
and therefore creates an executable Copy_Atom
. As you can see from the function definition, you can also embed information like L2 cache hint to the Copy_Atom
.src
tensor: We do a source partition of the tensor (i.e. partition_S()). This means we reshape the source tensor layout from [CTA_N, CTA_K]
to [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
. This layout is what the copy()
function expects. [ATOM_N, ATOM_K]
is the TMA Copy_Atom
size (also commonly referred to as [TMA]
in the code comment, and [CTA_N/ATOM_N, CTA_K/ATOM_K]
is referred as [TMA_N, TMA_K]
in the code comment). So in theory, the copy is broken down into CTA_N/ATOM_N*CTA_K/ATOM_K
number of steps with each step copying [ATOM_N, ATOM_K]
. In practice, the Copy_Atom
size is the entire tile, i.e. [TMA_N, TMA_K]=[CTA_N, CTA_K]
, so there is only 1 atom/step in the copy, i.e. the src tensor has been reshaped to [[ATOM_N, ATOM_K], 1, 1]
. You can also verify this by printing out the partitioned src tensor using cute::print()
.dst
tensor: We do a destination partition of the tensor (i.e. partition_D()) similar to src tensor. This means we reshape the destination tensor layout from [CTA_N, CTA_K]
to [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
.CTA_N/ATOM_N*CTA_K/ATOM_K
steps with each step finishing 1 Copy_Atom
size work.Copy_Atom
.Copy_Traits
and pass some additional traits into the copy function.CopyOp::copy
using explode() for passing in additional arguments.gmem_tensor
coordinate space, [2, 4]
in the figure.Take SM90_TMA_LOAD_2D::copy() for example, it takes in TMA descriptor, membar, etc. as arguments, which are unpacked in copy_unpack(). Here we only focus on explaining how the crd are unpacked and leave the rest as an exercise to the reader.
The coordinate of the TMA tile is obtained from src.data().coord_ and explode() eventually unpacks it as (..., int32_t const& crd0, int32_t const& crd1)
. Remember the src tensor is gmem_tensor_coord
which holds all the coordinates of the smem tile in gmem. It is an ArithTuple
and the Cute doc explains it in greater details.
Basically gmem_tensor_coord
is a tensor purely for getting the coordinates of the TMA tile (e.g. the blue tile) and pass it to the TMA instruction. In our Cute code, we might want to reshape (i.e. partition_D()
) or tile the smem_tensor
. It is very error-prone to manually updating the gmem coordinate to keep track of the smem_tensor
manipulation. So we simply creates gmem coordinate tensor and manipulate it in exactly the same way. Then the eventual coordinate will always be in sync with smem_tensor
. We never materialize the coordinate tensor in gmem though, we only manipulate it (reshape/tile, etc.), hence the prefix Arith
.
// 1. Define TMA load tile size
static constexpr int TILE_N = 64;
static constexpr int TILE_K = 128;
int main() {
// 2. Define problem size and tensors
int N = 256;
int K = 256;
// we assume this is a [N, K] row major matrix
cutlass::HostTensor<cutlass::float_e4m3_t, cutlass::layout::RowMajor> B({N, K});
// 3. init some value on host for B tensor and copy it to GPU memory
// ...
B.sync_device();
// 4. do TMA load to smem
cute_host_load<cutlass::float_e4m3_t, TILE_N, TILE_K>(B.device_data(), N, K);
// 5. wait for kernel to complete
cudaDeviceSynchronize();
return 0;
}
The harness function is pretty straightforward:
CTA_N
and CTA_K
) we want for each CTA’s TMA loadgmem_tensor
layout. Here we use the slightly older HostTensor API (tutorial here) to define a FP8 (e4m3) row major matrix with shape [N, K]
.cute_host_load
to launch the kernel.cudaDeviceSynchronize()
to wait for the kernel to complete.TMA unit can also prefetch the same tile from gmem to L2 cache (instead of loading the tile into smem). This can be useful in cases where the data loading latency is exposed.
The code only needs to be slightly tweaked to conduct prefetching instead of loading of the same example described in Sec. 3.1.
template <typename T, int CTA_N, int CTA_K>
void cute_host_prefetch(T* data, int N, int K) {
using namespace cute;
// 1. create the gmem tensor, row major
auto gmem_layout = make_layout(make_shape(N, K), make_stride(K, 1));
auto gmem_tensor = make_tensor(make_gmem_ptr(data), gmem_layout);
// 2. create the smem layout, row major
// smem_layout need to use static integer
// use dynamic integer will cause compilation error
auto smem_layout = make_layout(make_shape(Int<CTA_N>{}, Int<CTA_K>{}), make_stride(Int<CTA_K>{}, _1{}));
// 3. create the TMA object
auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);
// 4. invoke the kernel
cute_tma_prefetch_kernel<T, CTA_N, CTA_K>
<<<dim3{N / CTA_N, K / CTA_K, 1}, 32>>>
(tma_load, gmem_tensor, smem_layout);
}
You can see the host side code is exactly the same as TMA load (Sec. 3.2) (other than function names)! This is the power of the Cute abstraction. The tensor layout obviously is the same. Even the TMA Copy_Atom
is constructed the same. We specify whether we want to do load or prefetch on the device side, the TMA Copy_Atom
will get dispatched into corresponding PTX instruction for load or prefetch.
// assume load a [N, K] row major weight matrix
template <typename T, int CTA_N, int CTA_K, class TmaLoad, class GmemTensor, class SmemLayout>
__global__ void cute_tma_prefetch_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor, SmemLayout smem_layout) {
using namespace cute;
// 1. only need 1 thread to drive TMA
if (threadIdx.x == 0) {
// 2. gets the coordinate of the smem tile in gmem tensor
auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
auto gmem_tensor_coord_cta = local_tile(
gmem_tensor_coord,
Tile<Int<CTA_N>, Int<CTA_K>>{},
make_coord(blockIdx.x, blockIdx.y));
// 3. get the slice of TMA work assigned to this CTA in a threadblock cluster
auto tma_load_per_cta = tma_load.get_slice(0);
// 4. issue TMA load
prefetch(tma_load,
tma_load_per_cta.partition_S(gmem_tensor_coord_cta)); // [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
}
// 5. no need to wait for prefetch, just make sure all threads converge
__syncthreads();
// 6. after this line, the TMA prefetch is finished
}
The code is a subset of the TMA load code, here we only highlight the difference:
smem_tensor
because we only prefetch to L2, no smem is involved.mbarrier
. There isn’t hardware support for prefetch arrival signaling to CTA according to the PTX doc. Prefetch is non blocking and runs behind the scene for most applications, so there is no need to synchronize the prefetch arrival with the CTA. Then when the prefetch is done, we simply __syncthreads();
to eliminate thread divergence.copy()
function, we substitute it with the prefetch() function (another Cute built-in algorithms) to issue TMA prefetch. We still need to reshape the gmem_tensor_coord_cta
into [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
shape, but there is no need for a smem tensor reshape.The underlying mechanism to support prefetching in Cute is quite similar to TMA load and they share most of the infrastructure.
Copy_Atom
ConstructionWhen defining the CopyOp
(e.g. SM90_TMA_LOAD_2D), Cute defines a variant called PREFETCH. The rest of the Copy_Traits
, Copy_Atom
, TiledCopy
definition are unchanged. We always define the same load CopyOp
to SM90_TMA_LOAD_2D
in the host code. Cute just substitute the CopyOp
from SM90_TMA_LOAD_2D
to SM90_TMA_LOAD_2D::PREFETCH
underneath if we want to use prefetch.
CopyOp
substitution. If the CopyOp
has a prefetch variant, we substitute it to the prefetch variant and call copy()
to start reusing the same infra.dst
tensor pointer since it’s not needed in prefetch.PREFETCH
variant of the copy function.PREFETCH
variant of the copy function and lowers to the 2D variants of the TMA Prefetch PTX instruction is called.We still use the example in Sec. 3.1. But instead of a single CTA loads a tile of gmem_tensor
, we have an entire threadblock cluster (that contains multiple CTAs) loads the same tile of data. This means each CTA in the threadblock cluster will get a copy of the data tile in smem. We only load the data tile once from gmem/L2, and the tile got multicasted to all CTAs in the cluster from the L2. Multicasting saves L2 bandwidth and energy.
The figure above shows this new scenario. We have a grid of [3, 4]
CTA and a clusterDim
of 1x2
. Therefore, the grid contains [3, 2]
clusters. We have 2 CTAs of the same cluster (in blue) trying to load the same blue tile into their own smem.
template <typename T, int CTA_N, int CTA_K, int CLUSTER_N, int CLUSTER_K>
void cute_host_multicast(T* data, int N, int K) {
using namespace cute;
// 1. create the GMEM tensor, row major
auto gmem_layout = make_layout(make_shape(N, K), make_stride(K, 1));
auto gmem_tensor = make_tensor(make_gmem_ptr(data), gmem_layout);
// 2. create the SMEM layout, row major
// smem_layout need to use static integer
// use dynamic integer will cause compilation error
auto smem_layout = make_layout(make_shape(Int<CTA_N>{}, Int<CTA_K>{}), make_stride(Int<CTA_K>{}, _1{}));
// 3. create the TMA object
auto tma_load = make_tma_copy(SM90_TMA_LOAD_MULTICAST{}, gmem_tensor, smem_layout, Int<CLUSTER_N * CLUSTER_K>{});
// 4. invoke the kernel using extensible launch interface
cudaLaunchConfig_t config;
cudaLaunchAttribute attrs[1];
config.gridDim = dim3{N / CTA_N * CLUSTER_N, K / CTA_K * CLUSTER_K, 1};
config.blockDim = 32;
config.dynamicSmemBytes = 0;
config.stream = 0x0;
config.numAttrs = 1;
attrs[0].id = cudaLaunchAttributeClusterDimension;
attrs[0].val.clusterDim.x = CLUSTER_N;
attrs[0].val.clusterDim.y = CLUSTER_K;
attrs[0].val.clusterDim.z = 1;
config.attrs = attrs;
auto* kernel = &cute_tma_multicast_kernel<T, CTA_N, CTA_K, decltype(tma_load), decltype(gmem_tensor), decltype(smem_layout)>;
gpuErrChk(cudaLaunchKernelEx(&config, kernel, tma_load, gmem_tensor, smem_layout));
}
The host code is almost the same as normal TMA load, here we only highlight two differences:
Copy_Atom
now has the CopyOp
of SM90_TMA_LOAD_MULTICAST{} to denote this is a TMA multicast load. And we pass in the additional argument of cluster size (i.e. CLUSTER_N * CLUSTER_K
) to make_tma_copy(). This allows Cute to devide up the TMA load work among CTAs in a threadblock cluster (more details in Sec 5.3).<<<...>>>
interface.// assume load a [N, K] row major weight matrix
template <typename T, int CTA_N, int CTA_K, class TmaLoad, class GmemTensor, class SmemLayout>
__global__ void cute_tma_multicast_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor, SmemLayout smem_layout) {
using namespace cute;
constexpr int tma_transaction_bytes = CTA_N * CTA_K * sizeof(T);
// 1. allocate smem for the tile and memory barrier
__shared__ T smem_data[CTA_N * CTA_K];
__shared__ uint64_t tma_load_mbar;
// 2. create the smem tensor
auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout); // [CTA_N, CTA_K]
// 3. only need 1 thread to drive TMA
if (threadIdx.x == 0) {
// 4. initialize the barrier
initialize_barrier(tma_load_mbar, /* arrival count */ 1);
set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
}
// 5. synchronize to make all barrier init in a cluster are done
__syncthreads();
cluster_sync();
cutlass::arch::fence_barrier_init();
if (threadIdx.x == 0) {
// 6. gets the coordinate of the smem tile in gmem tensor
auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
dim3 clusterIdx = cluster_id_in_grid();
auto gmem_tensor_coord_cluster = local_tile( // [CTA_N, CTA_K]
gmem_tensor_coord,
Tile<Int<CTA_N>, Int<CTA_K>>{},
make_coord(clusterIdx.x, clusterIdx.y));
// 7. get the cluster related meta data
uint32_t _block_rank_in_cluster = block_rank_in_cluster();
dim3 clusterDim = cluster_shape();
int cluster_size = clusterDim.x * clusterDim.y * clusterDim.z;
uint16_t tma_mcast_mask = (uint16_t(1) << cluster_size) - 1;
// 8. get the slice of TMA work assigned to this CTA in a threadblock cluster
auto tma_load_per_cta = tma_load.get_slice(_block_rank_in_cluster);
// 9. issue TMA multicast
copy(tma_load.with(tma_load_mbar, tma_mcast_mask),
tma_load_per_cta.partition_S(gmem_tensor_coord_cluster), // [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
tma_load_per_cta.partition_D(smem_tensor)); // [[ATOM_N, ATOM_K], CTA_N/ATOM_N, CTA_K/ATOM_K]
}
// 10. wait for TMA to finish
__syncthreads();
wait_barrier(tma_load_mbar, /* phase */ 0);
// 11. after this line, the TMA multicast is finished
if (threadIdx.x == 0) {
printf("block: (%d, %d), value: %f, %f\n", blockIdx.x, blockIdx.y, float(smem_tensor(make_coord(0, 0))), float(smem_tensor(make_coord(0, 1))));
}
}
The device code is largely the same as TMA load code with a few additions related to cluster and multicast:
gmem_tensor_coord_cluster
(the blue tile) for the entire cluster (instead of CTA in the TMA load example). The only difference is we index the tiles by clusterIdx
rather than blockIdx
in local_tile()
.tma_mcast_mask
specifies which CTA participates in this TMA multicast. In this example, all CTAs (CLUSTER_K * CLUSTER_N
or 2) in a cluster. It’s a one hot vector with bit 1 specifying participation and bit 0 specifying non-participation. For our 2 CTA example, it’s 2’b11. If you set the mask to 2’b01 for instance, the multicast load result is only written to CTA 0.smem_tensor
into smaller chunks and let each CTA multicast load each chunk. We show this in the figure above. The entire cluster is responsible for loading the [2, 4]
-shaped smem_tensor
(this is a TiledCopy
amount of work). We have 2 CTA in the cluster, so each CTA is responsible for multicast loading a single row ([1, 4]
) of smem_tensor
. This means CTA 0 issues multicast load of row 0 (light blue) and the data comes back multicasting to smem of both CTA 0 and 1. And CTA 1 issues multicast load of row 1 (dark blue) and the data comes back multicasting to smem of both CTA 0 and 1. This divide of work makes sure we utilize TMA of all CTAs so that we are not TMA issue bound (alternatively, you can imagine having CTA 0 being the leader CTA and issue multicast load of the entire [2, 4]
tensor, but then only TMA unit in CTA 0 is used, we might be TMA issue throughput bound). To achieve this divide of work, here we slice the TiledCopy object for the whole cluster into ThrCopy object for each CTA using the index _block_rank_in_cluster
. ThrCopy
encapsulates the amount of work needed to be done per CTA (i.e. a row or a Copy_Atom).partition_S
and partition_D
. And we embed the barrier and multicast mask information into the Copy_Atom
using the with()
function.smem_tensor
tile is done when the CTA local barrier phase flips. We don’t need a cluster_sync()
to wait for other CTA’s barrier to flip. This is because although we break the smem_tensor
load into 2 CopyAtom
and let each CTA works on its own share, when the data arrives, TMA updates the transaction_bytes of both barriers in the cluster. This means each barrier independenly tracks whether TMA of the entire smem_tensor
is done, not only it’s own share.TMA multicast load largely follows the same dispatch path as TMA load. In this section we only focus on the difference which is how is the load work devided among CTAs in multicast.
TiledCopy
ConstructionTo illustrate better, here we use a more complex example. We have a 2x2
(CLUSTER_N x CLUSTER_K
) cluster multicast loading a [4, 4]
-shaped ([CTA_N, CTA_K]
) row-major smem_tensor
.
As we briefly touched upon in Sec. 5.3 step 8, the amount of multicast load work of a cluster is evenly devided among all CTAs in the cluster. In our example, each CTA should be responsible for multicast load 4*4/4=4
(CTA_N * CTA_K / CLUSTER_N / CLUSTER_K
) elements of smem_tensor
. But which dimension does Cute slice smem_tensor
? The answer is the slowest moving dimension. Since smem_tensor
is row-major, the CTA_N
dimension is the slowest moving. Hence, for 4 (CLUSTER_N x CLUSTER_K
) CTAs, Cute slices smem_tensor
into 4 (CTA_N * CTA_K / CLUSTER_N
) rows. Each CTA multicast loads a row. The reason why Cute slices the slowest moving dimension is because after the slice, each chunk of data is still contiguous in memory which is TMA load friendly. (If you slice the fastest moving dimension, each chunk would have strided access which is bad for memory coalescing).
Some side notes:
smem_tensor
.smem_tensor
layout. tma_partition() contains the logic.Putting it all together, a TiledCopy
object is responsible for multicast loading the entire smem_tensor
for the cluster. It can be further sliced into multiple ThrCopy
. Each CTA gets a ThrCopy
amount of work which is multicast loading a row of smem_tensor
. For the purpose of TMA, a ThrCopy
consists of a single TMA Copy_Atom
.