Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,45 @@ jobs:
cmake .. -G Ninja
ninja

linux_build_swiftpm:
name: Linux SwiftPM Build (${{ matrix.container }} ${{ matrix.arch }})
needs: lint
if: github.repository == 'ml-explore/mlx-swift'
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-24.04
arch: x86_64
container: swift:6.2.3-noble
- host: ubuntu-24.04-arm
arch: aarch64
container: swift:6.2.3-noble
runs-on: ${{ matrix.host }}
container:
image: ${{ matrix.container }}
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
submodules: recursive

- name: Install dependencies
run: |
apt-get update -y
apt-get install -y \
build-essential \
libblas-dev \
liblapacke-dev \
libopenblas-dev \
libgfortran-13-dev

- name: Build (SwiftPM)
run: swift build

- name: Run tests
run: swift test

linux_build_cmake_container:
name: Linux Container CMake Swift Build (${{ matrix.container }} ${{ matrix.arch }})
needs: lint
Expand Down
108 changes: 55 additions & 53 deletions Source/MLXFast/MLXFastKernel.swift
Original file line number Diff line number Diff line change
@@ -1,57 +1,59 @@
// Copyright © 2024 Apple Inc.

import Cmlx
import MLX
#if canImport(Darwin)
import Cmlx
import MLX

/// Container for a kernel created by
/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)``
///
/// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs:
///
/// ```swift
/// let a = normal([2, 2])
/// let kernel = MLXFast.metalKernel(
/// name: "basic",
/// inputNames: ["a"],
/// outputNames: ["out1"],
/// source: """
/// uint elem = thread_position_in_grid.x;
/// out1[elem] = a[elem];
/// """,
/// grid: (4, 1, 1),
/// threadGroup: (2, 1, 1),
/// outputShapes: [[2, 2]],
/// outputDTypes: [.float32])
///
/// let out = kernel([a])
/// ```
@available(*, deprecated, renamed: "MLXFast.MLXFastKernel")
public typealias MLXFastKernel = MLXFast.MLXFastKernel
/// Container for a kernel created by
/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)``
///
/// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs:
///
/// ```swift
/// let a = normal([2, 2])
/// let kernel = MLXFast.metalKernel(
/// name: "basic",
/// inputNames: ["a"],
/// outputNames: ["out1"],
/// source: """
/// uint elem = thread_position_in_grid.x;
/// out1[elem] = a[elem];
/// """,
/// grid: (4, 1, 1),
/// threadGroup: (2, 1, 1),
/// outputShapes: [[2, 2]],
/// outputDTypes: [.float32])
///
/// let out = kernel([a])
/// ```
@available(*, deprecated, renamed: "MLXFast.MLXFastKernel")
public typealias MLXFastKernel = MLXFast.MLXFastKernel

/// A jit-compiled custom Metal kernel defined from a source string.
///
/// - Parameters:
/// - name: name for the kernel
/// - inputNames: parameter names of the inputs in the function signature
/// - outputNames: parameter names of the outputs in the function signature
/// - source: source code -- this is the body of a function in Metal,
/// the function signature will be automatically generated.
/// - header: header source code to include before the main function. Useful
/// for helper functions or includes that should live outside of the main function body.
/// - ensureRowContiguous: whether to ensure the inputs are row contiguous
/// before the kernel runs (at a performance cost)
/// - atomicOutputs: whether to use atomic outputs in the function signature,
/// e.g. `device atomic<float>`
/// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it
public func metalKernel(
name: String, inputNames: [String], outputNames: [String],
source: String, header: String = "",
ensureRowContiguous: Bool = true,
atomicOutputs: Bool = false
) -> MLXFast.MLXFastKernel {
return MLX.MLXFast.metalKernel(
name: name, inputNames: inputNames, outputNames: outputNames,
source: source, header: header,
ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs
)
}
/// A jit-compiled custom Metal kernel defined from a source string.
///
/// - Parameters:
/// - name: name for the kernel
/// - inputNames: parameter names of the inputs in the function signature
/// - outputNames: parameter names of the outputs in the function signature
/// - source: source code -- this is the body of a function in Metal,
/// the function signature will be automatically generated.
/// - header: header source code to include before the main function. Useful
/// for helper functions or includes that should live outside of the main function body.
/// - ensureRowContiguous: whether to ensure the inputs are row contiguous
/// before the kernel runs (at a performance cost)
/// - atomicOutputs: whether to use atomic outputs in the function signature,
/// e.g. `device atomic<float>`
/// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it
public func metalKernel(
name: String, inputNames: [String], outputNames: [String],
source: String, header: String = "",
ensureRowContiguous: Bool = true,
atomicOutputs: Bool = false
) -> MLXFast.MLXFastKernel {
return MLX.MLXFast.metalKernel(
name: name, inputNames: inputNames, outputNames: outputNames,
source: source, header: header,
ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs
)
}
#endif
Loading