Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: pgvector/pgvector-java
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.1.5
Choose a base ref
...
head repository: pgvector/pgvector-java
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v0.1.6
Choose a head ref
  • 11 commits
  • 9 files changed
  • 1 contributor

Commits on Jun 29, 2024

  1. Improved example [skip ci]

    ankane committed Jun 29, 2024
    Copy the full SHA
    b343f4a View commit details

Commits on Jul 16, 2024

  1. Verified

    This commit was signed with the committer’s verified signature.
    hsbt Hiroshi SHIBATA
    Copy the full SHA
    5354be1 View commit details
  2. Added Cohere example [skip ci]

    ankane committed Jul 16, 2024
    Copy the full SHA
    5b3bef1 View commit details
  3. Improved example [skip ci]

    ankane committed Jul 16, 2024
    Copy the full SHA
    afd9160 View commit details
  4. Copy the full SHA
    e788a2f View commit details

Commits on Jul 17, 2024

  1. Improved example [skip ci]

    ankane committed Jul 17, 2024
    Copy the full SHA
    6f92c1a View commit details
  2. Improved overflow handling for PGbit

    ankane committed Jul 17, 2024
    Copy the full SHA
    e187e2c View commit details
  3. Fixed CI

    ankane committed Jul 17, 2024
    Copy the full SHA
    fce784a View commit details
  4. Added test for empty array for PGbit

    ankane committed Jul 17, 2024
    Copy the full SHA
    1fdb455 View commit details
  5. Reverted back to previous exception

    ankane committed Jul 17, 2024
    Copy the full SHA
    c9a5f5d View commit details
  6. Copy the full SHA
    10a6444 View commit details
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ jobs:
# Hibernate 6.4 and HttpClient require Java 11+
- if: ${{ matrix.java == 8 }}
run: |
rm src/test/java/com/pgvector/CohereTest.java
rm src/test/java/com/pgvector/HibernateTest.java
rm src/test/java/com/pgvector/OpenAITest.java
- run: mvn -B -ntp test
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.1.6 (2024-07-17)

- Added `byte[]` constructor to `PGbit`

## 0.1.5 (2024-06-25)

