33//! This module provides simple loss functions that measure the difference
44//! between fitted and target region areas.
55
6+ use crate :: geometry:: diagram:: RegionMask ;
7+ use std:: collections:: HashMap ;
8+
69/// Loss function type
710#[ derive( Debug , Clone , Copy , PartialEq , Default ) ]
811pub enum LossType {
@@ -38,42 +41,51 @@ impl LossType {
3841 Self :: MaxAbsolute
3942 }
4043
41- /// Compute loss between fitted and target values
42- pub fn compute ( & self , fitted : & [ f64 ] , target : & [ f64 ] ) -> f64 {
43- assert_eq ! (
44- fitted. len( ) ,
45- target. len( ) ,
46- "Fitted and target vectors must have the same length"
47- ) ;
44+ /// Compute loss between fitted and target region areas
45+ pub fn compute (
46+ & self ,
47+ fitted : & HashMap < RegionMask , f64 > ,
48+ target : & HashMap < RegionMask , f64 > ,
49+ ) -> f64 {
50+ // Collect all unique region masks from both fitted and target
51+ let all_masks: std:: collections:: HashSet < RegionMask > =
52+ fitted. keys ( ) . chain ( target. keys ( ) ) . copied ( ) . collect ( ) ;
4853
49- if fitted . is_empty ( ) {
54+ if all_masks . is_empty ( ) {
5055 return 0.0 ;
5156 }
5257
5358 match self {
5459 LossType :: Sse => {
5560 // Sum of squared errors
56- fitted
61+ all_masks
5762 . iter ( )
58- . zip ( target. iter ( ) )
59- . map ( |( f, t) | ( f - t) . powi ( 2 ) )
63+ . map ( |& mask| {
64+ let fitted_area = fitted. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
65+ let target_area = target. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
66+ ( fitted_area - target_area) . powi ( 2 )
67+ } )
6068 . sum ( )
6169 }
6270 LossType :: Rmse => {
6371 // Root mean squared error
64- let sum_squared: f64 = fitted
72+ let sum_squared: f64 = all_masks
6573 . iter ( )
66- . zip ( target. iter ( ) )
67- . map ( |( f, t) | ( f - t) . powi ( 2 ) )
74+ . map ( |& mask| {
75+ let f = fitted. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
76+ let t = target. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
77+ ( f - t) . powi ( 2 )
78+ } )
6879 . sum ( ) ;
69- ( sum_squared / fitted . len ( ) as f64 ) . sqrt ( )
80+ ( sum_squared / all_masks . len ( ) as f64 ) . sqrt ( )
7081 }
7182 LossType :: Stress => {
7283 // Stress: sum of squared relative errors
73- fitted
84+ all_masks
7485 . iter ( )
75- . zip ( target. iter ( ) )
76- . map ( |( f, t) | {
86+ . map ( |& mask| {
87+ let f = fitted. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
88+ let t = target. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
7789 if t. abs ( ) < 1e-10 {
7890 if f. abs ( ) < 1e-10 {
7991 0.0
@@ -88,10 +100,13 @@ impl LossType {
88100 }
89101 LossType :: MaxAbsolute => {
90102 // Maximum absolute error
91- fitted
103+ all_masks
92104 . iter ( )
93- . zip ( target. iter ( ) )
94- . map ( |( f, t) | ( f - t) . abs ( ) )
105+ . map ( |& mask| {
106+ let f = fitted. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
107+ let t = target. get ( & mask) . copied ( ) . unwrap_or ( 0.0 ) ;
108+ ( f - t) . abs ( )
109+ } )
95110 . fold ( 0.0 , f64:: max)
96111 }
97112 }
@@ -105,8 +120,16 @@ mod tests {
105120 #[ test]
106121 fn test_sse ( ) {
107122 let loss = LossType :: sse ( ) ;
108- let fitted = vec ! [ 10.0 , 20.0 , 30.0 ] ;
109- let target = vec ! [ 12.0 , 18.0 , 28.0 ] ;
123+
124+ let mut fitted = HashMap :: new ( ) ;
125+ fitted. insert ( 0b001 , 10.0 ) ;
126+ fitted. insert ( 0b010 , 20.0 ) ;
127+ fitted. insert ( 0b100 , 30.0 ) ;
128+
129+ let mut target = HashMap :: new ( ) ;
130+ target. insert ( 0b001 , 12.0 ) ;
131+ target. insert ( 0b010 , 18.0 ) ;
132+ target. insert ( 0b100 , 28.0 ) ;
110133
111134 // (10-12)² + (20-18)² + (30-28)² = 4 + 4 + 4 = 12
112135 assert_eq ! ( loss. compute( & fitted, & target) , 12.0 ) ;
@@ -115,8 +138,16 @@ mod tests {
115138 #[ test]
116139 fn test_rmse ( ) {
117140 let loss = LossType :: rmse ( ) ;
118- let fitted = vec ! [ 10.0 , 20.0 , 30.0 ] ;
119- let target = vec ! [ 12.0 , 18.0 , 28.0 ] ;
141+
142+ let mut fitted = HashMap :: new ( ) ;
143+ fitted. insert ( 0b001 , 10.0 ) ;
144+ fitted. insert ( 0b010 , 20.0 ) ;
145+ fitted. insert ( 0b100 , 30.0 ) ;
146+
147+ let mut target = HashMap :: new ( ) ;
148+ target. insert ( 0b001 , 12.0 ) ;
149+ target. insert ( 0b010 , 18.0 ) ;
150+ target. insert ( 0b100 , 28.0 ) ;
120151
121152 // sqrt((4 + 4 + 4) / 3) = sqrt(4) = 2.0
122153 assert_eq ! ( loss. compute( & fitted, & target) , 2.0 ) ;
@@ -125,8 +156,14 @@ mod tests {
125156 #[ test]
126157 fn test_stress ( ) {
127158 let loss = LossType :: stress ( ) ;
128- let fitted = vec ! [ 10.0 , 20.0 ] ;
129- let target = vec ! [ 12.0 , 18.0 ] ;
159+
160+ let mut fitted = HashMap :: new ( ) ;
161+ fitted. insert ( 0b001 , 10.0 ) ;
162+ fitted. insert ( 0b010 , 20.0 ) ;
163+
164+ let mut target = HashMap :: new ( ) ;
165+ target. insert ( 0b001 , 12.0 ) ;
166+ target. insert ( 0b010 , 18.0 ) ;
130167
131168 // ((10-12)/12)² + ((20-18)/18)² = (1/6)² + (1/9)² = 0.02778 + 0.01235 ≈ 0.04013
132169 let result = loss. compute ( & fitted, & target) ;
@@ -136,24 +173,74 @@ mod tests {
136173 #[ test]
137174 fn test_max_absolute ( ) {
138175 let loss = LossType :: max_absolute ( ) ;
139- let fitted = vec ! [ 10.0 , 20.0 , 30.0 ] ;
140- let target = vec ! [ 8.0 , 25.0 , 28.0 ] ;
176+
177+ let mut fitted = HashMap :: new ( ) ;
178+ fitted. insert ( 0b001 , 10.0 ) ;
179+ fitted. insert ( 0b010 , 20.0 ) ;
180+ fitted. insert ( 0b100 , 30.0 ) ;
181+
182+ let mut target = HashMap :: new ( ) ;
183+ target. insert ( 0b001 , 8.0 ) ;
184+ target. insert ( 0b010 , 25.0 ) ;
185+ target. insert ( 0b100 , 28.0 ) ;
141186
142187 // max(|10-8|, |20-25|, |30-28|) = max(2, 5, 2) = 5
143188 assert_eq ! ( loss. compute( & fitted, & target) , 5.0 ) ;
144189 }
145190
146191 #[ test]
147- fn test_empty_vectors ( ) {
192+ fn test_empty_target ( ) {
193+ let loss = LossType :: sse ( ) ;
194+ let fitted = HashMap :: new ( ) ;
195+ let target = HashMap :: new ( ) ;
196+ assert_eq ! ( loss. compute( & fitted, & target) , 0.0 ) ;
197+ }
198+
199+ #[ test]
200+ fn test_missing_fitted_area ( ) {
148201 let loss = LossType :: sse ( ) ;
149- assert_eq ! ( loss. compute( & [ ] , & [ ] ) , 0.0 ) ;
202+
203+ let fitted = HashMap :: new ( ) ; // Empty - no fitted areas
204+
205+ let mut target = HashMap :: new ( ) ;
206+ target. insert ( 0b001 , 5.0 ) ;
207+ target. insert ( 0b010 , 3.0 ) ;
208+
209+ // (0-5)² + (0-3)² = 25 + 9 = 34
210+ assert_eq ! ( loss. compute( & fitted, & target) , 34.0 ) ;
211+ }
212+
213+ #[ test]
214+ fn test_extra_fitted_area ( ) {
215+ let loss = LossType :: sse ( ) ;
216+
217+ let mut fitted = HashMap :: new ( ) ;
218+ fitted. insert ( 0b001 , 5.0 ) ;
219+ fitted. insert ( 0b010 , 3.0 ) ;
220+ fitted. insert ( 0b100 , 7.0 ) ; // Extra region not in target
221+
222+ let mut target = HashMap :: new ( ) ;
223+ target. insert ( 0b001 , 5.0 ) ;
224+ target. insert ( 0b010 , 3.0 ) ;
225+ // 0b100 missing from target
226+
227+ // (5-5)² + (3-3)² + (7-0)² = 0 + 0 + 49 = 49
228+ assert_eq ! ( loss. compute( & fitted, & target) , 49.0 ) ;
150229 }
151230
152231 #[ test]
153232 fn test_stress_with_zero_target ( ) {
154233 let loss = LossType :: stress ( ) ;
155- let fitted = vec ! [ 5.0 , 0.0 , 3.0 ] ;
156- let target = vec ! [ 0.0 , 0.0 , 3.0 ] ;
234+
235+ let mut fitted = HashMap :: new ( ) ;
236+ fitted. insert ( 0b001 , 5.0 ) ;
237+ fitted. insert ( 0b010 , 0.0 ) ;
238+ fitted. insert ( 0b100 , 3.0 ) ;
239+
240+ let mut target = HashMap :: new ( ) ;
241+ target. insert ( 0b001 , 0.0 ) ;
242+ target. insert ( 0b010 , 0.0 ) ;
243+ target. insert ( 0b100 , 3.0 ) ;
157244
158245 // First: target=0, fitted=5: 5² = 25
159246 // Second: target=0, fitted=0: 0
@@ -175,11 +262,4 @@ mod tests {
175262 let cloned = loss;
176263 assert_eq ! ( loss, cloned) ;
177264 }
178-
179- #[ test]
180- #[ should_panic( expected = "Fitted and target vectors must have the same length" ) ]
181- fn test_mismatched_lengths ( ) {
182- let loss = LossType :: sse ( ) ;
183- loss. compute ( & [ 1.0 , 2.0 ] , & [ 1.0 ] ) ;
184- }
185265}
0 commit comments