Skip to content

Commit e1c0939

Browse files
committed
improve issm api
1 parent 48b2ded commit e1c0939

File tree

3 files changed

+463
-91
lines changed

3 files changed

+463
-91
lines changed

src/gluonts/model/deepstate/_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def create_transformation(self) -> Transformation:
269269
),
270270
# Unnormalized seasonal features
271271
AddTimeFeatures(
272-
time_features=CompositeISSM.seasonal_features(self.freq),
272+
time_features=self.issm.time_features(),
273273
pred_length=self.prediction_length,
274274
start_field=FieldName.START,
275275
target_field=FieldName.TARGET,

src/gluonts/model/deepstate/issm.py

Lines changed: 93 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from typing import List, Tuple
1515

16+
import numpy as np
17+
import pandas as pd
1618
from pandas.tseries.frequencies import to_offset
1719

1820
from gluonts.core.component import validated
@@ -89,14 +91,23 @@ def _make_2_block_diagonal(F, left: Tensor, right: Tensor) -> Tensor:
8991
return _block_diagonal
9092

9193

94+
class ZeroFeature(TimeFeature):
95+
"""
96+
A feature that is identically zero.
97+
"""
98+
99+
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
100+
return np.zeros(index.values.shape)
101+
102+
92103
class ISSM:
93104
r"""
94105
An abstract class for providing the basic structure of Innovation State Space Model (ISSM).
95106
96107
The structure of ISSM is given by
97108
98109
* dimension of the latent state
99-
* transition and emission coefficents of the transition model
110+
* transition and innovation coefficents of the transition model
100111
* emission coefficient of the observation model
101112
102113
"""
@@ -111,22 +122,25 @@ def latent_dim(self) -> int:
111122
def output_dim(self) -> int:
112123
raise NotImplementedError
113124

114-
def emission_coeff(self, seasonal_indicators: Tensor):
125+
def time_features(self) -> List[TimeFeature]:
115126
raise NotImplementedError
116127

117-
def transition_coeff(self, seasonal_indicators: Tensor):
128+
def emission_coeff(self, features: Tensor) -> Tensor:
118129
raise NotImplementedError
119130

120-
def innovation_coeff(self, seasonal_indicators: Tensor):
131+
def transition_coeff(self, features: Tensor) -> Tensor:
132+
raise NotImplementedError
133+
134+
def innovation_coeff(self, features: Tensor) -> Tensor:
121135
raise NotImplementedError
122136

123137
def get_issm_coeff(
124-
self, seasonal_indicators: Tensor
138+
self, features: Tensor
125139
) -> Tuple[Tensor, Tensor, Tensor]:
126140
return (
127-
self.emission_coeff(seasonal_indicators),
128-
self.transition_coeff(seasonal_indicators),
129-
self.innovation_coeff(seasonal_indicators),
141+
self.emission_coeff(features),
142+
self.transition_coeff(features),
143+
self.innovation_coeff(features),
130144
)
131145

132146

@@ -137,52 +151,47 @@ def latent_dim(self) -> int:
137151
def output_dim(self) -> int:
138152
return 1
139153

154+
def time_features(self) -> List[TimeFeature]:
155+
return [ZeroFeature()]
156+
140157
def emission_coeff(
141-
self, seasonal_indicators: Tensor # (batch_size, time_length)
158+
self, feature: Tensor # (batch_size, time_length, 1)
142159
) -> Tensor:
143-
F = getF(seasonal_indicators)
160+
F = getF(feature)
144161

145162
_emission_coeff = F.ones(shape=(1, 1, 1, self.latent_dim()))
146163

147-
# get the right shape: (batch_size, seq_length, obs_dim, latent_dim)
164+
# get the right shape: (batch_size, time_length, obs_dim, latent_dim)
148165
zeros = _broadcast_param(
149-
F.zeros_like(
150-
seasonal_indicators.slice_axis(
151-
axis=-1, begin=0, end=1
152-
).squeeze(axis=-1)
153-
),
166+
feature.squeeze(axis=2),
154167
axes=[2, 3],
155168
sizes=[1, self.latent_dim()],
156169
)
157170

158171
return _emission_coeff.broadcast_like(zeros)
159172

160173
def transition_coeff(
161-
self, seasonal_indicators: Tensor # (batch_size, time_length)
174+
self, feature: Tensor # (batch_size, time_length, 1)
162175
) -> Tensor:
163-
F = getF(seasonal_indicators)
176+
F = getF(feature)
164177

165178
_transition_coeff = (
166179
F.eye(self.latent_dim()).expand_dims(axis=0).expand_dims(axis=0)
167180
)
168181

169-
# get the right shape: (batch_size, seq_length, latent_dim, latent_dim)
182+
# get the right shape: (batch_size, time_length, latent_dim, latent_dim)
170183
zeros = _broadcast_param(
171-
F.zeros_like(
172-
seasonal_indicators.slice_axis(
173-
axis=-1, begin=0, end=1
174-
).squeeze(axis=-1)
175-
),
184+
feature.squeeze(axis=2),
176185
axes=[2, 3],
177186
sizes=[self.latent_dim(), self.latent_dim()],
178187
)
179188

180189
return _transition_coeff.broadcast_like(zeros)
181190

182191
def innovation_coeff(
183-
self, seasonal_indicators: Tensor # (batch_size, time_length)
192+
self, feature: Tensor # (batch_size, time_length, 1)
184193
) -> Tensor:
185-
return self.emission_coeff(seasonal_indicators).squeeze(axis=2)
194+
return self.emission_coeff(feature).squeeze(axis=2)
186195

187196

188197
class LevelTrendISSM(LevelISSM):
@@ -192,24 +201,23 @@ def latent_dim(self) -> int:
192201
def output_dim(self) -> int:
193202
return 1
194203

204+
def time_features(self) -> List[TimeFeature]:
205+
return [ZeroFeature()]
206+
195207
def transition_coeff(
196-
self, seasonal_indicators: Tensor # (batch_size, time_length)
208+
self, feature: Tensor # (batch_size, time_length, 1)
197209
) -> Tensor:
198-
F = getF(seasonal_indicators)
210+
F = getF(feature)
199211

200212
_transition_coeff = (
201213
(F.diag(F.ones(shape=(2,)), k=0) + F.diag(F.ones(shape=(1,)), k=1))
202214
.expand_dims(axis=0)
203215
.expand_dims(axis=0)
204216
)
205217

206-
# get the right shape: (batch_size, seq_length, latent_dim, latent_dim)
218+
# get the right shape: (batch_size, time_length, latent_dim, latent_dim)
207219
zeros = _broadcast_param(
208-
F.zeros_like(
209-
seasonal_indicators.slice_axis(
210-
axis=-1, begin=0, end=1
211-
).squeeze(axis=-1)
212-
),
220+
feature.squeeze(axis=2),
213221
axes=[2, 3],
214222
sizes=[self.latent_dim(), self.latent_dim()],
215223
)
@@ -223,26 +231,47 @@ class SeasonalityISSM(LevelISSM):
223231
"""
224232

225233
@validated()
226-
def __init__(self, num_seasons: int) -> None:
234+
def __init__(self, num_seasons: int, time_feature: TimeFeature) -> None:
227235
super(SeasonalityISSM, self).__init__()
228236
self.num_seasons = num_seasons
237+
self.time_feature = time_feature
229238

230239
def latent_dim(self) -> int:
231240
return self.num_seasons
232241

233242
def output_dim(self) -> int:
234243
return 1
235244

236-
def emission_coeff(self, seasonal_indicators: Tensor) -> Tensor:
237-
F = getF(seasonal_indicators)
238-
return F.one_hot(seasonal_indicators, depth=self.latent_dim())
245+
def time_features(self) -> List[TimeFeature]:
246+
return [self.time_feature]
239247

240-
def innovation_coeff(self, seasonal_indicators: Tensor) -> Tensor:
241-
F = getF(seasonal_indicators)
242-
# seasonal_indicators = F.modulo(seasonal_indicators - 1, self.latent_dim)
243-
return F.one_hot(seasonal_indicators, depth=self.latent_dim()).squeeze(
244-
axis=2
245-
)
248+
def emission_coeff(self, feature: Tensor) -> Tensor:
249+
F = getF(feature)
250+
return F.one_hot(feature, depth=self.latent_dim())
251+
252+
def innovation_coeff(self, feature: Tensor) -> Tensor:
253+
F = getF(feature)
254+
return F.one_hot(feature, depth=self.latent_dim()).squeeze(axis=2)
255+
256+
257+
def MonthOfYearSeasonalISSM():
258+
return SeasonalityISSM(num_seasons=12, time_feature=MonthOfYearIndex())
259+
260+
261+
def WeekOfYearSeasonalISSM():
262+
return SeasonalityISSM(num_seasons=53, time_feature=WeekOfYearIndex())
263+
264+
265+
def DayOfWeekSeasonalISSM():
266+
return SeasonalityISSM(num_seasons=7, time_feature=DayOfWeekIndex())
267+
268+
269+
def HourOfDaySeasonalISSM():
270+
return SeasonalityISSM(num_seasons=24, time_feature=HourOfDayIndex())
271+
272+
273+
def MinuteOfHourSeasonalISSM():
274+
return SeasonalityISSM(num_seasons=60, time_feature=MinuteOfHourIndex())
246275

247276

248277
class CompositeISSM(ISSM):
@@ -269,79 +298,53 @@ def latent_dim(self) -> int:
269298
def output_dim(self) -> int:
270299
return self.nonseasonal_issm.output_dim()
271300

301+
def time_features(self) -> List[TimeFeature]:
302+
ans = self.nonseasonal_issm.time_features()
303+
for issm in self.seasonal_issms:
304+
ans.extend(issm.time_features())
305+
return ans
306+
272307
@classmethod
273308
def get_from_freq(cls, freq: str, add_trend: bool = DEFAULT_ADD_TREND):
274309
offset = to_offset(freq)
275310

276311
seasonal_issms: List[SeasonalityISSM] = []
277312

278313
if offset.name == "M":
279-
seasonal_issms = [
280-
SeasonalityISSM(num_seasons=12) # month-of-year seasonality
281-
]
314+
seasonal_issms = [MonthOfYearSeasonalISSM()]
282315
elif offset.name == "W-SUN":
283-
seasonal_issms = [
284-
SeasonalityISSM(num_seasons=53) # week-of-year seasonality
285-
]
316+
seasonal_issms = [WeekOfYearSeasonalISSM()]
286317
elif offset.name == "D":
287-
seasonal_issms = [
288-
SeasonalityISSM(num_seasons=7)
289-
] # day-of-week seasonality
318+
seasonal_issms = [DayOfWeekSeasonalISSM()]
290319
elif offset.name == "B": # TODO: check this case
291-
seasonal_issms = [
292-
SeasonalityISSM(num_seasons=7)
293-
] # day-of-week seasonality
320+
seasonal_issms = [DayOfWeekSeasonalISSM()]
294321
elif offset.name == "H":
295322
seasonal_issms = [
296-
SeasonalityISSM(num_seasons=24), # hour-of-day seasonality
297-
SeasonalityISSM(num_seasons=7), # day-of-week seasonality
323+
HourOfDaySeasonalISSM(),
324+
DayOfWeekSeasonalISSM(),
298325
]
299326
elif offset.name == "T":
300327
seasonal_issms = [
301-
SeasonalityISSM(num_seasons=60), # minute-of-hour seasonality
302-
SeasonalityISSM(num_seasons=24), # hour-of-day seasonality
328+
MinuteOfHourSeasonalISSM(),
329+
HourOfDaySeasonalISSM(),
303330
]
304331
else:
305332
RuntimeError(f"Unsupported frequency {offset.name}")
306333

307334
return cls(seasonal_issms=seasonal_issms, add_trend=add_trend)
308335

309-
@classmethod
310-
def seasonal_features(cls, freq: str) -> List[TimeFeature]:
311-
offset = to_offset(freq)
312-
if offset.name == "M":
313-
return [MonthOfYearIndex()]
314-
elif offset.name == "W-SUN":
315-
return [WeekOfYearIndex()]
316-
elif offset.name == "D":
317-
return [DayOfWeekIndex()]
318-
elif offset.name == "B": # TODO: check this case
319-
return [DayOfWeekIndex()]
320-
elif offset.name == "H":
321-
return [HourOfDayIndex(), DayOfWeekIndex()]
322-
elif offset.name == "T":
323-
return [
324-
MinuteOfHourIndex(),
325-
HourOfDayIndex(),
326-
]
327-
else:
328-
RuntimeError(f"Unsupported frequency {offset.name}")
329-
330-
return []
331-
332336
def get_issm_coeff(
333-
self, seasonal_indicators: Tensor # (batch_size, time_length)
337+
self, features: Tensor # (batch_size, time_length, num_features)
334338
) -> Tuple[Tensor, Tensor, Tensor]:
335-
F = getF(seasonal_indicators)
339+
F = getF(features)
336340
emission_coeff_ls, transition_coeff_ls, innovation_coeff_ls = zip(
337-
self.nonseasonal_issm.get_issm_coeff(seasonal_indicators),
338341
*[
339342
issm.get_issm_coeff(
340-
seasonal_indicators.slice_axis(
341-
axis=-1, begin=ix, end=ix + 1
342-
)
343+
features.slice_axis(axis=-1, begin=ix, end=ix + 1)
344+
)
345+
for ix, issm in enumerate(
346+
[self.nonseasonal_issm] + self.seasonal_issms
343347
)
344-
for ix, issm in enumerate(self.seasonal_issms)
345348
],
346349
)
347350

0 commit comments

Comments
 (0)