Skip to content

Commit 1ec0960

Browse files
committed
refactor: revert to simpler loss function definition
1 parent b71952b commit 1ec0960

File tree

5 files changed

+232
-420
lines changed

5 files changed

+232
-420
lines changed

crates/eunoia-wasm/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2025 Johan Larsson
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

crates/eunoia-wasm/src/lib.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ impl WasmRegion {
384384
#[wasm_bindgen]
385385
pub struct WasmRegionPolygons {
386386
regions: Vec<WasmRegion>,
387+
pub loss: f64,
388+
target_areas_json: String,
389+
fitted_areas_json: String,
387390
}
388391

389392
#[wasm_bindgen]
@@ -393,6 +396,16 @@ impl WasmRegionPolygons {
393396
self.regions.clone()
394397
}
395398

399+
#[wasm_bindgen(getter)]
400+
pub fn target_areas_json(&self) -> String {
401+
self.target_areas_json.clone()
402+
}
403+
404+
#[wasm_bindgen(getter)]
405+
pub fn fitted_areas_json(&self) -> String {
406+
self.fitted_areas_json.clone()
407+
}
408+
396409
#[wasm_bindgen(getter)]
397410
pub fn count(&self) -> usize {
398411
self.regions.len()
@@ -1362,8 +1375,25 @@ pub fn generate_region_polygons_circles(
13621375
});
13631376
}
13641377

1378+
// Get target and fitted areas, converting Combination keys to strings
1379+
let target_areas: std::collections::HashMap<String, f64> = layout
1380+
.requested()
1381+
.iter()
1382+
.map(|(k, v)| (k.to_string(), *v))
1383+
.collect();
1384+
let fitted_areas: std::collections::HashMap<String, f64> = layout
1385+
.fitted()
1386+
.iter()
1387+
.map(|(k, v)| (k.to_string(), *v))
1388+
.collect();
1389+
13651390
Ok(WasmRegionPolygons {
13661391
regions: wasm_regions,
1392+
loss: layout.loss(),
1393+
target_areas_json: serde_json::to_string(&target_areas)
1394+
.map_err(|e| JsValue::from_str(&format!("{}", e)))?,
1395+
fitted_areas_json: serde_json::to_string(&fitted_areas)
1396+
.map_err(|e| JsValue::from_str(&format!("{}", e)))?,
13671397
})
13681398
}
13691399

@@ -1443,7 +1473,24 @@ pub fn generate_region_polygons_ellipses(
14431473
});
14441474
}
14451475

1476+
// Get target and fitted areas, converting Combination keys to strings
1477+
let target_areas: std::collections::HashMap<String, f64> = layout
1478+
.requested()
1479+
.iter()
1480+
.map(|(k, v)| (k.to_string(), *v))
1481+
.collect();
1482+
let fitted_areas: std::collections::HashMap<String, f64> = layout
1483+
.fitted()
1484+
.iter()
1485+
.map(|(k, v)| (k.to_string(), *v))
1486+
.collect();
1487+
14461488
Ok(WasmRegionPolygons {
14471489
regions: wasm_regions,
1490+
loss: layout.loss(),
1491+
target_areas_json: serde_json::to_string(&target_areas)
1492+
.map_err(|e| JsValue::from_str(&format!("{}", e)))?,
1493+
fitted_areas_json: serde_json::to_string(&fitted_areas)
1494+
.map_err(|e| JsValue::from_str(&format!("{}", e)))?,
14481495
})
14491496
}

crates/eunoia/src/fitter/final_layout.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,9 @@ pub(crate) fn optimize_layout<S: DiagramShape + Copy + 'static>(
124124
}
125125
}
126126

127-
let loss_fn = config.loss_type.create();
128127
let inner_cost = DiagramCost::<S> {
129128
spec,
130-
loss_fn,
129+
loss_type: config.loss_type,
131130
params_per_shape,
132131
_shape: std::marker::PhantomData,
133132
};
@@ -176,10 +175,9 @@ pub(crate) fn optimize_layout<S: DiagramShape + Copy + 'static>(
176175
simplex
177176
};
178177

179-
let loss_fn = config.loss_type.create();
180178
let cost_function = DiagramCost::<S> {
181179
spec,
182-
loss_fn,
180+
loss_type: config.loss_type,
183181
params_per_shape,
184182
_shape: std::marker::PhantomData,
185183
};
@@ -194,10 +192,9 @@ pub(crate) fn optimize_layout<S: DiagramShape + Copy + 'static>(
194192
}
195193
Optimizer::Lbfgs => {
196194
// L-BFGS with numerical gradients
197-
let loss_fn = config.loss_type.create();
198195
let cost_function_lbfgs = DiagramCost::<S> {
199196
spec,
200-
loss_fn,
197+
loss_type: config.loss_type,
201198
params_per_shape,
202199
_shape: std::marker::PhantomData,
203200
};
@@ -217,10 +214,9 @@ pub(crate) fn optimize_layout<S: DiagramShape + Copy + 'static>(
217214
}
218215
Optimizer::ConjugateGradient => {
219216
// Conjugate Gradient with numerical gradients
220-
let loss_fn = config.loss_type.create();
221217
let cost_function_cg = DiagramCost::<S> {
222218
spec,
223-
loss_fn,
219+
loss_type: config.loss_type,
224220
params_per_shape,
225221
_shape: std::marker::PhantomData,
226222
};
@@ -249,7 +245,7 @@ pub(crate) fn optimize_layout<S: DiagramShape + Copy + 'static>(
249245
/// Computes the discrepancy between target exclusive areas and actual fitted areas.
250246
struct DiagramCost<'a, S: DiagramShape + Copy + 'static> {
251247
spec: &'a PreprocessedSpec,
252-
loss_fn: Box<dyn crate::loss::LossFunction>,
248+
loss_type: crate::loss::LossType,
253249
params_per_shape: usize,
254250
_shape: std::marker::PhantomData<S>,
255251
}
@@ -279,10 +275,18 @@ impl<'a, S: DiagramShape + Copy + 'static> CostFunction for DiagramCost<'a, S> {
279275
// Compute exclusive regions using shape-specific exact computation
280276
let exclusive_areas = S::compute_exclusive_regions(&shapes);
281277

278+
// Convert to vectors for loss computation
279+
let mut fitted_vec = Vec::new();
280+
let mut target_vec = Vec::new();
281+
282+
for (mask, &target_area) in &self.spec.exclusive_areas {
283+
let fitted_area = exclusive_areas.get(mask).copied().unwrap_or(0.0);
284+
fitted_vec.push(fitted_area);
285+
target_vec.push(target_area);
286+
}
287+
282288
// Use the configured loss function
283-
let error = self
284-
.loss_fn
285-
.evaluate(&exclusive_areas, &self.spec.exclusive_areas);
289+
let error = self.loss_type.compute(&fitted_vec, &target_vec);
286290

287291
Ok(error)
288292
}
@@ -455,10 +459,9 @@ mod tests {
455459

456460
let preprocessed = spec.preprocess().unwrap();
457461

458-
let loss_fn = crate::loss::LossType::sse().create();
459462
let cost_fn = DiagramCost::<Circle> {
460463
spec: &preprocessed,
461-
loss_fn,
464+
loss_type: crate::loss::LossType::sse(),
462465
params_per_shape: Circle::n_params(),
463466
_shape: std::marker::PhantomData,
464467
};

0 commit comments

Comments
 (0)