|
| 1 | +package com.pgvector; |
| 2 | + |
| 3 | +import java.io.UnsupportedEncodingException; |
| 4 | +import java.sql.Connection; |
| 5 | +import java.sql.DriverManager; |
| 6 | +import java.sql.PreparedStatement; |
| 7 | +import java.sql.ResultSet; |
| 8 | +import java.sql.SQLException; |
| 9 | +import java.sql.Statement; |
| 10 | +import java.util.ArrayList; |
| 11 | +import com.pgvector.PGvector; |
| 12 | +import org.postgresql.PGConnection; |
| 13 | +import org.postgresql.copy.CopyIn; |
| 14 | +import org.postgresql.copy.CopyManager; |
| 15 | +import org.postgresql.core.BaseConnection; |
| 16 | +import org.junit.jupiter.api.Test; |
| 17 | + |
| 18 | +public class LoadingTest { |
| 19 | + @Test |
| 20 | + void example() throws SQLException, UnsupportedEncodingException { |
| 21 | + if (System.getenv("TEST_LOADING") == null) { |
| 22 | + return; |
| 23 | + } |
| 24 | + |
| 25 | + // generate random data |
| 26 | + int rows = 1000000; |
| 27 | + int dimensions = 128; |
| 28 | + ArrayList<float[]> embeddings = new ArrayList<>(rows); |
| 29 | + for (int i = 0; i < rows; i++) { |
| 30 | + float[] embedding = new float[dimensions]; |
| 31 | + for (int j = 0; j < dimensions; j++) { |
| 32 | + embedding[j] = (float) Math.random(); |
| 33 | + } |
| 34 | + embeddings.add(embedding); |
| 35 | + } |
| 36 | + |
| 37 | + // enable extension |
| 38 | + Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example"); |
| 39 | + Statement setupStmt = conn.createStatement(); |
| 40 | + setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); |
| 41 | + PGvector.addVectorType(conn); |
| 42 | + |
| 43 | + // create table |
| 44 | + setupStmt.executeUpdate("DROP TABLE IF EXISTS items"); |
| 45 | + setupStmt.executeUpdate("CREATE TABLE items (id bigserial, embedding vector(128))"); |
| 46 | + |
| 47 | + // load data |
| 48 | + System.out.println("Loading 1000000 rows"); |
| 49 | + |
| 50 | + CopyManager copyManager = new CopyManager((BaseConnection) conn); |
| 51 | + // TODO use binary format |
| 52 | + CopyIn copyIn = copyManager.copyIn("COPY items (embedding) FROM STDIN"); |
| 53 | + for (int i = 0; i < rows; i++) { |
| 54 | + if (i % 10000 == 0) { |
| 55 | + System.out.print("."); |
| 56 | + } |
| 57 | + |
| 58 | + PGvector embedding = new PGvector(embeddings.get(i)); |
| 59 | + byte[] bytes = (embedding.getValue() + "\n").getBytes("UTF-8"); |
| 60 | + copyIn.writeToCopy(bytes, 0, bytes.length); |
| 61 | + } |
| 62 | + copyIn.endCopy(); |
| 63 | + |
| 64 | + System.out.println("\nSuccess!"); |
| 65 | + |
| 66 | + // create any indexes *after* loading initial data (skipping for this example) |
| 67 | + boolean createIndex = false; |
| 68 | + if (createIndex) { |
| 69 | + System.out.println("Creating index"); |
| 70 | + Statement createIndexStmt = conn.createStatement(); |
| 71 | + createIndexStmt.executeUpdate("SET maintenance_work_mem = '8GB'"); |
| 72 | + createIndexStmt.executeUpdate("SET max_parallel_maintenance_workers = 7"); |
| 73 | + createIndexStmt.executeUpdate("CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops)"); |
| 74 | + } |
| 75 | + |
| 76 | + // update planner statistics for good measure |
| 77 | + Statement analyzeStmt = conn.createStatement(); |
| 78 | + analyzeStmt.executeUpdate("ANALYZE items"); |
| 79 | + |
| 80 | + conn.close(); |
| 81 | + } |
| 82 | +} |
0 commit comments