Skip to content

Commit b84686b

Browse files
Derek-Wdsyou-n-g
authored andcommitted
Update models to enable save/load
1 parent 6a67082 commit b84686b

12 files changed

+35
-34
lines changed

qlib/contrib/model/pytorch_alstm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130
else:
131131
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
132132

133-
self._fitted = False
133+
self.fitted = False
134134
self.ALSTM_model.to(self.device)
135135

136136
def mse(self, pred, label):
@@ -238,7 +238,7 @@ def fit(
238238

239239
# train
240240
self.logger.info("training...")
241-
self._fitted = True
241+
self.fitted = True
242242

243243
for step in range(self.n_epochs):
244244
self.logger.info("Epoch%d:", step)
@@ -270,7 +270,7 @@ def fit(
270270
torch.cuda.empty_cache()
271271

272272
def predict(self, dataset):
273-
if not self._fitted:
273+
if not self.fitted:
274274
raise ValueError("model is not fitted yet!")
275275

276276
x_test = dataset.prepare("test", col_set="feature")

qlib/contrib/model/pytorch_alstm_ts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
else:
136136
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
137137

138-
self._fitted = False
138+
self.fitted = False
139139
self.ALSTM_model.to(self.device)
140140

141141
def mse(self, pred, label):
@@ -225,7 +225,7 @@ def fit(
225225

226226
# train
227227
self.logger.info("training...")
228-
self._fitted = True
228+
self.fitted = True
229229

230230
for step in range(self.n_epochs):
231231
self.logger.info("Epoch%d:", step)
@@ -257,7 +257,7 @@ def fit(
257257
torch.cuda.empty_cache()
258258

259259
def predict(self, dataset):
260-
if not self._fitted:
260+
if not self.fitted:
261261
raise ValueError("model is not fitted yet!")
262262

263263
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

qlib/contrib/model/pytorch_gats.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
else:
143143
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
144144

145-
self._fitted = False
145+
self.fitted = False
146146
self.GAT_model.to(self.device)
147147

148148
def mse(self, pred, label):
@@ -275,7 +275,7 @@ def fit(
275275

276276
# train
277277
self.logger.info("training...")
278-
self._fitted = True
278+
self.fitted = True
279279

280280
for step in range(self.n_epochs):
281281
self.logger.info("Epoch%d:", step)
@@ -307,7 +307,7 @@ def fit(
307307
torch.cuda.empty_cache()
308308

309309
def predict(self, dataset):
310-
if not self._fitted:
310+
if not self.fitted:
311311
raise ValueError("model is not fitted yet!")
312312

313313
x_test = dataset.prepare("test", col_set="feature")

qlib/contrib/model/pytorch_gats_ts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(
164164
else:
165165
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
166166

167-
self._fitted = False
167+
self.fitted = False
168168
self.GAT_model.to(self.device)
169169

170170
def mse(self, pred, label):
@@ -297,7 +297,7 @@ def fit(
297297

298298
# train
299299
self.logger.info("training...")
300-
self._fitted = True
300+
self.fitted = True
301301

302302
for step in range(self.n_epochs):
303303
self.logger.info("Epoch%d:", step)
@@ -329,7 +329,7 @@ def fit(
329329
torch.cuda.empty_cache()
330330

331331
def predict(self, dataset):
332-
if not self._fitted:
332+
if not self.fitted:
333333
raise ValueError("model is not fitted yet!")
334334

335335
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

qlib/contrib/model/pytorch_gru.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130
else:
131131
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
132132

133-
self._fitted = False
133+
self.fitted = False
134134
self.gru_model.to(self.device)
135135

136136
def mse(self, pred, label):
@@ -238,7 +238,7 @@ def fit(
238238

239239
# train
240240
self.logger.info("training...")
241-
self._fitted = True
241+
self.fitted = True
242242

243243
for step in range(self.n_epochs):
244244
self.logger.info("Epoch%d:", step)
@@ -270,7 +270,7 @@ def fit(
270270
torch.cuda.empty_cache()
271271

272272
def predict(self, dataset):
273-
if not self._fitted:
273+
if not self.fitted:
274274
raise ValueError("model is not fitted yet!")
275275

276276
x_test = dataset.prepare("test", col_set="feature")

qlib/contrib/model/pytorch_gru_ts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
else:
136136
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
137137

138-
self._fitted = False
138+
self.fitted = False
139139
self.GRU_model.to(self.device)
140140

141141
def mse(self, pred, label):
@@ -225,7 +225,7 @@ def fit(
225225

226226
# train
227227
self.logger.info("training...")
228-
self._fitted = True
228+
self.fitted = True
229229

230230
for step in range(self.n_epochs):
231231
self.logger.info("Epoch%d:", step)
@@ -257,7 +257,7 @@ def fit(
257257
torch.cuda.empty_cache()
258258

259259
def predict(self, dataset):
260-
if not self._fitted:
260+
if not self.fitted:
261261
raise ValueError("model is not fitted yet!")
262262

263263
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

qlib/contrib/model/pytorch_lstm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __init__(
130130
else:
131131
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
132132

133-
self._fitted = False
133+
self.fitted = False
134134
self.lstm_model.to(self.device)
135135

136136
def mse(self, pred, label):
@@ -238,7 +238,7 @@ def fit(
238238

239239
# train
240240
self.logger.info("training...")
241-
self._fitted = True
241+
self.fitted = True
242242

243243
for step in range(self.n_epochs):
244244
self.logger.info("Epoch%d:", step)
@@ -270,7 +270,7 @@ def fit(
270270
torch.cuda.empty_cache()
271271

272272
def predict(self, dataset):
273-
if not self._fitted:
273+
if not self.fitted:
274274
raise ValueError("model is not fitted yet!")
275275

276276
x_test = dataset.prepare("test", col_set="feature")

qlib/contrib/model/pytorch_lstm_ts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
else:
136136
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
137137

138-
self._fitted = False
138+
self.fitted = False
139139
self.LSTM_model.to(self.device)
140140

141141
def mse(self, pred, label):
@@ -225,7 +225,7 @@ def fit(
225225

226226
# train
227227
self.logger.info("training...")
228-
self._fitted = True
228+
self.fitted = True
229229

230230
for step in range(self.n_epochs):
231231
self.logger.info("Epoch%d:", step)
@@ -257,7 +257,7 @@ def fit(
257257
torch.cuda.empty_cache()
258258

259259
def predict(self, dataset):
260-
if not self._fitted:
260+
if not self.fitted:
261261
raise ValueError("model is not fitted yet!")
262262

263263
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

qlib/contrib/model/pytorch_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
eps=1e-08,
151151
)
152152

153-
self._fitted = False
153+
self.fitted = False
154154
self.dnn_model.to(self.device)
155155

156156
def fit(
@@ -180,7 +180,7 @@ def fit(
180180
evals_result["valid"] = []
181181
# train
182182
self.logger.info("training...")
183-
self._fitted = True
183+
self.fitted = True
184184
# return
185185
# prepare training data
186186
x_train_values = torch.from_numpy(x_train.values).float()
@@ -265,7 +265,7 @@ def get_loss(self, pred, w, target, loss_type):
265265
raise NotImplementedError("loss {} is not supported!".format(loss_type))
266266

267267
def predict(self, dataset):
268-
if not self._fitted:
268+
if not self.fitted:
269269
raise ValueError("model is not fitted yet!")
270270
x_test_pd = dataset.prepare("test", col_set="feature")
271271
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)

qlib/contrib/model/pytorch_sfm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def __init__(
302302
else:
303303
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
304304

305-
self._fitted = False
305+
self.fitted = False
306306
self.sfm_model.to(self.device)
307307

308308
def test_epoch(self, data_x, data_y):
@@ -386,7 +386,7 @@ def fit(
386386

387387
# train
388388
self.logger.info("training...")
389-
self._fitted = True
389+
self.fitted = True
390390

391391
for step in range(self.n_epochs):
392392
self.logger.info("Epoch%d:", step)
@@ -435,7 +435,7 @@ def metric_fn(self, pred, label):
435435
raise ValueError("unknown metric `%s`" % self.metric)
436436

437437
def predict(self, dataset):
438-
if not self._fitted:
438+
if not self.fitted:
439439
raise ValueError("model is not fitted yet!")
440440

441441
x_test = dataset.prepare("test", col_set="feature")

0 commit comments

Comments
 (0)