Skip to content

Commit 15a2574

Browse files
committed
refactor: delegate handling of hash maps to helper
1 parent 1ec0960 commit 15a2574

File tree

2 files changed

+124
-52
lines changed

2 files changed

+124
-52
lines changed

crates/eunoia/src/fitter/final_layout.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,10 @@ impl<'a, S: DiagramShape + Copy + 'static> CostFunction for DiagramCost<'a, S> {
275275
// Compute exclusive regions using shape-specific exact computation
276276
let exclusive_areas = S::compute_exclusive_regions(&shapes);
277277

278-
// Convert to vectors for loss computation
279-
let mut fitted_vec = Vec::new();
280-
let mut target_vec = Vec::new();
281-
282-
for (mask, &target_area) in &self.spec.exclusive_areas {
283-
let fitted_area = exclusive_areas.get(mask).copied().unwrap_or(0.0);
284-
fitted_vec.push(fitted_area);
285-
target_vec.push(target_area);
286-
}
287-
288-
// Use the configured loss function
289-
let error = self.loss_type.compute(&fitted_vec, &target_vec);
278+
// Use the configured loss function directly on HashMaps
279+
let error = self
280+
.loss_type
281+
.compute(&exclusive_areas, &self.spec.exclusive_areas);
290282

291283
Ok(error)
292284
}

crates/eunoia/src/loss.rs

