Skip to content

Small refactor to enhance code clarity #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions src/computesim.nim
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,19 @@ type
const
MaxConcurrentWorkGroups {.intdefine.} = 2

proc subgroupProc[A, B, C](wg: WorkGroupContext; numActiveThreads: uint32; barrier: BarrierHandle,
compute: ThreadGenerator[A, B, C]; buffers: A; shared: ptr B; args: C) =
proc subgroupProc[A, B, C](dsp: DispatchContext; wg: WorkGroupContext; numActiveThreads: uint32;
barrier: BarrierHandle, compute: ThreadGenerator[A, B, C]; buffers: A; shared: ptr B; args: C) =
var threads = default(SubgroupThreads)
var threadContexts {.noinit.}: ThreadContexts
let startIdx = wg.gl_SubgroupID * SubgroupSize
# Initialize coordinates from startIdx
var x = startIdx mod wg.gl_WorkGroupSize.x
var y = (startIdx div wg.gl_WorkGroupSize.x) mod wg.gl_WorkGroupSize.y
var z = startIdx div (wg.gl_WorkGroupSize.x * wg.gl_WorkGroupSize.y)
var x = startIdx mod dsp.gl_WorkGroupSize.x
var y = (startIdx div dsp.gl_WorkGroupSize.x) mod dsp.gl_WorkGroupSize.y
var z = startIdx div (dsp.gl_WorkGroupSize.x * dsp.gl_WorkGroupSize.y)
# Pre-compute global offsets
let globalOffsetX = wg.gl_WorkGroupID.x * wg.gl_WorkGroupSize.x
let globalOffsetY = wg.gl_WorkGroupID.y * wg.gl_WorkGroupSize.y
let globalOffsetZ = wg.gl_WorkGroupID.z * wg.gl_WorkGroupSize.z
let globalOffsetX = wg.gl_WorkGroupID.x * dsp.gl_WorkGroupSize.x
let globalOffsetY = wg.gl_WorkGroupID.y * dsp.gl_WorkGroupSize.y
let globalOffsetZ = wg.gl_WorkGroupID.z * dsp.gl_WorkGroupSize.z
# Setup thread contexts
for threadId in 0..<numActiveThreads:
threadContexts[threadId] = ThreadContext(
Expand All @@ -144,27 +144,26 @@ proc subgroupProc[A, B, C](wg: WorkGroupContext; numActiveThreads: uint32; barri
)
# Update coordinates
inc x
if x >= wg.gl_WorkGroupSize.x:
if x >= dsp.gl_WorkGroupSize.x:
x = 0
inc y
if y >= wg.gl_WorkGroupSize.y:
if y >= dsp.gl_WorkGroupSize.y:
y = 0
inc z
# Allocate all compute closures
for threadId in 0..<numActiveThreads:
threads[threadId] = compute(buffers, shared, args)
# Run threads in lockstep
runThreads(threads, wg, threadContexts, numActiveThreads, barrier)
runThreads(threads, dsp, wg, threadContexts, numActiveThreads, barrier)

proc workGroupProc[A, B, C](
workgroupID: UVec3,
wg: WorkGroupContext,
dsp: DispatchContext,
compute: ThreadGenerator[A, B, C],
ssbo: A, smem: ptr B, args: C) =
# Auxiliary proc for work group management
var wg = wg # Shadow for modification
wg.gl_WorkGroupID = workgroupID
let threadsInWorkgroup = wg.gl_WorkGroupSize.x * wg.gl_WorkGroupSize.y * wg.gl_WorkGroupSize.z
var wg = WorkGroupContext(gl_WorkGroupID: workgroupID)
let threadsInWorkgroup = dsp.gl_WorkGroupSize.x * dsp.gl_WorkGroupSize.y * dsp.gl_WorkGroupSize.z
let numSubgroups = ceilDiv(threadsInWorkgroup, SubgroupSize)
wg.gl_NumSubgroups = numSubgroups
# Initialize local shared memory
Expand All @@ -177,18 +176,18 @@ proc workGroupProc[A, B, C](
wg.gl_SubgroupID = subgroupId
# Calculate number of active threads in this subgroup
let threadsInSubgroup = min(remainingThreads, SubgroupSize)
master.spawn subgroupProc(wg, threadsInSubgroup, barrier.getHandle(), compute, ssbo, smem, args)
master.spawn subgroupProc(dsp, wg, threadsInSubgroup, barrier.getHandle(), compute, ssbo, smem, args)
dec remainingThreads, threadsInSubgroup

proc runCompute[A, B, C](
numWorkGroups, workGroupSize: UVec3,
compute: ThreadGenerator[A, B, C],
ssbo: A, smem: B, args: C) =
let wg = WorkGroupContext(
let dsp = DispatchContext(
gl_NumWorkGroups: numWorkGroups,
gl_WorkGroupSize: workGroupSize
)
let totalGroups = wg.gl_NumWorkGroups.x * wg.gl_NumWorkGroups.y * wg.gl_NumWorkGroups.z
let totalGroups = dsp.gl_NumWorkGroups.x * dsp.gl_NumWorkGroups.y * dsp.gl_NumWorkGroups.z
let numBatches = ceilDiv(totalGroups, MaxConcurrentWorkGroups)
var currentGroup: uint32 = 0
# Initialize workgroup coordinates
Expand All @@ -203,13 +202,13 @@ proc runCompute[A, B, C](
master.awaitAll:
var groupIdx: uint32 = 0
while currentGroup < endGroup:
master.spawn workGroupProc(uvec3(wgX, wgY, wgZ), wg, compute, ssbo, addr smemArr[groupIdx], args)
master.spawn workGroupProc(uvec3(wgX, wgY, wgZ), dsp, compute, ssbo, addr smemArr[groupIdx], args)
# Increment coordinates, wrapping when needed
inc wgX
if wgX >= wg.gl_NumWorkGroups.x:
if wgX >= dsp.gl_NumWorkGroups.x:
wgX = 0
inc wgY
if wgY >= wg.gl_NumWorkGroups.y:
if wgY >= dsp.gl_NumWorkGroups.y:
wgY = 0
inc wgZ
inc groupIdx
Expand Down
8 changes: 5 additions & 3 deletions src/computesim/core.nim
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ type
groupMemoryBarrier, reconverge, invalid:
discard

WorkGroupContext* = object
gl_WorkGroupID*: UVec3
DispatchContext* = object
gl_WorkGroupSize*: UVec3
gl_NumWorkGroups*: UVec3

WorkGroupContext* = object
gl_WorkGroupID*: UVec3
gl_NumSubgroups*: uint32
gl_SubgroupID*: uint32

Expand Down Expand Up @@ -132,7 +134,7 @@ proc wait*(m: BarrierHandle) {.inline.} =
wait(m.x[])

type
ThreadClosure* = iterator (iterArg: SubgroupResult, wg: WorkGroupContext,
ThreadClosure* = iterator (iterArg: SubgroupResult, dsp: DispatchContext, wg: WorkGroupContext,
thread: ThreadContext, threadId: uint32): SubgroupCommand
SubgroupResults* = array[SubgroupSize, SubgroupResult]
SubgroupCommands* = array[SubgroupSize, SubgroupCommand]
Expand Down
6 changes: 3 additions & 3 deletions src/computesim/lockstep.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type
ThreadState = enum
running, halted, atSubBarrier, atBarrier, finished

proc runThreads*(threads: SubgroupThreads; workGroup: WorkGroupContext,
proc runThreads*(threads: SubgroupThreads; dispatch: DispatchContext; workGroup: WorkGroupContext,
threadContexts: ThreadContexts; numActiveThreads: uint32; b: BarrierHandle) =
var
anyThreadsActive = true
Expand Down Expand Up @@ -64,7 +64,7 @@ proc runThreads*(threads: SubgroupThreads; workGroup: WorkGroupContext,
threadStates[threadId] == running or canReconverge or canPassBarrier:
madeProgress = true
{.cast(gcsafe).}:
commands[threadId] = threads[threadId](results[threadId], workGroup, threadContexts[threadId], threadId)
commands[threadId] = threads[threadId](results[threadId], dispatch, workGroup, threadContexts[threadId], threadId)
if finished(threads[threadId]):
threadStates[threadId] = finished
elif commands[threadId].kind == barrier:
Expand Down Expand Up @@ -137,7 +137,7 @@ proc runThreads*(threads: SubgroupThreads; workGroup: WorkGroupContext,
let firstThreadId = threadGroups[groupIdx][1]
let opKind = commands[firstThreadId].kind
let opId = commands[firstThreadId].id
case opKind:
case opKind
of subgroupBroadcast:
execSubgroupOp(execBroadcast)
of subgroupBroadcastFirst:
Expand Down
13 changes: 10 additions & 3 deletions src/computesim/transform.nim
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,15 @@ proc generateTemplates(sym: NimNode; fields: openArray[string]): NimNode =
result.add quote do:
template `fieldIdent`(): untyped {.used.} = `sym`.`fieldIdent`

proc generateDispatchTemplates(dspSym: NimNode): NimNode =
generateTemplates(dspSym, [
"gl_WorkGroupSize",
"gl_NumWorkGroups"
])

proc generateWorkGroupTemplates(wgSym: NimNode): NimNode =
generateTemplates(wgSym, [
"gl_WorkGroupID",
"gl_WorkGroupSize",
"gl_NumWorkGroups",
"gl_NumSubgroups",
"gl_SubgroupID"
])
Expand Down Expand Up @@ -270,19 +274,22 @@ macro computeShader*(prc: untyped): untyped =
# Apply optimization to remove unnecessary reconverge points
traversedBody = optimizeReconvergePoints(traversedBody)
# Create symbols for both contexts
let dspSym = genSym(nskParam, "dsp")
let wgSym = genSym(nskParam, "wg")
let threadSym = genSym(nskParam, "thread")
let tidSym = genSym(nskParam, "threadId")
# Generate template declarations for both contexts
let dspTemplates = generateDispatchTemplates(dspSym)
let wgTemplates = generateWorkGroupTemplates(wgSym)
let threadTemplates = generateThreadTemplates(threadSym)

result = quote do:
proc `procName`(): ThreadClosure =
iterator (`iterArg`: SubgroupResult, `wgSym`: WorkGroupContext,
iterator (`iterArg`: SubgroupResult, `dspSym`: DispatchContext, `wgSym`: WorkGroupContext,
`threadSym`: ThreadContext, `tidSym`: uint32): SubgroupCommand =
template gl_SubgroupInvocationID(): uint32 {.used.} = `tidSym`
`threadTemplates`
`dspTemplates`
`wgTemplates`
`traversedBody`

Expand Down