Skip to content

Commit 06e895a

Browse files
lostellaabdulfatir
andauthored
Backports for v0.16.2 (#3260)
*Description of changes:* backport changes - #3259 - #3261 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup --------- Co-authored-by: Abdul Fatir <[email protected]>
1 parent a164967 commit 06e895a

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
rpy2~=3.5
1+
rpy2~=3.5,<3.6
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
# scipy cap can be removed once this is resolved: https://github.com/statsmodels/statsmodels/issues/9584
2+
scipy<1.16.0; python_version > "3.7.0"
3+
scipy~=1.7.3; python_version <= "3.7.0"
14
statsforecast~=1.0

requirements/requirements-pytorch.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ torch>=1.9,<3
22
lightning>=2.2.2,<2.5
33
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
44
pytorch_lightning>=2.2.2,<2.5
5-
scipy~=1.10; python_version > "3.7.0"
5+
# scipy cap can be removed once this is resolved: https://github.com/statsmodels/statsmodels/issues/9584
6+
scipy<1.16.0; python_version > "3.7.0"
67
scipy~=1.7.3; python_version <= "3.7.0"

src/gluonts/transform/split.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def __init__(
494494
is_pad_field: str = FieldName.IS_PAD,
495495
start_field: str = FieldName.START,
496496
forecast_start_field: str = FieldName.FORECAST_START,
497-
observed_value_field: str = FieldName.OBSERVED_VALUES,
497+
observed_value_field: Optional[str] = FieldName.OBSERVED_VALUES,
498498
lead_time: int = 0,
499499
output_NTC: bool = True,
500500
time_series_fields: List[str] = [],
@@ -529,11 +529,9 @@ def flatmap_transform(
529529

530530
sampled_indices = self.instance_sampler(target)
531531

532-
slice_cols = (
533-
self.ts_fields
534-
+ self.past_ts_fields
535-
+ [self.target_field, self.observed_value_field]
536-
)
532+
slice_cols = self.ts_fields + self.past_ts_fields + [self.target_field]
533+
if self.observed_value_field is not None:
534+
slice_cols.append(self.observed_value_field)
537535
for i in sampled_indices:
538536
pad_length = max(self.past_length - i, 0)
539537
d = data.copy()

0 commit comments

Comments
 (0)