Skip to content

Commit 380fa0f

Browse files
committed
fixed broken typical sampler issues
1 parent cf5d918 commit 380fa0f

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

gpttype_adapter.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
366366
}
367367
}
368368

369+
static float LowestLogit(const std::vector<float> & logits)
370+
{
371+
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
372+
return (logits[topid] < 0 ? logits[topid] : 0);
373+
}
374+
static float LowestLogit(const float *logits, size_t size)
375+
{
376+
if (size == 0) {
377+
// Handle the case of an empty array
378+
return 0.0;
379+
}
380+
int topid = std::min_element(logits, logits + size) - logits;
381+
return (logits[topid] < 0 ? logits[topid] : 0);
382+
}
383+
369384
static std::string RemoveBell(const std::string & input) //removes the bell character
370385
{
371386
std::string word2;
@@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
14421457
eosID = llama_v3_token_eos();
14431458
}
14441459

1460+
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
14451461
if (!unbanTokens)
14461462
{
14471463
// set the logit of the eos token (2) to -INF to avoid sampling it
1448-
logitsPtr[eosID] = -INFINITY;
1464+
logitsPtr[eosID] = lowestLogit;
14491465
}
14501466

14511467
if(btsize>0)
14521468
{
14531469
for(int t=0;t<btsize;++t)
14541470
{
1455-
logitsPtr[banned_token_ids[t]]=-INFINITY;
1471+
logitsPtr[banned_token_ids[t]]=lowestLogit;
14561472
}
14571473
}
14581474
}
14591475
else
14601476
{
14611477
logitsPtr = logits.data();
1478+
float lowestLogit = LowestLogit(logits);
14621479
if (!unbanTokens)
14631480
{
14641481
//gpt2 uses negative logits, so we cant zero it
@@ -1474,17 +1491,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
14741491
file_format == FileFormat::GPTJ_5)
14751492
{
14761493
eosID = 50256;
1494+
14771495
if(logits.size() > eosID)
14781496
{
1479-
logits[eosID] = -INFINITY;
1497+
logits[eosID] = lowestLogit;
14801498
}
14811499
else
14821500
{
14831501
//special case, starcoder models use ID 0 for EOS
14841502
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
14851503
{
14861504
eosID = 0;
1487-
logits[eosID] = -INFINITY;
1505+
logits[eosID] = lowestLogit;
1506+
14881507
}
14891508
}
14901509
}
@@ -1502,15 +1521,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
15021521
file_format == FileFormat::MPT_1)
15031522
{
15041523
eosID = 0;
1505-
logits[eosID] = -INFINITY;
1524+
logits[eosID] = lowestLogit;
15061525
}
15071526
}
15081527

15091528
if(btsize>0)
15101529
{
15111530
for (int t = 0; t < btsize; ++t)
15121531
{
1513-
logits[banned_token_ids[t]] = -INFINITY;
1532+
logits[banned_token_ids[t]] = lowestLogit;
15141533
}
15151534
}
15161535
}

0 commit comments

Comments
 (0)