22import time
33from typing import Tuple , List
44
5+ import numpy as np
56import zmq
67from loguru import logger
78
89from . import Engine
10+ from ..adaptive .gpCAM_in_process import GPCAMInProcessEngine
911
1012SLEEP_FOR_AGENT_TIME = .1
1113SLEEP_FOR_TSUCHINOKO_TIME = .1
@@ -17,7 +19,11 @@ class BlueskyAdaptiveEngine(Engine):
1719 A `tsuchinoko.execution.Engine` that sends targets to Blueskly-Adaptive and receives back measured data.
1820 """
1921
20- def __init__ (self , host : str = '127.0.0.1' , port : int = 5557 ):
22+ suggest_blacklist = ["x_data" ,
23+ "y_data" ,
24+ "noise_variances" ] # keys with ragged state
25+
26+ def __init__ (self , adaptive_engine :GPCAMInProcessEngine , host : str = '127.0.0.1' , port : int = 5557 ):
2127 """
2228
2329 Parameters
@@ -29,6 +35,7 @@ def __init__(self, host: str = '127.0.0.1', port: int = 5557):
2935 """
3036 super (BlueskyAdaptiveEngine , self ).__init__ ()
3137
38+ self .adaptive_engine = adaptive_engine
3239 self .position = None
3340 self .context = None
3441 self .socket = None
@@ -59,8 +66,16 @@ def update_targets(self, targets: List[Tuple]):
5966 if self .has_fresh_points_on_server :
6067 time .sleep (SLEEP_FOR_AGENT_TIME ) # chill if the Agent hasn't measured any points from the previous list
6168 else :
69+ # checkpoint optimizer state
70+ gpcam_state = self .adaptive_engine .optimizer .__getstate__ ()
71+ sanitized_gpcam_state = dict (
72+ ** {key if key not in self .suggest_blacklist else f"STATEDICT-{ key } " : np .asarray (val )
73+ for key , val in gpcam_state .items ()
74+ if key in self .suggest_blacklist })
75+
6276 # send targets to TsuchinokoAgent
63- self .has_fresh_points_on_server = self .send_payload ({'targets' : targets })
77+ self .has_fresh_points_on_server = self .send_payload ({'candidate' : targets ,
78+ 'optimizer' : sanitized_gpcam_state })
6479 self ._last_targets_sent = targets
6580
6681 def get_measurements (self ) -> List [Tuple ]:
0 commit comments