Skip to content

Change grouping method #17

@planetis-m

Description

@planetis-m

Benchmark shows that saving length is more efficient

import std/[times, stats]

const
  SubgroupSize = 32  # Example number of threads
  NumCommands = 1_000  # Number of commands to simulate
  NumIterations = 10_000  # Number of iterations for benchmarking

type
  ThreadState = enum
    finished, running, atSubBarrier

  Command = object
    id: uint32  # Operation ID

  SubgroupThreadIDs = array[SubgroupSize + 1, uint32]

const
  InvalidId = high(uint32) # Sentinel value for empty/invalid

var
  threadGroups: array[SubgroupSize, SubgroupThreadIDs]
  numGroups: uint32 = 0
  commands: array[SubgroupSize, Command]
  threadStates: array[SubgroupSize, ThreadState]

iterator threadsInGroup(group: SubgroupThreadIDs): uint32 =
  var idx: uint32 = 0
  while group[idx] != InvalidId:
    yield group[idx]
    inc idx

# iterator threadsInGroup(group: SubgroupThreadIDs): uint32 =
#   let length = int32(group[0])
#   for i in 1..length:
#     yield group[i]

proc initializeState() =
  # Reset all states
  for i in 0..<SubgroupSize:
    threadStates[i] = running
    commands[i].id = uint32(i mod 4)  # Assign some operation IDs (0-3)

  # Clear thread groups
  for i in 0..<SubgroupSize:
    for j in 0..<SubgroupSize + 1:
      threadGroups[i][j] = InvalidId

  # # Clear thread groups
  # for i in 0..<SubgroupSize:
  #   threadGroups[i][0] = 0 # Initialize length to 0

proc groupThreads(numActiveThreads: uint32, canReconverge, canPassBarrier: bool): uint32 =
  # Group matching operations
  numGroups = 0

  # # Group by operation id
  # for threadId in 0..<numActiveThreads:
  #   if threadStates[threadId] != finished and
  #       (threadStates[threadId] == running or
  #       (threadStates[threadId] == atSubBarrier and canReconverge) or canPassBarrier):
  #     var found = false
  #     for groupIdx in 0..<numGroups:
  #       let firstThreadId = threadGroups[groupIdx][1] # First thread is at index 1
  #       if commands[firstThreadId].id == commands[threadId].id:
  #         # Add thread to group if there's space
  #         let currentLen = int(threadGroups[groupIdx][0])
  #         if currentLen < SubgroupSize:
  #           threadGroups[groupIdx][currentLen + 1] = threadId
  #           threadGroups[groupIdx][0] = uint32(currentLen + 1)
  #         found = true
  #         break
  #     if not found:
  #       # Create new group
  #       threadGroups[numGroups][0] = 1 # Length is 1
  #       threadGroups[numGroups][1] = threadId
  #       inc numGroups

  # Group by operation id
  for threadId in 0..<numActiveThreads:
    if threadStates[threadId] != finished and
        (threadStates[threadId] == running or
        (threadStates[threadId] == atSubBarrier and canReconverge) or canPassBarrier):
      var found = false
      for groupIdx in 0..<numGroups:
        let firstThreadId = threadGroups[groupIdx][0]
        if commands[firstThreadId].id == commands[threadId].id:
          # Find first empty slot in group
          for slot in 0..<SubgroupSize:
            if threadGroups[groupIdx][slot] == InvalidId:
              threadGroups[groupIdx][slot] = threadId
              threadGroups[groupIdx][slot + 1] = InvalidId
              break
          found = true
          break
      if not found:
        threadGroups[numGroups][0] = threadId
        threadGroups[numGroups][1] = InvalidId
        inc numGroups

  result = numGroups

proc runBenchmark(): float =
  let startTime = cpuTime()

  # Run the grouping algorithm multiple times
  for _ in 0..<NumCommands:
    discard groupThreads(SubgroupSize, true, false)

  result = cpuTime() - startTime

proc main() =
  var statistics = RunningStat()

  echo "Running benchmark..."
  for i in 0..<NumIterations:
    initializeState()
    statistics.push(runBenchmark())

  echo "Results over ", NumIterations, " iterations:"
  echo "Mean time: ", statistics.mean() * 1000, " ms"
  echo "Std dev:  ", statistics.standardDeviationS() * 1000, " ms"
  echo "Sample variance: ", statistics.varianceS() * 1000, " ms²"
  echo "Number of samples: ", statistics.n

when isMainModule:
  main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions