-
Notifications
You must be signed in to change notification settings - Fork 369
Expand file tree
/
Copy pathEmbeddingModel.swift
More file actions
115 lines (104 loc) · 3.72 KB
/
EmbeddingModel.swift
File metadata and controls
115 lines (104 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// Copyright © 2024 Apple Inc.
import Foundation
@preconcurrency import Hub
import MLX
import MLXNN
import MLXLMCommon
import Tokenizers
/// Container for models that guarantees single threaded access.
///
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// the model and/or tokenizer:
///
/// ```swift
/// let promptTokens = await modelContainer.perform { _, tokenizer in
/// tokenizer.encode(text: prompt)
/// }
/// ```
///
/// or:
///
/// ```swift
/// let result = await modelContainer.perform { model, tokenizer in
/// LLM.generate(
/// promptTokens: promptTokens, parameters: generateParameters, model: model,
/// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
/// ) { tokens in
/// ...
/// }
/// }
/// ```
public actor ModelContainer {
let model: EmbeddingModel
let tokenizer: Tokenizer
let pooler: Pooling
public init(
model: EmbeddingModel, tokenizer: Tokenizer, pooler: Pooling = Pooling(strategy: .none)
) {
self.model = model
self.tokenizer = tokenizer
self.pooler = pooler
}
/// build the model and tokenizer without passing non-sendable data over isolation barriers
public init(
hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration
) async throws {
self.model = try loadSynchronous(modelDirectory: modelDirectory)
let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
configuration: configuration, hub: hub)
self.tokenizer = try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
self.pooler = loadPooling(modelDirectory: modelDirectory) //?? Pooling(strategy: .none)
}
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
/// `MLXArray` is not `Sendable`.
public func perform<R>(_ action: @Sendable (EmbeddingModel, Tokenizer, Pooling) throws -> R)
rethrows
-> R
{
try action(model, tokenizer, pooler)
}
}
extension Module {
/// Compute the number of parameters in a possibly quantized model
public func numParameters() -> Int {
return leafModules().flattenedValues().map {
mod -> Int in
if let qlin = mod as? QuantizedLinear {
return qlin.scales.size * qlin.groupSize
} else if let qemb = mod as? QuantizedEmbedding {
return qemb.scales.size * qemb.groupSize
} else {
return mod.parameters().flattenedValues().reduce(
0,
{
$0 + $1.size
})
}
}.reduce(0, +)
}
}
public struct EmbeddingModelOutput {
public let hiddenStates: MLXArray?
public let pooledOutput: MLXArray?
}
public protocol EmbeddingModel: Module {
var vocabularySize: Int { get }
func callAsFunction(
_ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?,
attentionMask: MLXArray?
) -> EmbeddingModelOutput
/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
func sanitize(weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?) -> [String: MLXArray]
}
extension EmbeddingModel {
func callAsFunction(
_ inputs: MLXArray, positionIds: MLXArray? = nil, tokenTypeIds: MLXArray? = nil,
attentionMask: MLXArray? = nil
) -> EmbeddingModelOutput {
return callAsFunction(
inputs, positionIds: positionIds, tokenTypeIds: tokenTypeIds,
attentionMask: attentionMask)
}
}