more streamlined APIs for typical use cases#330
Conversation
davidkoski
commented
Jun 12, 2025
- inspired by https://developer.apple.com/documentation/foundationmodels
| hub: HubApi, configuration: ModelConfiguration, | ||
| progressHandler: @Sendable @escaping (Progress) -> Void | ||
| ) async throws -> ModelContext { | ||
| ) async throws -> sending ModelContext { |
There was a problem hiding this comment.
This is the correct syntax for passing ownership back.
|
|
||
| /// Default instance of HubApi to use. This is configured to save downloads into the caches directory. | ||
| public var defaultHubApi: HubApi = { | ||
| HubApi(downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first) |
There was a problem hiding this comment.
Per feedback and experience putting the downloaded weights in ~/Documents was problematic -- it synced to iCloud Documents and participated in backups on iOS devices.
Another option is to use ~/Downloads (like the MLXChatExample app does) but that requires specific entitlements. This will put them in ~/Library/Caches (or the equivalent in a container):
ls ~/Library/Caches/models/mlx-community
Qwen3-0.6B-4bit Qwen3-4B-4bit Qwen3-8B-4bit
|
|
||
| public func load( | ||
| hub: HubApi = defaultHubApi, id: String, | ||
| progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } |
There was a problem hiding this comment.
All of these are missing documentation -- I will add that once we are in agreement on API.
There was a problem hiding this comment.
This allows:
let model = try await LLMModelFactory.shared.load(id: "mlx-community/Qwen3-4B-4bit")There was a problem hiding this comment.
Question: we have had some feedback about load being both download and load. Do we want to address that here? It may be tricky as the current load() semantics cover both. We can easily add a download() method, for sure.
As far as simplicity goes, having this do both seems better, but it propagates the issue forward.
There was a problem hiding this comment.
I think we should keep it this way for now.
There was a problem hiding this comment.
Sorry naive question: but why do we need to do LLMModelFactory.shared.load vs either just a free function load or simply LLMModelFactory.load?
There was a problem hiding this comment.
Per discussion:
load()has layering issues as MLXLMCommon doesn't know how to load VLM/LLM (it is above those in the layering)- we can skip the
.shared.part with a static function
There was a problem hiding this comment.
Though I wonder if we can do something dynamic ... let me try that.
There was a problem hiding this comment.
Yep, that worked
| public func load( | ||
| hub: HubApi = defaultHubApi, directory: URL, | ||
| progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } | ||
| ) async throws -> sending ModelContext { |
There was a problem hiding this comment.
This allows:
let model = try await LLMModelFactory.shared.load(directory: .homeDirectory.appending(component: "my-model"))| import Foundation | ||
| import MLX | ||
|
|
||
| private class Generator { |
There was a problem hiding this comment.
Code shared between the one-shot and session calls. It is a little more complex than just calling the methods as it handles some variants.
| image: UserInput.Image? = nil, video: UserInput.Video? = nil, | ||
| processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)), | ||
| generateParameters: GenerateParameters = .init() | ||
| ) async throws -> String { |
There was a problem hiding this comment.
A lot of arguments -- all but the prompt have defaults:
let model = try await LLMModelFactory.shared.load(id: "mlx-community/Qwen3-4B-4bit")
print(try await generate(model, "What are three things to see in Paris?"))The others can be supplied if you are using a VLM and want to add an image for example.
There was a problem hiding this comment.
Do we like this API?
- (pro) it matches the python API
- (con) there are a lot of overloads of
generate()-- is this confusing? we can highlight this one in the docs - (con) it doesn't match the naming of the FM api or the session api below (
respond(to:))
There was a problem hiding this comment.
Per discussion removing these free functions -- they are covered by the ChatSession API and that is minimal overhead to create. You can just create & discard for the one-shot.
| } | ||
|
|
||
| public func generate( | ||
| _ model: ModelContext, instructions: String? = nil, _ prompt: String, |
There was a problem hiding this comment.
Same as above but with a ModelContext (doesn't have the actor container). Actually the example code above uses this one.
| return try await generator.generate() | ||
| } | ||
|
|
||
| public func stream( |
There was a problem hiding this comment.
Same as above but produce a streaming output:
for try await item in stream(model, prompt) {
print(item, terminator: "")
}Questions on the API:
- should this be more like
streamResponse(to:)? stream()vsgenerate()-- should it begenerateStream()orstreamGenerate()?
| return generator.stream() | ||
| } | ||
|
|
||
| public class ChatSession { |
There was a problem hiding this comment.
For chat sessions:
let session = ChatSession(model)
let questions = [
"What are two things to see in San Francisco?",
"How about a great place to eat?",
"What city are we talking about? I forgot!",
]
for question in questions {
for try await item in session.streamResponse(to: question) {
print(item, terminator: "")
}
print()
}| generateParameters: generateParameters) | ||
| } | ||
|
|
||
| public func respond( |
There was a problem hiding this comment.
These two methods closely match the FM API
| hiddenSize: 128, hiddenLayers: 128, intermediateSize: 512, attentionHeads: 32, | ||
| rmsNormEps: 0.00001, vocabularySize: 1500, kvHeads: 8) | ||
| hiddenSize: 64, hiddenLayers: 16, intermediateSize: 512, attentionHeads: 32, | ||
| rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 8) |
There was a problem hiding this comment.
I thought I reduced the size of these earlier -- no need for them to be that large, we just want to exercise the machinery.
| ) | ||
| } | ||
| ) | ||
| } |
There was a problem hiding this comment.
The TestTokenizer gets a little more power -- it produces output like this:
rpxdjm twj rexpn tdrgdu tdrgdu xmrrds ldre lcowwy lcowwy lcowwy lcowwy nzlmfiz lmb lmb jkjkxz twj gefvypc lmb ldre klb ulipy cvvi tnxgjl oew cvvi xhqk unxxymp
| print(try await session.respond(to: "what color is the sky?")) | ||
| print(try await session.respond(to: "why is that?")) | ||
| print(try await session.respond(to: "describe this image", image: .ciImage(CIImage.red))) | ||
| } |
There was a problem hiding this comment.
@awni @angeloskath some examples of the streamlined API
| @@ -0,0 +1,50 @@ | |||
| import MLXLLM | |||
| import MLXLMCommon | |||
There was a problem hiding this comment.
@angeloskath @awni this is meant as an example of the streamlined API with full integration
- how does this look?
- changes? improvements?
- I tried to keep the code as simple as possible but still have legible output
- if the output isn't considered it can be simpler:
let session = ChatSession(model)
print(try await session.respond(to: "What are three things to see in San Francisco?")
print(try await session.respond(to: "How about a place to eat?")| // add the assistant response to the chat messages | ||
| state.chat.append(.assistant(output)) | ||
| // the kvcache now contains this context | ||
| state.chat.removeAll() |
There was a problem hiding this comment.
Unrelated to the API but I realized the messages/kvcache in the chat command line example were not right.
|
|
||
| """) | ||
|
|
||
| for try await item in session.streamResponse(to: question) { |
There was a problem hiding this comment.
I slightly feel like this should be session.streamRespond(to: question) to match session.respond(to:question).
There was a problem hiding this comment.
Or they could both be response instead of respond?
There was a problem hiding this comment.
|
This is really great. I left a couple inline comments / questions.
|
| /// an `actor` providing an isolation context. Use this call when you control the isolation context | ||
| /// and can hold the `ModelContext` directly. |
There was a problem hiding this comment.
Should those be in double backticks?
| /// an `actor` providing an isolation context. Use this call when you control the isolation context | ||
| /// and can hold the `ModelContext` directly. |
|
|
||
| private let generator: Generator | ||
|
|
||
| /// Initialzie the ChatSession |
There was a problem hiding this comment.
| /// Initialzie the ChatSession | |
| /// Initialize the ChatSession |
Also maybe ChatSession should be backticked?
| generateParameters: generateParameters) | ||
| } | ||
|
|
||
| /// Initialzie the ChatSession |
There was a problem hiding this comment.
| /// Initialzie the ChatSession | |
| /// Initialize the ChatSession |
| /// - hub: optional HubApi -- by default uses ``defaultHubApi`` | ||
| /// - directory: directory of configuration and weights | ||
| /// - progressHandler: optional callback for progress | ||
| /// - Returns: a ModelContainer |
There was a problem hiding this comment.
Maybe ModelContainer should be in backticks?
| /// - hub: optional HubApi -- by default uses ``defaultHubApi`` | ||
| /// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit" | ||
| /// - progressHandler: optional callback for progress | ||
| /// - Returns: a ModelContainer |
There was a problem hiding this comment.
Maybe ModelContainer should be in backticks?
| /// - hub: optional HubApi -- by default uses ``defaultHubApi`` | ||
| /// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit" | ||
| /// - progressHandler: optional callback for progress | ||
| /// - Returns: a ModelContext |
awni
left a comment
There was a problem hiding this comment.
Minor nits in the docs. O/w looks awesome!