Skip to content

Commit 4660337

Browse files
committed
first commit
0 parents  commit 4660337

File tree

7 files changed

+356
-0
lines changed

7 files changed

+356
-0
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.DS_Store
2+
/.build
3+
/Packages
4+
xcuserdata/
5+
DerivedData/
6+
.swiftpm/configuration/registries.json
7+
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
8+
.netrc

Package.resolved

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// swift-tools-version: 6.0
2+
// The swift-tools-version declares the minimum version of Swift required to build this package.
3+
4+
import PackageDescription
5+
6+
let package = Package(
7+
name: "OtosakuKWS",
8+
platforms: [
9+
.iOS(.v16),
10+
.macOS(.v12)
11+
],
12+
products: [
13+
// Products define the executables and libraries a package produces, making them visible to other packages.
14+
.library(
15+
name: "OtosakuKWS",
16+
targets: ["OtosakuKWS"]),
17+
],
18+
dependencies: [
19+
.package(url: "https://github.com/Otosaku/OtosakuFeatureExtractor-iOS.git", from: "1.0.1"),
20+
.package(url: "https://github.com/ZipArchive/ZipArchive.git", from: "2.6.0")
21+
],
22+
targets: [
23+
// Targets are the basic building blocks of a package, defining a module or a test suite.
24+
// Targets can depend on other targets in this package and products from dependencies.
25+
.target(
26+
name: "OtosakuKWS",
27+
dependencies: [
28+
.product(name: "OtosakuFeatureExtractor", package: "OtosakuFeatureExtractor-iOS"),
29+
.product(name: "ZipArchive", package: "ZipArchive")
30+
],
31+
path: "Sources"
32+
),
33+
.testTarget(
34+
name: "OtosakuKWSTests",
35+
dependencies: ["OtosakuKWS"]
36+
),
37+
]
38+
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
//
2+
// ModelDownloader.swift
3+
// OtosakuKWS
4+
//
5+
// Created by Marat Zainullin on 12/06/2025.
6+
//
7+
//
8+
//
9+
10+
#if canImport(UIKit)
11+
import UIKit
12+
import Foundation
13+
import ZipArchive
14+
15+
16+
final class ModelDownloader: NSObject, URLSessionDownloadDelegate {
17+
static let shared = ModelDownloader()
18+
19+
private var progressHandler: ((Float) -> Void)?
20+
private var completionHandler: ((Result<URL, Error>) -> Void)?
21+
private var destinationFolder: URL?
22+
private var remoteURL: URL?
23+
private var currentTask: URLSessionDownloadTask?
24+
25+
private lazy var session: URLSession = {
26+
let config = URLSessionConfiguration.background(withIdentifier: "otosaku.otusaku-kws.download")
27+
return URLSession(configuration: config, delegate: self, delegateQueue: nil)
28+
}()
29+
30+
func downloadAndUnzip(from remoteURL: URL,
31+
to destinationFolder: URL,
32+
progress: @escaping (Float) -> Void,
33+
completion: @escaping (Result<URL, Error>) -> Void) {
34+
self.progressHandler = progress
35+
self.completionHandler = completion
36+
self.destinationFolder = destinationFolder
37+
self.remoteURL = remoteURL
38+
39+
session.getAllTasks { tasks in
40+
if let existing = tasks.first(where: { $0.originalRequest?.url == remoteURL }) as? URLSessionDownloadTask {
41+
self.currentTask = existing
42+
UserDefaults.standard.set(remoteURL.absoluteString, forKey: "ActiveDownloadURL")
43+
return
44+
}
45+
46+
let task = self.session.downloadTask(with: remoteURL)
47+
self.currentTask = task
48+
UserDefaults.standard.set(remoteURL.absoluteString, forKey: "ActiveDownloadURL")
49+
task.resume()
50+
}
51+
}
52+
53+
func restorePendingDownloadIfNeeded(progress: @escaping (Float) -> Void,
54+
completion: @escaping (Result<URL, Error>) -> Void) {
55+
guard let urlString = UserDefaults.standard.string(forKey: "ActiveDownloadURL"),
56+
let url = URL(string: urlString) else { return }
57+
58+
session.getAllTasks { tasks in
59+
if let existing = tasks.first(where: { $0.originalRequest?.url == url }) as? URLSessionDownloadTask {
60+
self.remoteURL = url
61+
self.progressHandler = progress
62+
self.completionHandler = completion
63+
self.currentTask = existing
64+
}
65+
}
66+
}
67+
68+
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask,
69+
didWriteData bytesWritten: Int64,
70+
totalBytesWritten: Int64,
71+
totalBytesExpectedToWrite: Int64) {
72+
guard totalBytesExpectedToWrite > 0 else { return }
73+
let fractionCompleted = Float(totalBytesWritten) / Float(totalBytesExpectedToWrite)
74+
DispatchQueue.main.async {
75+
self.progressHandler?(fractionCompleted)
76+
}
77+
}
78+
79+
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
80+
guard let destinationFolder else {
81+
completionHandler?(.failure(NSError(domain: "Missing destination folder", code: 0)))
82+
return
83+
}
84+
85+
do {
86+
let fileManager = FileManager.default
87+
88+
if fileManager.fileExists(atPath: destinationFolder.path) {
89+
try fileManager.removeItem(at: destinationFolder)
90+
}
91+
try fileManager.createDirectory(at: destinationFolder, withIntermediateDirectories: true)
92+
93+
let success = SSZipArchive.unzipFile(atPath: location.path,
94+
toDestination: destinationFolder.deletingLastPathComponent().path)
95+
if success {
96+
DispatchQueue.main.async {
97+
UserDefaults.standard.removeObject(forKey: "ActiveDownloadURL")
98+
self.completionHandler?(.success(destinationFolder))
99+
self.callBackgroundSessionCompletionIfNeeded()
100+
}
101+
} else {
102+
throw NSError(domain: "ModelDownloader", code: 1,
103+
userInfo: [NSLocalizedDescriptionKey: "Unzipping failed"])
104+
}
105+
} catch {
106+
DispatchQueue.main.async {
107+
UserDefaults.standard.removeObject(forKey: "ActiveDownloadURL")
108+
self.completionHandler?(.failure(error))
109+
self.callBackgroundSessionCompletionIfNeeded()
110+
}
111+
}
112+
}
113+
114+
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
115+
if let error {
116+
DispatchQueue.main.async {
117+
UserDefaults.standard.removeObject(forKey: "ActiveDownloadURL")
118+
self.completionHandler?(.failure(error))
119+
self.callBackgroundSessionCompletionIfNeeded()
120+
}
121+
}
122+
}
123+
124+
private func callBackgroundSessionCompletionIfNeeded() {
125+
DispatchQueue.main.async {
126+
if let appDelegate = UIApplication.shared.delegate as? AppDelegate,
127+
let completion = appDelegate.backgroundSessionCompletionHandler {
128+
completion()
129+
appDelegate.backgroundSessionCompletionHandler = nil
130+
}
131+
}
132+
}
133+
}
134+
#endif
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// The Swift Programming Language
2+
// https://docs.swift.org/swift-book
3+
4+
#if canImport(UIKit)
5+
import Foundation
6+
import OtosakuFeatureExtractor
7+
import CoreML
8+
9+
10+
public class OtosakuKWS {
11+
private let TAG: String = "OtosakuKWS"
12+
private let featureExtractor: OtosakuFeatureExtractor
13+
private let model: OtosakuKWSModel
14+
private let classes: [String]
15+
16+
private var buffer: [Double] = []
17+
private let totalFrameLimit = 16000
18+
private var threshold: Float = 0.9
19+
20+
21+
public init (
22+
modelRootURL: URL,
23+
featureExtractorRootURL: URL,
24+
configuration: MLModelConfiguration
25+
) throws {
26+
self.featureExtractor = try OtosakuFeatureExtractor(directoryURL: featureExtractorRootURL)
27+
self.model = try OtosakuKWSModel(
28+
url: modelRootURL.appendingPathComponent("CRNNKeywordSpotter.mlmodelc"),
29+
configuration: configuration
30+
)
31+
self.classes = try readClasses(
32+
from: modelRootURL.appendingPathComponent("classes.txt")
33+
)
34+
}
35+
36+
public func handleAudioBuffer(_ buff: [Double]) throws {
37+
buffer.append(contentsOf: buff)
38+
if buffer.count > totalFrameLimit {
39+
let overflow = buffer.count - totalFrameLimit
40+
buffer.removeFirst(overflow)
41+
}
42+
if buffer.count != totalFrameLimit {
43+
return
44+
}
45+
Task {
46+
do {
47+
var feats = try featureExtractor.processChunk(chunk: buffer)
48+
feats = try featureExtractor.expandDims2D(array: feats)
49+
let probsArray = try model.predict(x: feats)
50+
var (max, idx) = (-1.0, -1)
51+
for i in 0..<probsArray.count {
52+
let value = probsArray[i].doubleValue
53+
if value > max {
54+
max = value
55+
idx = i
56+
}
57+
}
58+
if max > threshold {
59+
let currentClass = self.classes[idx]
60+
}
61+
} catch {
62+
print(TAG, "handleAudioBuffer: error: \(error)")
63+
}
64+
}
65+
66+
}
67+
68+
public func setProbabilityThreshold(threshold: Float) {
69+
self.threshold = threshold
70+
}
71+
72+
public static func downloadAndUnzip(
73+
from remoteURL: URL,
74+
to destinationFolder: URL,
75+
progress: @escaping (Float) -> Void,
76+
completion: @escaping (Result<URL, Error>) -> Void
77+
) {
78+
let downloader = ModelDownloader.shared
79+
return downloader.downloadAndUnzip(
80+
from: remoteURL,
81+
to: destinationFolder,
82+
progress: progress,
83+
completion: completion
84+
)
85+
}
86+
87+
88+
private func readClasses(from url: URL) throws -> [String] {
89+
let text = try String(contentsOf: url, encoding: .utf8)
90+
let lines = text.components(separatedBy: .newlines)
91+
return lines
92+
}
93+
}
94+
95+
#endif
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//
2+
// OtosakuKWSModel.swift
3+
// OtosakuKWS
4+
//
5+
// Created by Marat Zainullin on 12/06/2025.
6+
//
7+
8+
import CoreML
9+
import Foundation
10+
11+
enum OtosakuKWSModelPredictError: Error {
12+
case outputExtractionFailed
13+
}
14+
15+
class OtosakuKWSModel {
16+
private let model: MLModel
17+
18+
public init (url: URL, configuration: MLModelConfiguration) throws {
19+
model = try MLModel(contentsOf: url, configuration: configuration)
20+
}
21+
22+
public func predict(x: MLMultiArray) throws -> MLMultiArray {
23+
let featureProvider = try MLDictionaryFeatureProvider(dictionary: [
24+
"input": x
25+
])
26+
let out = try model.prediction(from: featureProvider)
27+
guard let probs = out.featureValue(for: "probs")?.multiArrayValue else {
28+
throw OtosakuKWSModelPredictError.outputExtractionFailed
29+
}
30+
31+
return probs
32+
}
33+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import Testing
2+
@testable import OtosakuKWS
3+
4+
@Test func example() async throws {
5+
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
6+
}

0 commit comments

Comments
 (0)