Skip to content

Commit

Permalink
Refactor sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Jun 13, 2023
1 parent fa0e40b commit 0b5ab44
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions rwkv/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8
if top_p == 0.0:
top_p = 1.0

if logit_bias is not None:
if logit_bias is not None and len(logit_bias) > 0:
logits = np.log(probs)

if len(logit_bias) > 0:
ids, values = zip(*logit_bias.items())
logits[list(ids)] += values

ids, values = zip(*logit_bias.items())
logits[list(ids)] += values

# Makes calculation more numerically stable, does not change the result
logits -= logits.max(axis=-1, keepdims=True)

probs = np.exp(logits) / np.sum(np.exp(logits))

if temperature == 0.0:
Expand Down

0 comments on commit 0b5ab44

Please sign in to comment.