Skip to content

Support embedding func in ChromaDB memory #6267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Tracked by #4006
victordibia opened this issue Apr 10, 2025 · 4 comments
Open
Tracked by #4006

Support embedding func in ChromaDB memory #6267

victordibia opened this issue Apr 10, 2025 · 4 comments
Labels
help wanted Extra attention is needed proj-extensions

Comments

@victordibia
Copy link
Collaborator

victordibia commented Apr 10, 2025

Current Status

The current implementation of ChromaDBVectorMemory in the AutoGen extension package doesn't expose parameters for setting custom embedding functions. It relies entirely on ChromaDB's default embedding function (Sentence Transformers all-MiniLM-L6-v2).

Goal

Allow users to customize the embedding function used by ChromaDBVectorMemory through a flexible, declarative configuration system that supports:

  1. Default embedding function (current behavior)
  2. Alternative Sentence Transformer models
  3. OpenAI embeddings
  4. Custom user-defined embedding functions

Rough Sketch of an Implementation Plan

1. Create Base Configuration Classes

Create a hierarchy of embedding function configurations:

class BaseEmbeddingFunctionConfig(BaseModel):
    """Base configuration for embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"]
    

class DefaultEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for the default embedding function."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "default"


class SentenceTransformerEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for SentenceTransformer embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "sentence_transformer"
    model_name: str = Field(default="all-MiniLM-L6-v2", description="Model name to use")
    

class OpenAIEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for OpenAI embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "openai"
    api_key: str = Field(default="", description="OpenAI API key")
    model_name: str = Field(default="text-embedding-ada-002", description="Model name")

2. Support Custom Embedding Functions

Add a configuration for custom embedding functions using the direct function approach:

class CustomEmbeddingFunctionConfig(BaseEmbeddingFunctionConfig):
    """Configuration for custom embedding functions."""
    function_type: Literal["default", "sentence_transformer", "openai", "custom"] = "custom"
    function: Callable[..., Any] = Field(description="Function that returns an embedding function")
    params: Dict[str, Any] = Field(default_factory=dict, description="Parameters")

Note: Using a direct function in the configuration will make it non-serializable. The implementation should include appropriate warnings when users attempt to serialize configurations that contain function references.

3. Update ChromaDBVectorMemory Configuration

Extend the existing ChromaDBVectorMemoryConfig class to include the embedding function configuration:

class ChromaDBVectorMemoryConfig(BaseModel):
    # Existing fields...
    embedding_function_config: BaseEmbeddingFunctionConfig = Field(
        default_factory=DefaultEmbeddingFunctionConfig,
        description="Configuration for the embedding function"
    )

4. Implement Embedding Function Creation

Add a method to ChromaDBVectorMemory that creates embedding functions based on configuration:

def _create_embedding_function(self):
    """Create an embedding function based on the configuration."""
    from chromadb.utils import embedding_functions
    
    config = self._config.embedding_function_config
    
    if config.function_type == "default":
        return embedding_functions.DefaultEmbeddingFunction()
    
    elif config.function_type == "sentence_transformer":
        cfg = cast(SentenceTransformerEmbeddingFunctionConfig, config)
        return embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name=cfg.model_name
        )
    
    elif config.function_type == "openai":
        cfg = cast(OpenAIEmbeddingFunctionConfig, config)
        return embedding_functions.OpenAIEmbeddingFunction(
            api_key=cfg.api_key,
            model_name=cfg.model_name
        )
    
    elif config.function_type == "custom":
        cfg = cast(CustomEmbeddingFunctionConfig, config)
        return cfg.function(**cfg.params)
    
    else:
        raise ValueError(f"Unsupported embedding function type: {config.function_type}")

5. Update Collection Initialization

Modify the _ensure_initialized method to use the embedding function:

def _ensure_initialized(self) -> None:
    # ... existing client initialization code ...
    
    if self._collection is None:
        try:
            # Create embedding function
            embedding_function = self._create_embedding_function()
            
            # Create or get collection with embedding function
            self._collection = self._client.get_or_create_collection(
                name=self._config.collection_name,
                metadata={"distance_metric": self._config.distance_metric},
                embedding_function=embedding_function
            )
        except Exception as e:
            logger.error(f"Failed to get/create collection: {e}")
            raise

Example Usage

# Using default embedding function
memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig()
)

# Using a specific Sentence Transformer model
memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig(
        embedding_function_config=SentenceTransformerEmbeddingFunctionConfig(
            model_name="paraphrase-multilingual-mpnet-base-v2"
        )
    )
)

# Using OpenAI embeddings
memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig(
        embedding_function_config=OpenAIEmbeddingFunctionConfig(
            api_key="sk-...",
            model_name="text-embedding-3-small"
        )
    )
)

# Using a custom embedding function (direct function approach)
def create_my_embedder(param1="default"):
    # Return a ChromaDB-compatible embedding function
    class MyCustomEmbeddingFunction(EmbeddingFunction):
        def __call__(self, input: Documents) -> Embeddings:
            # Custom embedding logic here
            return embeddings
    
    return MyCustomEmbeddingFunction(param1)

memory = ChromaDBVectorMemory(
    config=PersistentChromaDBVectorMemoryConfig(
        embedding_function_config=CustomEmbeddingFunctionConfig(
            function=create_my_embedder,
            params={"param1": "custom_value"}
        )
    )
)
 
@victordibia
Copy link
Collaborator Author

@mpegram3rd
Thanks for the help with the types PR!
Might this be an issue that you might be interested in working on?
I have a rough sketch above mostly just as an initial design ... more work will be needed to arrive at a clean implementation with no side effects.

@jcroucherMSFT
Copy link

+1 - This would be a very valuable feature!

@mpegram3rd
Copy link
Contributor

mpegram3rd commented Apr 16, 2025

@victordibia Thanks for thinking of me. I doubt I could do anything with it this week, but if it is not a rush I could possibly take a look into it sometime next week (no promises).

Reading through your suggested fixes, I'd also need some time to really digest how this all fits together in the big picture.

@victordibia
Copy link
Collaborator Author

No worries at all, thanks again for the fine contribution!
I'll also tag you for a review if I get to it before then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed proj-extensions
Projects
None yet
Development

No branches or pull requests

3 participants