How to perform vector search in Java with the Jedis client library?

Last updated 20, Apr 2024

Question

How to perform vector search in Java with the Jedis client library?

Answer

Create a Java Maven project (check the instructions to build a scaffold project) and include the following dependencies (specify the desired versions):

    <dependency>
      <groupId>redis.clients</groupId>
      <artifactId>jedis</artifactId>
      <version>5.0.1</version>
    </dependency>
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.24.0</version>
    </dependency>
    <dependency>
      <groupId>ai.djl.huggingface</groupId>
      <artifactId>tokenizers</artifactId>
      <version>0.24.0</version>
    </dependency>

The example will store three sentences ("That is a very happy person", "That is a happy dog", "Today is a sunny day") as Redis hashes and finds the similarity of the test sentence "That is a happy person" from the modeled sentences. Vector search is configured to return three results (KNN 3)

package com.redis.app;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.UnifiedJedis;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.schemafields.*;
import redis.clients.jedis.HostAndPort;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;

import java.util.HashMap;
import java.util.List;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;


public class App {
    public static byte[] floatArrayToByteArray(float[] input) {
        byte[] bytes = new byte[Float.BYTES * input.length];
        ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input);
        return bytes;
    }

    public static byte[] longArrayToByteArray(long[] input) {
        return floatArrayToByteArray(longArrayToFloatArray(input));
    }

    public static float[] longArrayToFloatArray(long[] input) {
        float[] floats = new float[input.length];
        for (int i = 0; i < input.length; i++) {
            floats[i] = input[i];
        }
        return floats;
    }

    public static void main(String[] args) {
        // Connect to Redis
        UnifiedJedis unifiedjedis = new UnifiedJedis(System.getenv().getOrDefault("REDIS_URL", "redis://localhost:6379"));

        // Create the index
        IndexDefinition definition = new IndexDefinition().setPrefixes(new String[]{"doc:"});
        Map<String, Object> attr = new HashMap<>();
        attr.put("TYPE", "FLOAT32");
        attr.put("DIM", 768);
        attr.put("DISTANCE_METRIC", "L2");
        attr.put("INITIAL_CAP", 3);
        Schema schema = new Schema().addTextField("content", 1).addTagField("genre").addHNSWVectorField("embedding", attr);                      

        // Catch exceptions if the index exists
        try {
            unifiedjedis.ftCreate("vector_idx", IndexOptions.defaultOptions().setDefinition(definition), schema);
        }
        catch(Exception e) {
            System.out.println(e.getMessage());
        }

        // Create the embedding model
        Map<String, String> options = Map.of("maxLength", "768",  "modelMaxLength", "768");
        HuggingFaceTokenizer sentenceTokenizer = HuggingFaceTokenizer.newInstance("sentence-transformers/all-mpnet-base-v2", options);

        // Train with sentences
        String sentence1 = "That is a very happy person";
        unifiedjedis.hset("doc:1", Map.of(  "content", sentence1, "genre", "persons"));
        unifiedjedis.hset("doc:1".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence1).getIds()));

        String sentence2 = "That is a happy dog";
        unifiedjedis.hset("doc:2", Map.of(  "content", sentence2, "genre", "pets"));
        unifiedjedis.hset("doc:2".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence2).getIds()));

        String sentence3 = "Today is a sunny day";
        Map<String, String> doc3 = Map.of(  "content", sentence3, "genre", "weather");
        unifiedjedis.hset("doc:3", doc3);
        unifiedjedis.hset("doc:3".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence3).getIds()));

        // This is the test sentence
        String sentence = "That is a happy person";

        int K = 3;
        Query q = new Query("*=>[KNN $K @embedding $BLOB AS score]").
                            returnFields("content", "score").
                            addParam("K", K).
                            addParam("BLOB", longArrayToByteArray(sentenceTokenizer.encode(sentence).getIds())).
                            dialect(2);

        // Execute the query
        List<Document> docs = unifiedjedis.ftSearch("vector_idx", q).getDocuments();
        System.out.println(docs);
    }
}

Ensure that your Redis Stack instance (or a Redis Cloud database) is running and that you have set the REDIS_URL environment variable if necessary. Example:

export REDIS_URL=redis://user:password@host:port

By default, the connection is attempted to a localhost Redis Stack instance on port 6379

The example is provided as a Maven project, which you can compile using

mvn package

And execute using:

mvn exec:java -Dexec.mainClass=com.redis.app.App

As expected, the minimum distance corresponds to the highest semantic similarity of the two sentences being compared.

[id:doc:1, score: 1.0, properties:[score=9301635, content=That is a very happy person], id:doc:2, score: 1.0, properties:[score=1411344, content=That is a happy dog], id:doc:3, score: 1.0, properties:[score=67178800, content=Today is a sunny day]]

References