Skip to content

Commit e788a2f

Browse files
committed
Added bulk loading example [skip ci]
1 parent afd9160 commit e788a2f

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Or check out an example:
3737

3838
- [Embeddings](src/test/java/com/pgvector/OpenAITest.java) with OpenAI
3939
- [Binary embeddings](src/test/java/com/pgvector/CohereTest.java) with Cohere
40+
- [Bulk loading](src/test/java/com/pgvector/LoadingTest.java) with `COPY`
4041

4142
## JDBC (Java)
4243

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package com.pgvector;
2+
3+
import java.io.UnsupportedEncodingException;
4+
import java.sql.Connection;
5+
import java.sql.DriverManager;
6+
import java.sql.PreparedStatement;
7+
import java.sql.ResultSet;
8+
import java.sql.SQLException;
9+
import java.sql.Statement;
10+
import java.util.ArrayList;
11+
import com.pgvector.PGvector;
12+
import org.postgresql.PGConnection;
13+
import org.postgresql.copy.CopyIn;
14+
import org.postgresql.copy.CopyManager;
15+
import org.postgresql.core.BaseConnection;
16+
import org.junit.jupiter.api.Test;
17+
18+
public class LoadingTest {
19+
@Test
20+
void example() throws SQLException, UnsupportedEncodingException {
21+
if (System.getenv("TEST_LOADING") == null) {
22+
return;
23+
}
24+
25+
// generate random data
26+
int rows = 1000000;
27+
int dimensions = 128;
28+
ArrayList<float[]> embeddings = new ArrayList<>(rows);
29+
for (int i = 0; i < rows; i++) {
30+
float[] embedding = new float[dimensions];
31+
for (int j = 0; j < dimensions; j++) {
32+
embedding[j] = (float) Math.random();
33+
}
34+
embeddings.add(embedding);
35+
}
36+
37+
// enable extension
38+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
39+
Statement setupStmt = conn.createStatement();
40+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
41+
PGvector.addVectorType(conn);
42+
43+
// create table
44+
setupStmt.executeUpdate("DROP TABLE IF EXISTS items");
45+
setupStmt.executeUpdate("CREATE TABLE items (id bigserial, embedding vector(128))");
46+
47+
// load data
48+
System.out.println("Loading 1000000 rows");
49+
50+
CopyManager copyManager = new CopyManager((BaseConnection) conn);
51+
// TODO use binary format
52+
CopyIn copyIn = copyManager.copyIn("COPY items (embedding) FROM STDIN");
53+
for (int i = 0; i < rows; i++) {
54+
if (i % 10000 == 0) {
55+
System.out.print(".");
56+
}
57+
58+
PGvector embedding = new PGvector(embeddings.get(i));
59+
byte[] bytes = (embedding.getValue() + "\n").getBytes("UTF-8");
60+
copyIn.writeToCopy(bytes, 0, bytes.length);
61+
}
62+
copyIn.endCopy();
63+
64+
System.out.println("\nSuccess!");
65+
66+
// create any indexes *after* loading initial data (skipping for this example)
67+
boolean createIndex = false;
68+
if (createIndex) {
69+
System.out.println("Creating index");
70+
Statement createIndexStmt = conn.createStatement();
71+
createIndexStmt.executeUpdate("SET maintenance_work_mem = '8GB'");
72+
createIndexStmt.executeUpdate("SET max_parallel_maintenance_workers = 7");
73+
createIndexStmt.executeUpdate("CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops)");
74+
}
75+
76+
// update planner statistics for good measure
77+
Statement analyzeStmt = conn.createStatement();
78+
analyzeStmt.executeUpdate("ANALYZE items");
79+
80+
conn.close();
81+
}
82+
}

0 commit comments

Comments
 (0)