@@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
366
366
}
367
367
}
368
368
369
+ static float LowestLogit (const std::vector<float > & logits)
370
+ {
371
+ int topid = std::min_element (logits.begin (), logits.end ()) - logits.begin ();
372
+ return (logits[topid] < 0 ? logits[topid] : 0 );
373
+ }
374
+ static float LowestLogit (const float *logits, size_t size)
375
+ {
376
+ if (size == 0 ) {
377
+ // Handle the case of an empty array
378
+ return 0.0 ;
379
+ }
380
+ int topid = std::min_element (logits, logits + size) - logits;
381
+ return (logits[topid] < 0 ? logits[topid] : 0 );
382
+ }
383
+
369
384
static std::string RemoveBell (const std::string & input) // removes the bell character
370
385
{
371
386
std::string word2;
@@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1442
1457
eosID = llama_v3_token_eos ();
1443
1458
}
1444
1459
1460
+ float lowestLogit = LowestLogit (logitsPtr,n_vocab);
1445
1461
if (!unbanTokens)
1446
1462
{
1447
1463
// set the logit of the eos token (2) to -INF to avoid sampling it
1448
- logitsPtr[eosID] = -INFINITY ;
1464
+ logitsPtr[eosID] = lowestLogit ;
1449
1465
}
1450
1466
1451
1467
if (btsize>0 )
1452
1468
{
1453
1469
for (int t=0 ;t<btsize;++t)
1454
1470
{
1455
- logitsPtr[banned_token_ids[t]]=-INFINITY ;
1471
+ logitsPtr[banned_token_ids[t]]=lowestLogit ;
1456
1472
}
1457
1473
}
1458
1474
}
1459
1475
else
1460
1476
{
1461
1477
logitsPtr = logits.data ();
1478
+ float lowestLogit = LowestLogit (logits);
1462
1479
if (!unbanTokens)
1463
1480
{
1464
1481
// gpt2 uses negative logits, so we cant zero it
@@ -1474,17 +1491,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1474
1491
file_format == FileFormat::GPTJ_5)
1475
1492
{
1476
1493
eosID = 50256 ;
1494
+
1477
1495
if (logits.size () > eosID)
1478
1496
{
1479
- logits[eosID] = -INFINITY ;
1497
+ logits[eosID] = lowestLogit ;
1480
1498
}
1481
1499
else
1482
1500
{
1483
1501
// special case, starcoder models use ID 0 for EOS
1484
1502
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
1485
1503
{
1486
1504
eosID = 0 ;
1487
- logits[eosID] = -INFINITY;
1505
+ logits[eosID] = lowestLogit;
1506
+
1488
1507
}
1489
1508
}
1490
1509
}
@@ -1502,15 +1521,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1502
1521
file_format == FileFormat::MPT_1)
1503
1522
{
1504
1523
eosID = 0 ;
1505
- logits[eosID] = -INFINITY ;
1524
+ logits[eosID] = lowestLogit ;
1506
1525
}
1507
1526
}
1508
1527
1509
1528
if (btsize>0 )
1510
1529
{
1511
1530
for (int t = 0 ; t < btsize; ++t)
1512
1531
{
1513
- logits[banned_token_ids[t]] = -INFINITY ;
1532
+ logits[banned_token_ids[t]] = lowestLogit ;
1514
1533
}
1515
1534
}
1516
1535
}
0 commit comments