Skip to content

Commit 2780b9e

Browse files
committed
improve issm api
1 parent 488ed9b commit 2780b9e

File tree

3 files changed

+351
-87
lines changed

3 files changed

+351
-87
lines changed

src/gluonts/model/deepstate/_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def create_transformation(self) -> Transformation:
269269
),
270270
# Unnormalized seasonal features
271271
AddTimeFeatures(
272-
time_features=CompositeISSM.seasonal_features(self.freq),
272+
time_features=self.issm.time_features(),
273273
pred_length=self.prediction_length,
274274
start_field=FieldName.START,
275275
target_field=FieldName.TARGET,

src/gluonts/model/deepstate/issm.py

Lines changed: 94 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from typing import List, Tuple
1515

16+
import numpy as np
17+
import pandas as pd
1618
from pandas.tseries.frequencies import to_offset
1719

1820
from gluonts.core.component import validated
@@ -89,14 +91,23 @@ def _make_2_block_diagonal(F, left: Tensor, right: Tensor) -> Tensor:
8991
return _block_diagonal
9092

9193

94+
class ZeroFeature(TimeFeature):
95+
"""
96+
A feature that is identically zero.
97+
"""
98+
99+
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
100+
return np.zeros(index.values.shape)
101+
102+
92103
class ISSM:
93104
r"""
94105
An abstract class for providing the basic structure of Innovation State Space Model (ISSM).
95106
96107
The structure of ISSM is given by
97108
98109
* dimension of the latent state
99-
* transition and emission coefficents of the transition model
110+
* transition and innovation coefficents of the transition model
100111
* emission coefficient of the observation model
101112
102113
"""
@@ -106,27 +117,30 @@ def __init__(self):
106117
pass
107118

108119
def latent_dim(self) -> int:
109-
raise NotImplemented()
120+
raise NotImplementedError
110121

111122
def output_dim(self) -> int:
112-
raise NotImplemented()
123+
raise NotImplementedError
124+
125+
def time_features(self) -> List[TimeFeature]:
126+
raise NotImplementedError
113127

114-
def emission_coeff(self, seasonal_indicators: Tensor):
115-
raise NotImplemented()
128+
def emission_coeff(self, features: Tensor) -> Tensor:
129+
raise NotImplementedError
116130

117-
def transition_coeff(self, seasonal_indicators: Tensor):
118-
raise NotImplemented()
131+
def transition_coeff(self, features: Tensor) -> Tensor:
132+
raise NotImplementedError
119133

120-
def innovation_coeff(self, seasonal_indicators: Tensor):
121-
raise NotImplemented()
134+
def innovation_coeff(self, features: Tensor) -> Tensor:
135+
raise NotImplementedError
122136

123137
def get_issm_coeff(
124-
self, seasonal_indicators: Tensor
138+
self, features: Tensor
125139
) -> Tuple[Tensor, Tensor, Tensor]:
126140
return (
127-
self.emission_coeff(seasonal_indicators),
128-
self.transition_coeff(seasonal_indicators),
129-
self.innovation_coeff(seasonal_indicators),
141+
self.emission_coeff(features),
142+
self.transition_coeff(features),
143+
self.innovation_coeff(features),
130144
)
131145

132146

@@ -137,52 +151,47 @@ def latent_dim(self) -> int:
137151
def output_dim(self) -> int:
138152
return 1
139153

154+
def time_features(self) -> List[TimeFeature]:
155+
return [ZeroFeature()]
156+
140157
def emission_coeff(
141-
self, seasonal_indicators: Tensor # (batch_size, time_length)
158+
self, feature: Tensor # (batch_size, time_length)
142159
) -> Tensor:
143-
F = getF(seasonal_indicators)
160+
F = getF(feature)
144161

145162
_emission_coeff = F.ones(shape=(1, 1, 1, self.latent_dim()))
146163

147164
# get the right shape: (batch_size, seq_length, obs_dim, latent_dim)
148165
zeros = _broadcast_param(
149-
F.zeros_like(
150-
seasonal_indicators.slice_axis(
151-
axis=-1, begin=0, end=1
152-
).squeeze(axis=-1)
153-
),
166+
feature.slice_axis(axis=-1, begin=0, end=1).squeeze(axis=-1),
154167
axes=[2, 3],
155168
sizes=[1, self.latent_dim()],
156169
)
157170

158171
return _emission_coeff.broadcast_like(zeros)
159172

160173
def transition_coeff(
161-
self, seasonal_indicators: Tensor # (batch_size, time_length)
174+
self, feature: Tensor # (batch_size, time_length)
162175
) -> Tensor:
163-
F = getF(seasonal_indicators)
176+
F = getF(feature)
164177

165178
_transition_coeff = (
166179
F.eye(self.latent_dim()).expand_dims(axis=0).expand_dims(axis=0)
167180
)
168181

169182
# get the right shape: (batch_size, seq_length, latent_dim, latent_dim)
170183
zeros = _broadcast_param(
171-
F.zeros_like(
172-
seasonal_indicators.slice_axis(
173-
axis=-1, begin=0, end=1
174-
).squeeze(axis=-1)
175-
),
184+
feature.slice_axis(axis=-1, begin=0, end=1).squeeze(axis=-1),
176185
axes=[2, 3],
177186
sizes=[self.latent_dim(), self.latent_dim()],
178187
)
179188

180189
return _transition_coeff.broadcast_like(zeros)
181190

182191
def innovation_coeff(
183-
self, seasonal_indicators: Tensor # (batch_size, time_length)
192+
self, feature: Tensor # (batch_size, time_length)
184193
) -> Tensor:
185-
return self.emission_coeff(seasonal_indicators).squeeze(axis=2)
194+
return self.emission_coeff(feature).squeeze(axis=2)
186195

187196

188197
class LevelTrendISSM(LevelISSM):
@@ -192,10 +201,13 @@ def latent_dim(self) -> int:
192201
def output_dim(self) -> int:
193202
return 1
194203

204+
def time_features(self) -> List[TimeFeature]:
205+
return [ZeroFeature()]
206+
195207
def transition_coeff(
196-
self, seasonal_indicators: Tensor # (batch_size, time_length)
208+
self, feature: Tensor # (batch_size, time_length)
197209
) -> Tensor:
198-
F = getF(seasonal_indicators)
210+
F = getF(feature)
199211

200212
_transition_coeff = (
201213
(F.diag(F.ones(shape=(2,)), k=0) + F.diag(F.ones(shape=(1,)), k=1))
@@ -205,11 +217,7 @@ def transition_coeff(
205217

206218
# get the right shape: (batch_size, seq_length, latent_dim, latent_dim)
207219
zeros = _broadcast_param(
208-
F.zeros_like(
209-
seasonal_indicators.slice_axis(
210-
axis=-1, begin=0, end=1
211-
).squeeze(axis=-1)
212-
),
220+
feature.slice_axis(axis=-1, begin=0, end=1).squeeze(axis=-1),
213221
axes=[2, 3],
214222
sizes=[self.latent_dim(), self.latent_dim()],
215223
)
@@ -223,26 +231,28 @@ class SeasonalityISSM(LevelISSM):
223231
"""
224232

225233
@validated()
226-
def __init__(self, num_seasons: int) -> None:
234+
def __init__(self, num_seasons: int, time_feature: TimeFeature) -> None:
227235
super(SeasonalityISSM, self).__init__()
228236
self.num_seasons = num_seasons
237+
self.time_feature = time_feature
229238

230239
def latent_dim(self) -> int:
231240
return self.num_seasons
232241

233242
def output_dim(self) -> int:
234243
return 1
235244

236-
def emission_coeff(self, seasonal_indicators: Tensor) -> Tensor:
237-
F = getF(seasonal_indicators)
238-
return F.one_hot(seasonal_indicators, depth=self.latent_dim())
245+
def time_features(self) -> List[TimeFeature]:
246+
return [self.time_feature]
239247

240-
def innovation_coeff(self, seasonal_indicators: Tensor) -> Tensor:
241-
F = getF(seasonal_indicators)
242-
# seasonal_indicators = F.modulo(seasonal_indicators - 1, self.latent_dim)
243-
return F.one_hot(seasonal_indicators, depth=self.latent_dim()).squeeze(
244-
axis=2
245-
)
248+
def emission_coeff(self, feature: Tensor) -> Tensor:
249+
F = getF(feature)
250+
return F.one_hot(feature, depth=self.latent_dim())
251+
252+
def innovation_coeff(self, feature: Tensor) -> Tensor:
253+
F = getF(feature)
254+
# feature = F.modulo(feature - 1, self.latent_dim)
255+
return F.one_hot(feature, depth=self.latent_dim()).squeeze(axis=2)
246256

247257

248258
class CompositeISSM(ISSM):
@@ -269,6 +279,12 @@ def latent_dim(self) -> int:
269279
def output_dim(self) -> int:
270280
return self.nonseasonal_issm.output_dim()
271281

282+
def time_features(self) -> List[TimeFeature]:
283+
ans = self.nonseasonal_issm.time_features()
284+
for issm in self.seasonal_issms:
285+
ans.extend(issm.time_features())
286+
return ans
287+
272288
@classmethod
273289
def get_from_freq(cls, freq: str, add_trend: bool = DEFAULT_ADD_TREND):
274290
offset = to_offset(freq)
@@ -277,71 +293,63 @@ def get_from_freq(cls, freq: str, add_trend: bool = DEFAULT_ADD_TREND):
277293

278294
if offset.name == "M":
279295
seasonal_issms = [
280-
SeasonalityISSM(num_seasons=12) # month-of-year seasonality
296+
SeasonalityISSM( # month-of-year seasonality
297+
num_seasons=12, time_feature=MonthOfYear(normalized=False)
298+
)
281299
]
282300
elif offset.name == "W-SUN":
283301
seasonal_issms = [
284-
SeasonalityISSM(num_seasons=53) # week-of-year seasonality
302+
SeasonalityISSM( # week-of-year seasonality
303+
num_seasons=53, time_feature=WeekOfYear(normalized=False)
304+
)
285305
]
286306
elif offset.name == "D":
287307
seasonal_issms = [
288-
SeasonalityISSM(num_seasons=7)
289-
] # day-of-week seasonality
308+
SeasonalityISSM( # day-of-week seasonality
309+
num_seasons=7, time_feature=DayOfWeek(normalized=False)
310+
)
311+
]
290312
elif offset.name == "B": # TODO: check this case
291313
seasonal_issms = [
292-
SeasonalityISSM(num_seasons=7)
293-
] # day-of-week seasonality
314+
SeasonalityISSM( # day-of-week seasonality
315+
num_seasons=7, time_feature=DayOfWeek(normalized=False)
316+
)
317+
]
294318
elif offset.name == "H":
295319
seasonal_issms = [
296-
SeasonalityISSM(num_seasons=24), # hour-of-day seasonality
297-
SeasonalityISSM(num_seasons=7), # day-of-week seasonality
320+
SeasonalityISSM( # hour-of-day seasonality
321+
num_seasons=24, time_feature=HourOfDay(normalized=False)
322+
),
323+
SeasonalityISSM( # day-of-week seasonality
324+
num_seasons=7, time_feature=DayOfWeek(normalized=False)
325+
),
298326
]
299327
elif offset.name == "T":
300328
seasonal_issms = [
301-
SeasonalityISSM(num_seasons=60), # minute-of-hour seasonality
302-
SeasonalityISSM(num_seasons=24), # hour-of-day seasonality
329+
SeasonalityISSM( # minute-of-hour seasonality
330+
num_seasons=60, time_feature=MinuteOfHour(normalized=False)
331+
),
332+
SeasonalityISSM( # hour-of-day seasonality
333+
num_seasons=24, time_features=HourOfDay(normalized=False)
334+
),
303335
]
304336
else:
305337
RuntimeError(f"Unsupported frequency {offset.name}")
306338

307339
return cls(seasonal_issms=seasonal_issms, add_trend=add_trend)
308340

309-
@classmethod
310-
def seasonal_features(cls, freq: str) -> List[TimeFeature]:
311-
offset = to_offset(freq)
312-
if offset.name == "M":
313-
return [MonthOfYear(normalized=False)]
314-
elif offset.name == "W-SUN":
315-
return [WeekOfYear(normalized=False)]
316-
elif offset.name == "D":
317-
return [DayOfWeek(normalized=False)]
318-
elif offset.name == "B": # TODO: check this case
319-
return [DayOfWeek(normalized=False)]
320-
elif offset.name == "H":
321-
return [HourOfDay(normalized=False), DayOfWeek(normalized=False)]
322-
elif offset.name == "T":
323-
return [
324-
MinuteOfHour(normalized=False),
325-
HourOfDay(normalized=False),
326-
]
327-
else:
328-
RuntimeError(f"Unsupported frequency {offset.name}")
329-
330-
return []
331-
332341
def get_issm_coeff(
333-
self, seasonal_indicators: Tensor # (batch_size, time_length)
342+
self, features: Tensor # (batch_size, time_length)
334343
) -> Tuple[Tensor, Tensor, Tensor]:
335-
F = getF(seasonal_indicators)
344+
F = getF(features)
336345
emission_coeff_ls, transition_coeff_ls, innovation_coeff_ls = zip(
337-
self.nonseasonal_issm.get_issm_coeff(seasonal_indicators),
338346
*[
339347
issm.get_issm_coeff(
340-
seasonal_indicators.slice_axis(
341-
axis=-1, begin=ix, end=ix + 1
342-
)
348+
features.slice_axis(axis=-1, begin=ix, end=ix + 1)
349+
)
350+
for ix, issm in enumerate(
351+
[self.nonseasonal_issm] + self.seasonal_issms
343352
)
344-
for ix, issm in enumerate(self.seasonal_issms)
345353
],
346354
)
347355

0 commit comments

Comments
 (0)