File tree Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Expand file tree Collapse file tree 2 files changed +8
-10
lines changed Original file line number Diff line number Diff line change 1313namespace megablocks {
1414namespace construct_indices {
1515
16- // We expect the number of outputs per block to be
17- // small. With ffn_hidden_size=4096, we only need
18- // to write 32 elements per block per iteration.
19- // This is the largest we're every likely to use
20- // so we keep the blocks small.
16+ // We expect the number of outputs per block to be small. For
17+ // example, with ffn_hidden_size=4096, we only need to write
18+ // 32 elements per block per iteration.
2119const int kThreadsPerBlock = 32 ;
2220
2321__global__ void __launch_bounds__ (kThreadsPerBlock )
@@ -39,13 +37,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
3937
4038 // Write the indices to the output.
4139 int bin_offset = blockIdx.y ;
42- int tid = threadIdx.x ;
4340 int num_rows = end - start;
4441 for (; bin_offset < num_rows; num_rows -= gridDim.y ) {
45- int elements = num_columns;
4642 short *out = indices;
47- for (; tid < elements; elements - = kThreadsPerBlock ) {
48- *out = threadIdx. x + (blockIdx.x * num_columns);
43+ for (int bid = threadIdx. x ; bid < num_columns; bid + = kThreadsPerBlock ) {
44+ *out = bid + (blockIdx.x * num_columns);
4945 out += kThreadsPerBlock ;
5046 }
5147 indices += gridDim.y * num_columns;
Original file line number Diff line number Diff line change 2626 (16384 , 768 , 128 ),
2727 (16384 , 768 , 256 ),
2828 (16384 , 768 , 512 ),
29- (16384 , 768 , 1024 ))
29+ (16384 , 768 , 1024 ),
30+ (8 , 14336 , 8 ),
31+ )
3032
3133
3234class TopologyTest (parameterized .TestCase ):
You can’t perform that action at this time.
0 commit comments