Skip to content

Commit

Permalink
Fix issues with server send_embeddings function
Browse files Browse the repository at this point in the history
Fixes #404
  • Loading branch information
jart committed May 8, 2024
1 parent 22aba95 commit 0e2845a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 65 deletions.
60 changes: 29 additions & 31 deletions llama.cpp/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17130,43 +17130,41 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embd;
}

static float * llama_get_embeddings_ith_fail(int i, std::string reason) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason);
return nullptr;
}

float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;

llama_synchronize(ctx);

try {
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}

if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}

// [jart] DO NOT SYNC this function
if (ctx->embd == nullptr) {
return llama_get_embeddings_ith_fail(i, "no embeddings");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
return llama_get_embeddings_ith_fail(
i, format("negative index out of range [0, %d)", ctx->n_outputs));
}

return ctx->embd + j*ctx->model.hparams.n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_embeddings_ith_fail(
i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_embeddings_ith_fail(
i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_embeddings_ith_fail(
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
}

float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
Expand Down
64 changes: 31 additions & 33 deletions llama.cpp/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,42 +1293,40 @@ struct llama_server_context
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;

const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
LOG_WARNING("embedding disabled", {
{"params.embedding", params.embedding},
});
res.result_json = json
{
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
}
else
{
std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; i++) {
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
LOG_ERROR("failed to get embeddings", {
int n_embd = llama_n_embd(model);
if (n_embd > 16777216u) {
LOG_ERROR("model has more than 2**24 embeddings (please report this)", {{"n_embd", n_embd}});
n_embd = 0;
}
std::vector<float> embd_res(n_embd);
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i])
continue;
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL)
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings (please report this)", {
{"token", batch.token [i]},
{"seq_id", batch.seq_id[i][0]}
});
res.result_json = json {
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
continue;
}
llama_embd_normalize(embd, embd_res.data(), n_embd);
res.result_json = json {
{"embedding", embd_res},
};
continue;
}
float * beg = &embd_res[0];
float * end = beg + embd_res.size();
float * out = beg + batch.seq_id[i][0] * n_embd;
if (beg <= out && out + n_embd <= end) {
llama_embd_normalize(embd, out, n_embd);
} else {
LOG_ERROR("embeddings out of bounds (please report this)", {
{"token", batch.token [i]},
{"seq_id", batch.seq_id[i][0]}
});
continue;
}
res.result_json = json {
{"embedding", embd_res},
};
}
queue_results.send(res);
}
Expand Down Expand Up @@ -1511,7 +1509,7 @@ struct llama_server_context
{
// if no slot is available, we defer this task for processing later
LOG_VERBOSE("no slot is available", {{"task_id", task.id}});
queue_tasks.defer(task);
queue_tasks.defer_(task);
break;
}

Expand Down
2 changes: 1 addition & 1 deletion llama.cpp/server/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ struct llama_server_queue {
}

// Add a new task, but defer until one slot is available
void defer(task_server task) {
void defer_(task_server task) {
std::unique_lock<std::mutex> lock(mutex_tasks);
queue_tasks_deferred.push_back(std::move(task));
}
Expand Down

0 comments on commit 0e2845a

Please sign in to comment.