Skip to content

more streamlined APIs for typical use cases#330

Merged
davidkoski merged 6 commits intomainfrom
simple-api
Jun 13, 2025
Merged

more streamlined APIs for typical use cases#330
davidkoski merged 6 commits intomainfrom
simple-api

Conversation

@davidkoski
Copy link
Copy Markdown
Collaborator

@davidkoski davidkoski requested review from angeloskath and awni June 12, 2025 16:58
hub: HubApi, configuration: ModelConfiguration,
progressHandler: @Sendable @escaping (Progress) -> Void
) async throws -> ModelContext {
) async throws -> sending ModelContext {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the correct syntax for passing ownership back.


/// Default instance of HubApi to use. This is configured to save downloads into the caches directory.
public var defaultHubApi: HubApi = {
HubApi(downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per feedback and experience putting the downloaded weights in ~/Documents was problematic -- it synced to iCloud Documents and participated in backups on iOS devices.

Another option is to use ~/Downloads (like the MLXChatExample app does) but that requires specific entitlements. This will put them in ~/Library/Caches (or the equivalent in a container):

ls ~/Library/Caches/models/mlx-community 
Qwen3-0.6B-4bit Qwen3-4B-4bit   Qwen3-8B-4bit

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #332


public func load(
hub: HubApi = defaultHubApi, id: String,
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these are missing documentation -- I will add that once we are in agreement on API.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows:

let model = try await LLMModelFactory.shared.load(id: "mlx-community/Qwen3-4B-4bit")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: we have had some feedback about load being both download and load. Do we want to address that here? It may be tricky as the current load() semantics cover both. We can easily add a download() method, for sure.

As far as simplicity goes, having this do both seems better, but it propagates the issue forward.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep it this way for now.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry naive question: but why do we need to do LLMModelFactory.shared.load vs either just a free function load or simply LLMModelFactory.load?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per discussion:

  • load() has layering issues as MLXLMCommon doesn't know how to load VLM/LLM (it is above those in the layering)
  • we can skip the .shared. part with a static function

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I wonder if we can do something dynamic ... let me try that.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that worked

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!

public func load(
hub: HubApi = defaultHubApi, directory: URL,
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
) async throws -> sending ModelContext {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows:

let model = try await LLMModelFactory.shared.load(directory: .homeDirectory.appending(component: "my-model"))

import Foundation
import MLX

private class Generator {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code shared between the one-shot and session calls. It is a little more complex than just calling the methods as it handles some variants.

image: UserInput.Image? = nil, video: UserInput.Video? = nil,
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
generateParameters: GenerateParameters = .init()
) async throws -> String {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of arguments -- all but the prompt have defaults:

let model = try await LLMModelFactory.shared.load(id: "mlx-community/Qwen3-4B-4bit")
print(try await generate(model, "What are three things to see in Paris?"))

The others can be supplied if you are using a VLM and want to add an image for example.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we like this API?

  • (pro) it matches the python API
  • (con) there are a lot of overloads of generate() -- is this confusing? we can highlight this one in the docs
  • (con) it doesn't match the naming of the FM api or the session api below (respond(to:))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per discussion removing these free functions -- they are covered by the ChatSession API and that is minimal overhead to create. You can just create & discard for the one-shot.

}

public func generate(
_ model: ModelContext, instructions: String? = nil, _ prompt: String,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above but with a ModelContext (doesn't have the actor container). Actually the example code above uses this one.

return try await generator.generate()
}

public func stream(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above but produce a streaming output:

for try await item in stream(model, prompt) {
    print(item, terminator: "")
}

Questions on the API:

  • should this be more like streamResponse(to:)?
  • stream() vs generate() -- should it be generateStream() or streamGenerate()?

return generator.stream()
}

public class ChatSession {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For chat sessions:

let session = ChatSession(model)

let questions = [
    "What are two things to see in San Francisco?",
    "How about a great place to eat?",
    "What city are we talking about?  I forgot!",
]

for question in questions {
    for try await item in session.streamResponse(to: question) {
        print(item, terminator: "")
    }
    print()
}

generateParameters: generateParameters)
}

public func respond(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods closely match the FM API

hiddenSize: 128, hiddenLayers: 128, intermediateSize: 512, attentionHeads: 32,
rmsNormEps: 0.00001, vocabularySize: 1500, kvHeads: 8)
hiddenSize: 64, hiddenLayers: 16, intermediateSize: 512, attentionHeads: 32,
rmsNormEps: 0.00001, vocabularySize: 100, kvHeads: 8)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I reduced the size of these earlier -- no need for them to be that large, we just want to exercise the machinery.

)
}
)
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TestTokenizer gets a little more power -- it produces output like this:

rpxdjm twj rexpn tdrgdu tdrgdu xmrrds ldre lcowwy lcowwy lcowwy lcowwy nzlmfiz lmb lmb jkjkxz twj gefvypc lmb ldre klb ulipy cvvi tnxgjl oew cvvi xhqk unxxymp

print(try await session.respond(to: "what color is the sky?"))
print(try await session.respond(to: "why is that?"))
print(try await session.respond(to: "describe this image", image: .ciImage(CIImage.red)))
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni @angeloskath some examples of the streamlined API

@@ -0,0 +1,50 @@
import MLXLLM
import MLXLMCommon
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angeloskath @awni this is meant as an example of the streamlined API with full integration

  • how does this look?
  • changes? improvements?
  • I tried to keep the code as simple as possible but still have legible output
  • if the output isn't considered it can be simpler:
let session = ChatSession(model)
print(try await session.respond(to: "What are three things to see in San Francisco?")
print(try await session.respond(to: "How about a place to eat?")

// add the assistant response to the chat messages
state.chat.append(.assistant(output))
// the kvcache now contains this context
state.chat.removeAll()
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the API but I realized the messages/kvcache in the chat command line example were not right.


""")

for try await item in session.streamResponse(to: question) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I slightly feel like this should be session.streamRespond(to: question) to match session.respond(to:question).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or they could both be response instead of respond?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awni
Copy link
Copy Markdown
Member

awni commented Jun 12, 2025

This is really great. I left a couple inline comments / questions.

  • I think the main thing I'm unsure about (same as you) is if we should provide the generate API or not. It's a small improvement over the session in terms of usability. And yet it's also pretty nice and matches the Python version as well.
  • I'm wondering if we could simplify the loading to be ModelFactory.load("model/path") (or just a free function) and it figures out if it's an LLM or VLM dynamically?

Comment on lines +194 to +211
/// an `actor` providing an isolation context. Use this call when you control the isolation context
/// and can hold the `ModelContext` directly.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should those be in double backticks?

Comment on lines +231 to +254
/// an `actor` providing an isolation context. Use this call when you control the isolation context
/// and can hold the `ModelContext` directly.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment.


private let generator: Generator

/// Initialzie the ChatSession
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Initialzie the ChatSession
/// Initialize the ChatSession

Also maybe ChatSession should be backticked?

generateParameters: generateParameters)
}

/// Initialzie the ChatSession
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Initialzie the ChatSession
/// Initialize the ChatSession

/// - hub: optional HubApi -- by default uses ``defaultHubApi``
/// - directory: directory of configuration and weights
/// - progressHandler: optional callback for progress
/// - Returns: a ModelContainer
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ModelContainer should be in backticks?

/// - hub: optional HubApi -- by default uses ``defaultHubApi``
/// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit"
/// - progressHandler: optional callback for progress
/// - Returns: a ModelContainer
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ModelContainer should be in backticks?

/// - hub: optional HubApi -- by default uses ``defaultHubApi``
/// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit"
/// - progressHandler: optional callback for progress
/// - Returns: a ModelContext
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModelContext in backticks?

Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nits in the docs. O/w looks awesome!

@davidkoski davidkoski merged commit 45563d4 into main Jun 13, 2025
1 check passed
@davidkoski davidkoski deleted the simple-api branch June 13, 2025 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants