Skip to content

Commit b50b62e

Browse files
committed
move state filtering into engine; targets->candidate
1 parent 5a7dc04 commit b50b62e

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

tsuchinoko/execution/bluesky_adaptive.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import time
33
from typing import Tuple, List
44

5+
import numpy as np
56
import zmq
67
from loguru import logger
78

89
from . import Engine
10+
from ..adaptive.gpCAM_in_process import GPCAMInProcessEngine
911

1012
SLEEP_FOR_AGENT_TIME = .1
1113
SLEEP_FOR_TSUCHINOKO_TIME = .1
@@ -17,7 +19,11 @@ class BlueskyAdaptiveEngine(Engine):
1719
A `tsuchinoko.execution.Engine` that sends targets to Blueskly-Adaptive and receives back measured data.
1820
"""
1921

20-
def __init__(self, host: str = '127.0.0.1', port: int = 5557):
22+
suggest_blacklist = ["x_data",
23+
"y_data",
24+
"noise_variances"] # keys with ragged state
25+
26+
def __init__(self, adaptive_engine:GPCAMInProcessEngine, host: str = '127.0.0.1', port: int = 5557):
2127
"""
2228
2329
Parameters
@@ -29,6 +35,7 @@ def __init__(self, host: str = '127.0.0.1', port: int = 5557):
2935
"""
3036
super(BlueskyAdaptiveEngine, self).__init__()
3137

38+
self.adaptive_engine = adaptive_engine
3239
self.position = None
3340
self.context = None
3441
self.socket = None
@@ -59,8 +66,16 @@ def update_targets(self, targets: List[Tuple]):
5966
if self.has_fresh_points_on_server:
6067
time.sleep(SLEEP_FOR_AGENT_TIME) # chill if the Agent hasn't measured any points from the previous list
6168
else:
69+
# checkpoint optimizer state
70+
gpcam_state = self.adaptive_engine.optimizer.__getstate__()
71+
sanitized_gpcam_state = dict(
72+
**{key if key not in self.suggest_blacklist else f"STATEDICT-{key}": np.asarray(val)
73+
for key, val in gpcam_state.items()
74+
if key in self.suggest_blacklist})
75+
6276
# send targets to TsuchinokoAgent
63-
self.has_fresh_points_on_server = self.send_payload({'targets': targets})
77+
self.has_fresh_points_on_server = self.send_payload({'candidate': targets,
78+
'optimizer': sanitized_gpcam_state})
6479
self._last_targets_sent = targets
6580

6681
def get_measurements(self) -> List[Tuple]:

0 commit comments

Comments
 (0)