Skip to content

Commit d02746c

Browse files
committed
Update add script to add a range and prioritise questions asked from 9am-12pm and 8pm-10pm as well as questions asked during the weekend
1 parent 827ec3c commit d02746c

File tree

1 file changed

+72
-23
lines changed

1 file changed

+72
-23
lines changed

core_backend/add_new_data_to_db.py

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
--api-key <API_KEY> \
6666
--nb-workers 8 \
6767
--start-date 01-08-23
68+
--end-date 04-09-24
6869
6970
""",
7071
)
@@ -83,6 +84,16 @@
8384
help="Start date for the records in the format dd-mm-yy",
8485
required=False,
8586
)
87+
parser.add_argument(
88+
"--end-date",
89+
help="End date for the records in the format dd-mm-yy",
90+
required=False,
91+
)
92+
parser.add_argument(
93+
"--subset",
94+
help="Subset of the data to use for testing",
95+
required=False,
96+
)
8697
args = parser.parse_args()
8798

8899

@@ -282,24 +293,60 @@ def process_urgency_detection(_id: int, text: str) -> tuple | None:
282293
return None
283294

284295

285-
def create_random_datetime_from_string(start_date: datetime) -> datetime:
296+
def create_random_datetime(start_date: datetime, end_date: datetime) -> datetime:
286297
"""
287-
Create a random datetime from a date in the format "%d-%m-%y
288-
to today
298+
Create a random datetime from a date within a range
289299
"""
290300

291-
time_difference = datetime.now() - start_date
301+
time_difference = end_date - start_date
292302
random_number_of_days = random.randint(0, time_difference.days)
293303

294-
random_number_of_seconds = random.randint(0, 86399) # Number of seconds in one day
295-
304+
random_number_of_seconds = random.randint(0, 86399)
296305
random_datetime = start_date + timedelta(
297306
days=random_number_of_days, seconds=random_number_of_seconds
298307
)
299308
return random_datetime
300309

301310

302-
def update_date_of_records(models: list, random_dates: list, api_key: str) -> None:
311+
def is_within_time_range(date: datetime) -> bool:
312+
"""
313+
Helper function to check if the date is within desired time range.
314+
Prioritizing 9am-12pm and 8pm-10pm
315+
"""
316+
if 9 <= date.hour < 12 or 20 <= date.hour < 22:
317+
return True
318+
return False
319+
320+
321+
def generate_distributed_dates(n: int, start: datetime, end: datetime) -> list:
322+
"""
323+
Generate dates with a specific distribution for the records
324+
"""
325+
dates: list[datetime] = []
326+
while len(dates) < n:
327+
date = create_random_datetime(start, end)
328+
329+
# More dates on weekends
330+
if date.weekday() >= 5:
331+
332+
if (
333+
is_within_time_range(date) or random.random() < 0.4
334+
): # Within time range or 30% chance
335+
dates.append(date)
336+
else:
337+
if random.random() < 0.6:
338+
if is_within_time_range(date) or random.random() < 0.55:
339+
dates.append(date)
340+
341+
return dates
342+
343+
344+
def update_date_of_records(
345+
models: list,
346+
api_key: str,
347+
start_date: datetime,
348+
end_date: datetime,
349+
) -> None:
303350
"""
304351
Update the date of the records in the database
305352
"""
@@ -309,11 +356,7 @@ def update_date_of_records(models: list, random_dates: list, api_key: str) -> No
309356
select(UserDB).where(UserDB.hashed_api_key == hashed_token)
310357
).scalar_one()
311358
queries = [c for c in session.query(QueryDB).all() if c.user_id == user.user_id]
312-
if len(queries) > len(random_dates):
313-
random_dates = random_dates + [
314-
create_random_datetime_from_string(start_date)
315-
for _ in range(len(queries) - len(random_dates))
316-
]
359+
random_dates = generate_distributed_dates(len(queries), start_date, end_date)
317360
# Create a dictionary to map the query_id to the random date
318361
date_map_dic = {queries[i].query_id: random_dates[i] for i in range(len(queries))}
319362
for model in models:
@@ -324,8 +367,8 @@ def update_date_of_records(models: list, random_dates: list, api_key: str) -> No
324367

325368
for i, row in enumerate(rows):
326369
# Set the date attribute to the random date
327-
if hasattr(row, "query_id"):
328-
date = date_map_dic[row.query_id]
370+
if hasattr(row, "query_id") and model[0] != UrgencyQueryDB:
371+
date = date_map_dic.get(row.query_id, None)
329372
else:
330373
date = random_dates[i]
331374
setattr(row, model[1], date)
@@ -351,17 +394,26 @@ def update_date_of_contents(date: datetime) -> None:
351394
NB_WORKERS = int(args.nb_workers) if args.nb_workers else 8
352395
API_KEY = args.api_key if args.api_key else ADMIN_API_KEY
353396

354-
date_string = args.start_date if args.start_date else "01-08-23"
397+
start_date_string = args.start_date if args.start_date else "01-08-23"
398+
end_date_string = args.end_date if args.end_date else None
355399
date_format = "%d-%m-%y"
356-
start_date = datetime.strptime(date_string, date_format)
400+
start_date = datetime.strptime(start_date_string, date_format)
401+
end_date = (
402+
datetime.strptime(end_date_string, date_format)
403+
if end_date_string
404+
else datetime.now()
405+
)
406+
assert end_date, "Invalid end date. Please provide a valid date. Format is dd-mm-yy"
357407
assert (
358-
start_date and start_date < datetime.now()
359-
), "Invalid start date. Please provide a valid start date."
408+
start_date and start_date < end_date
409+
), "Invalid start date. Please provide a valid start date. Format is dd-mm-yy"
360410

411+
subset = int(args.subset) if args.subset else None
361412
path = args.csv
362-
df = pd.read_csv(path)
413+
df = pd.read_csv(path, nrows=subset)
363414
saved_queries = defaultdict(list)
364415
print("Processing search queries...")
416+
365417
# Using multithreading to speed up the process
366418
with ThreadPoolExecutor(max_workers=NB_WORKERS) as executor:
367419
future_to_text = {
@@ -444,11 +496,8 @@ def update_date_of_contents(date: datetime) -> None:
444496
result = future.result()
445497
print("Urgency Detection successfully processed")
446498

447-
random_dates = [
448-
create_random_datetime_from_string(start_date) for _ in range(len(df))
449-
]
450499
print("Updating the date of the records...")
451-
update_date_of_records(MODELS, random_dates, API_KEY)
500+
update_date_of_records(MODELS, API_KEY, start_date, end_date)
452501

453502
print("Updating the date of the content records...")
454503
update_date_of_contents(start_date)

0 commit comments

Comments
 (0)