- Added support for `halfvec`, `bit`, and `sparsevec` types
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -14,14 +14,14 @@ For Maven, add to `pom.xml` under `<dependencies>`:
<dependency>
<groupId>com.pgvector</groupId>
<artifactId>pgvector</artifactId>
<version>0.1.5</version>
<version>0.1.6</version>
</dependency>
```

For sbt, add to `build.sbt`:

```sbt
libraryDependencies += "com.pgvector" % "pgvector" % "0.1.5"
libraryDependencies += "com.pgvector" % "pgvector" % "0.1.6"
```

For other build tools, see [this page](https://central.sonatype.com/artifact/com.pgvector/pgvector).
@@ -36,6 +36,8 @@ And follow the instructions for your database library:
Or check out an example:

- [Embeddings](src/test/java/com/pgvector/OpenAITest.java) with OpenAI
- [Binary embeddings](src/test/java/com/pgvector/CohereTest.java) with Cohere
- [Bulk loading](src/test/java/com/pgvector/LoadingTest.java) with `COPY`

## JDBC (Java)

2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
<packaging>jar</packaging>
<description>pgvector support for Java, Kotlin, Groovy, and Scala</description>
<url>https://github.com/pgvector/pgvector-java</url>
<version>0.1.5</version>
<version>0.1.6</version>
<licenses>
<license>
<name>MIT</name>
11 changes: 11 additions & 0 deletions src/main/java/com/pgvector/PGbit.java
Original file line number Diff line number Diff line change
@@ -39,6 +39,17 @@ public PGbit(boolean[] v) {
}
}

/**
* Constructor
*
* @param v byte array
*/
public PGbit(byte[] v) {
this();
length = Math.multiplyExact(v.length, 8);
data = v;
}

/**
* Constructor
*
101 changes: 101 additions & 0 deletions src/test/java/com/pgvector/CohereTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package com.pgvector;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse;
import java.net.http.HttpResponse.BodyHandlers;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.pgvector.PGvector;
import org.postgresql.PGConnection;
import org.junit.jupiter.api.Test;

public class CohereTest {
@Test
void example() throws IOException, InterruptedException, SQLException {
String apiKey = System.getenv("CO_API_KEY");
if (apiKey == null) {
return;
}

Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");

Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
setupStmt.executeUpdate("DROP TABLE IF EXISTS documents");

PGvector.addVectorType(conn);

Statement createStmt = conn.createStatement();
createStmt.executeUpdate("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1024))");

String[] input = {
"The dog is barking",
"The cat is purring",
"The bear is growling"
};
List<byte[]> embeddings = fetchEmbeddings(input, "search_document", apiKey);
for (int i = 0; i < input.length; i++) {
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
insertStmt.setString(1, input[i]);
insertStmt.setObject(2, new PGbit(embeddings.get(i)));
insertStmt.executeUpdate();
}

String query = "forest";
byte[] queryEmbedding = fetchEmbeddings(new String[] {query}, "search_query", apiKey).get(0);
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM documents ORDER BY embedding <~> ? LIMIT 5");
neighborStmt.setObject(1, new PGbit(queryEmbedding));
ResultSet rs = neighborStmt.executeQuery();
while (rs.next()) {
System.out.println(rs.getString("content"));
}

conn.close();
}

// https://docs.cohere.com/reference/embed
private List<byte[]> fetchEmbeddings(String[] texts, String inputType, String apiKey) throws IOException, InterruptedException {
ObjectMapper mapper = new ObjectMapper();
ObjectNode root = mapper.createObjectNode();
for (String v : texts) {
root.withArray("texts").add(v);
}
root.put("model", "embed-english-v3.0");
root.put("input_type", inputType);
root.withArray("embedding_types").add("binary");
String json = mapper.writeValueAsString(root);

HttpClient client = HttpClient.newHttpClient();
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create("https://api.cohere.com/v1/embed"))
.header("Authorization", "Bearer " + apiKey)
.header("Content-Type", "application/json")
.POST(BodyPublishers.ofString(json))
.build();
HttpResponse<String> response = client.send(request, BodyHandlers.ofString());

List<byte[]> embeddings = new ArrayList<>();
for (JsonNode n : mapper.readTree(response.body()).get("embeddings").get("binary")) {
byte[] embedding = new byte[n.size()];
int i = 0;
for (JsonNode v : n) {
embedding[i++] = (byte) v.asInt();
}
embeddings.add(embedding);
}
return embeddings;
}
}
82 changes: 82 additions & 0 deletions src/test/java/com/pgvector/LoadingTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.pgvector;

import java.io.UnsupportedEncodingException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import com.pgvector.PGvector;
import org.postgresql.PGConnection;
import org.postgresql.copy.CopyIn;
import org.postgresql.copy.CopyManager;
import org.postgresql.core.BaseConnection;
import org.junit.jupiter.api.Test;

public class LoadingTest {
@Test
void example() throws SQLException {
if (System.getenv("TEST_LOADING") == null) {
return;
}

// generate random data
int rows = 1000000;
int dimensions = 128;
ArrayList<float[]> embeddings = new ArrayList<>(rows);
for (int i = 0; i < rows; i++) {
float[] embedding = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
embedding[j] = (float) Math.random();
}
embeddings.add(embedding);
}

// enable extension
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
PGvector.addVectorType(conn);

// create table
setupStmt.executeUpdate("DROP TABLE IF EXISTS items");
setupStmt.executeUpdate("CREATE TABLE items (id bigserial, embedding vector(128))");

// load data
System.out.println("Loading 1000000 rows");

CopyManager copyManager = new CopyManager((BaseConnection) conn);
// TODO use binary format
CopyIn copyIn = copyManager.copyIn("COPY items (embedding) FROM STDIN");
for (int i = 0; i < rows; i++) {
if (i % 10000 == 0) {
System.out.print(".");
}

PGvector embedding = new PGvector(embeddings.get(i));
byte[] bytes = (embedding.getValue() + "\n").getBytes();
copyIn.writeToCopy(bytes, 0, bytes.length);
}
copyIn.endCopy();

System.out.println("\nSuccess!");

// create any indexes *after* loading initial data (skipping for this example)
boolean createIndex = false;
if (createIndex) {
System.out.println("Creating index");
Statement createIndexStmt = conn.createStatement();
createIndexStmt.executeUpdate("SET maintenance_work_mem = '8GB'");
createIndexStmt.executeUpdate("SET max_parallel_maintenance_workers = 7");
createIndexStmt.executeUpdate("CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops)");
}

// update planner statistics for good measure
Statement analyzeStmt = conn.createStatement();
analyzeStmt.executeUpdate("ANALYZE items");

conn.close();
}
}
2 changes: 1 addition & 1 deletion src/test/java/com/pgvector/OpenAITest.java
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ private List<float[]> fetchEmbeddings(String[] input, String apiKey) throws IOEx
for (String v : input) {
root.withArray("input").add(v);
}
root.put("model", "text-embedding-ada-002");
root.put("model", "text-embedding-3-small");
String json = mapper.writeValueAsString(root);

HttpClient client = HttpClient.newHttpClient();
7 changes: 7 additions & 0 deletions src/test/java/com/pgvector/PGbitTest.java
Original file line number Diff line number Diff line change
@@ -17,6 +17,13 @@ void testArrayConstructor() {
assertArrayEquals(new boolean[] {false, true, false, true, false, false, false, false, true}, vec.toArray());
}

void testEmptyArrayConstructor() {
PGbit vec = new PGbit(new boolean[] {});
assertEquals(0, vec.length());
assertArrayEquals(new byte[] {}, vec.toByteArray());
assertArrayEquals(new boolean[] {}, vec.toArray());
}

@Test
void testStringConstructor() throws SQLException {
PGbit vec = new PGbit("010100001");