Lines changed: 120 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
//! This module provides simple loss functions that measure the difference
44
//! between fitted and target region areas.
55
6+
use crate::geometry::diagram::RegionMask;
7+
use std::collections::HashMap;
8+
69
/// Loss function type
710
#[derive(Debug, Clone, Copy, PartialEq, Default)]
811
pub enum LossType {
@@ -38,42 +41,51 @@ impl LossType {
3841
Self::MaxAbsolute
3942
}
4043

41-
/// Compute loss between fitted and target values
42-
pub fn compute(&self, fitted: &[f64], target: &[f64]) -> f64 {
43-
assert_eq!(
44-
fitted.len(),
45-
target.len(),
46-
"Fitted and target vectors must have the same length"
47-
);
44+
/// Compute loss between fitted and target region areas
45+
pub fn compute(
46+
&self,
47+
fitted: &HashMap<RegionMask, f64>,
48+
target: &HashMap<RegionMask, f64>,
49+
) -> f64 {
50+
// Collect all unique region masks from both fitted and target
51+
let all_masks: std::collections::HashSet<RegionMask> =
52+
fitted.keys().chain(target.keys()).copied().collect();
4853

49-
if fitted.is_empty() {
54+
if all_masks.is_empty() {
5055
return 0.0;
5156
}
5257

5358
match self {
5459
LossType::Sse => {
5560
// Sum of squared errors
56-
fitted
61+
all_masks
5762
.iter()
58-
.zip(target.iter())
59-
.map(|(f, t)| (f - t).powi(2))
63+
.map(|&mask| {
64+
let fitted_area = fitted.get(&mask).copied().unwrap_or(0.0);
65+
let target_area = target.get(&mask).copied().unwrap_or(0.0);
66+
(fitted_area - target_area).powi(2)
67+
})
6068
.sum()
6169
}
6270
LossType::Rmse => {
6371
// Root mean squared error
64-
let sum_squared: f64 = fitted
72+
let sum_squared: f64 = all_masks
6573
.iter()
66-
.zip(target.iter())
67-
.map(|(f, t)| (f - t).powi(2))
74+
.map(|&mask| {
75+
let f = fitted.get(&mask).copied().unwrap_or(0.0);
76+
let t = target.get(&mask).copied().unwrap_or(0.0);
77+
(f - t).powi(2)
78+
})
6879
.sum();
69-
(sum_squared / fitted.len() as f64).sqrt()
80+
(sum_squared / all_masks.len() as f64).sqrt()
7081
}
7182
LossType::Stress => {
7283
// Stress: sum of squared relative errors
73-
fitted
84+
all_masks
7485
.iter()
75-
.zip(target.iter())
76-
.map(|(f, t)| {
86+
.map(|&mask| {
87+
let f = fitted.get(&mask).copied().unwrap_or(0.0);
88+
let t = target.get(&mask).copied().unwrap_or(0.0);
7789
if t.abs() < 1e-10 {
7890
if f.abs() < 1e-10 {
7991
0.0
@@ -88,10 +100,13 @@ impl LossType {
88100
}
89101
LossType::MaxAbsolute => {
90102
// Maximum absolute error
91-
fitted
103+
all_masks
92104
.iter()
93-
.zip(target.iter())
94-
.map(|(f, t)| (f - t).abs())
105+
.map(|&mask| {
106+
let f = fitted.get(&mask).copied().unwrap_or(0.0);
107+
let t = target.get(&mask).copied().unwrap_or(0.0);
108+
(f - t).abs()
109+
})
95110
.fold(0.0, f64::max)
96111
}
97112
}
@@ -105,8 +120,16 @@ mod tests {
105120
#[test]
106121
fn test_sse() {
107122
let loss = LossType::sse();
108-
let fitted = vec![10.0, 20.0, 30.0];
109-
let target = vec![12.0, 18.0, 28.0];
123+
124+
let mut fitted = HashMap::new();
125+
fitted.insert(0b001, 10.0);
126+
fitted.insert(0b010, 20.0);
127+
fitted.insert(0b100, 30.0);
128+
129+
let mut target = HashMap::new();
130+
target.insert(0b001, 12.0);
131+
target.insert(0b010, 18.0);
132+
target.insert(0b100, 28.0);
110133

111134
// (10-12)² + (20-18)² + (30-28)² = 4 + 4 + 4 = 12
112135
assert_eq!(loss.compute(&fitted, &target), 12.0);
@@ -115,8 +138,16 @@ mod tests {
115138
#[test]
116139
fn test_rmse() {
117140
let loss = LossType::rmse();
118-
let fitted = vec![10.0, 20.0, 30.0];
119-
let target = vec![12.0, 18.0, 28.0];
141+
142+
let mut fitted = HashMap::new();
143+
fitted.insert(0b001, 10.0);
144+
fitted.insert(0b010, 20.0);
145+
fitted.insert(0b100, 30.0);
146+
147+
let mut target = HashMap::new();
148+
target.insert(0b001, 12.0);
149+
target.insert(0b010, 18.0);
150+
target.insert(0b100, 28.0);
120151

121152
// sqrt((4 + 4 + 4) / 3) = sqrt(4) = 2.0
122153
assert_eq!(loss.compute(&fitted, &target), 2.0);
@@ -125,8 +156,14 @@ mod tests {
125156
#[test]
126157
fn test_stress() {
127158
let loss = LossType::stress();
128-
let fitted = vec![10.0, 20.0];
129-
let target = vec![12.0, 18.0];
159+
160+
let mut fitted = HashMap::new();
161+
fitted.insert(0b001, 10.0);
162+
fitted.insert(0b010, 20.0);
163+
164+
let mut target = HashMap::new();
165+
target.insert(0b001, 12.0);
166+
target.insert(0b010, 18.0);
130167

131168
// ((10-12)/12)² + ((20-18)/18)² = (1/6)² + (1/9)² = 0.02778 + 0.01235 ≈ 0.04013
132169
let result = loss.compute(&fitted, &target);
@@ -136,24 +173,74 @@ mod tests {
136173
#[test]
137174
fn test_max_absolute() {
138175
let loss = LossType::max_absolute();
139-
let fitted = vec![10.0, 20.0, 30.0];
140-
let target = vec![8.0, 25.0, 28.0];
176+
177+
let mut fitted = HashMap::new();
178+
fitted.insert(0b001, 10.0);
179+
fitted.insert(0b010, 20.0);
180+
fitted.insert(0b100, 30.0);
181+
182+
let mut target = HashMap::new();
183+
target.insert(0b001, 8.0);
184+
target.insert(0b010, 25.0);
185+
target.insert(0b100, 28.0);
141186

142187
// max(|10-8|, |20-25|, |30-28|) = max(2, 5, 2) = 5
143188
assert_eq!(loss.compute(&fitted, &target), 5.0);
144189
}
145190

146191
#[test]
147-
fn test_empty_vectors() {
192+
fn test_empty_target() {
193+
let loss = LossType::sse();
194+
let fitted = HashMap::new();
195+
let target = HashMap::new();
196+
assert_eq!(loss.compute(&fitted, &target), 0.0);
197+
}
198+
199+
#[test]
200+
fn test_missing_fitted_area() {
148201
let loss = LossType::sse();
149-
assert_eq!(loss.compute(&[], &[]), 0.0);
202+
203+
let fitted = HashMap::new(); // Empty - no fitted areas
204+
205+
let mut target = HashMap::new();
206+
target.insert(0b001, 5.0);
207+
target.insert(0b010, 3.0);
208+
209+
// (0-5)² + (0-3)² = 25 + 9 = 34
210+
assert_eq!(loss.compute(&fitted, &target), 34.0);
211+
}
212+
213+
#[test]
214+
fn test_extra_fitted_area() {
215+
let loss = LossType::sse();
216+
217+
let mut fitted = HashMap::new();
218+
fitted.insert(0b001, 5.0);
219+
fitted.insert(0b010, 3.0);
220+
fitted.insert(0b100, 7.0); // Extra region not in target
221+
222+
let mut target = HashMap::new();
223+
target.insert(0b001, 5.0);
224+
target.insert(0b010, 3.0);
225+
// 0b100 missing from target
226+
227+
// (5-5)² + (3-3)² + (7-0)² = 0 + 0 + 49 = 49
228+
assert_eq!(loss.compute(&fitted, &target), 49.0);
150229
}
151230

152231
#[test]
153232
fn test_stress_with_zero_target() {
154233
let loss = LossType::stress();
155-
let fitted = vec![5.0, 0.0, 3.0];
156-
let target = vec![0.0, 0.0, 3.0];
234+
235+
let mut fitted = HashMap::new();
236+
fitted.insert(0b001, 5.0);
237+
fitted.insert(0b010, 0.0);
238+
fitted.insert(0b100, 3.0);
239+
240+
let mut target = HashMap::new();
241+
target.insert(0b001, 0.0);
242+
target.insert(0b010, 0.0);
243+
target.insert(0b100, 3.0);
157244

158245
// First: target=0, fitted=5: 5² = 25
159246
// Second: target=0, fitted=0: 0
@@ -175,11 +262,4 @@ mod tests {
175262
let cloned = loss;
176263
assert_eq!(loss, cloned);
177264
}
178-
179-
#[test]
180-
#[should_panic(expected = "Fitted and target vectors must have the same length")]
181-
fn test_mismatched_lengths() {
182-
let loss = LossType::sse();
183-
loss.compute(&[1.0, 2.0], &[1.0]);
184-
}
185265
}

0 commit comments

Comments
 (0)