Skip to content

Fix component.label error in AGS frontend #5845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 6, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ A: If you are running the server on a remote machine (or a local machine that fa
autogenstudio ui --port 8081 --host 0.0.0.0
```

## Q: How do I use AutoGen Studio with a different database?

A: By default, AutoGen Studio uses SQLite as the database. However, it uses the SQLModel library, which supports multiple database backends. You can use any database supported by SQLModel, such as PostgreSQL or MySQL. To use a different database, you need to specify the connection string for the database using the `--database-uri` argument when running the application. Example connection strings include:

- SQLite: `sqlite:///database.sqlite`
- PostgreSQL: `postgresql+psycopg://user:password@localhost/dbname`
- MySQL: `mysql+pymysql://user:password@localhost/dbname`
- AzureSQL: `mssql+pyodbc:///?odbc_connect=DRIVER%3D%7BODBC+Driver+17+for+SQL+Server%7D%3BSERVER%3Dtcp%3Aservername.database.windows.net%2C1433%3BDATABASE%3Ddatabasename%3BUID%3Dusername%3BPWD%3Dpassword123%3BEncrypt%3Dyes%3BTrustServerCertificate%3Dno%3BConnection+Timeout%3D30%3B`

You can then run the application with the specified database URI. For example, to use PostgreSQL, you can run the following command:

```bash
autogenstudio ui --database-uri postgresql+psycopg://user:password@localhost/dbname
```

> **Note:** Make sure to install the appropriate database drivers for your chosen database:
>
> - PostgreSQL: `pip install psycopg2` or `pip install psycopg2-binary`
> - MySQL: `pip install pymysql`
> - SQL Server/Azure SQL: `pip install pyodbc`
> - Oracle: `pip install cx_oracle`

## Q: Can I export my agent workflows for use in a python app?

Yes. In the Team Builder view, you select a team and download its specification. This file can be imported in a python application using the `TeamManager` class. For example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,12 @@ def run_migrations_online() -> None:
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
is_sqlite = connection.dialect.name == "sqlite"
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True
render_as_batch=is_sqlite,
)
with context.begin_transaction():
context.run_migrations()
Expand All @@ -213,10 +215,11 @@ def _generate_alembic_ini_content(self) -> str:
"""
Generates content for alembic.ini file.
"""
engine_url = str(self.engine.url).replace("%", "%%")
return f"""
[alembic]
script_location = {self.alembic_dir}
sqlalchemy.url = {self.engine.url}
sqlalchemy.url = {engine_url}

[loggers]
keys = root,sqlalchemy,alembic
Expand Down
14 changes: 5 additions & 9 deletions python/packages/autogen-studio/autogenstudio/datamodel/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from uuid import UUID, uuid4

from autogen_core import ComponentModel
from pydantic import ConfigDict
from sqlalchemy import UUID as SQLAlchemyUUID
from sqlalchemy import ForeignKey, Integer, String
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func

Expand Down Expand Up @@ -45,11 +43,9 @@ class Message(SQLModel, table=True):
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(default_factory=MessageConfig, sa_column=Column(JSON))
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(
default=None, sa_column=Column(SQLAlchemyUUID, ForeignKey("run.id", ondelete="CASCADE"))
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="NO ACTION"))
)
run_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("run.id", ondelete="CASCADE")))

message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))

Expand Down Expand Up @@ -84,7 +80,7 @@ class Run(SQLModel, table=True):

__table_args__ = {"sqlite_autoincrement": True}

id: UUID = Field(default_factory=uuid4, sa_column=Column(SQLAlchemyUUID, primary_key=True, index=True, unique=True))
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
Expand All @@ -106,7 +102,7 @@ class Run(SQLModel, table=True):
version: Optional[str] = "0.0.1"
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))

model_config = ConfigDict(json_encoders={UUID: str, datetime: lambda v: v.isoformat()})
model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat()})
user_id: Optional[str] = None


Expand All @@ -125,7 +121,7 @@ class Gallery(SQLModel, table=True):
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))

model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat(), UUID: str})
model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat()})


class Settings(SQLModel, table=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ async def run_stream(
yield event
finally:
# Cleanup - remove our handler
logger.handlers.remove(llm_event_logger)
if llm_event_logger in logger.handlers:
logger.handlers.remove(llm_event_logger)

# Ensure cleanup happens
if team and hasattr(team, "_participants"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import traceback
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional, Union
from uuid import UUID

from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import (
Expand Down Expand Up @@ -42,11 +41,11 @@ class WebSocketManager:

def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
self._connections: Dict[UUID, WebSocket] = {}
self._cancellation_tokens: Dict[UUID, CancellationToken] = {}
self._connections: Dict[int, WebSocket] = {}
self._cancellation_tokens: Dict[int, CancellationToken] = {}
# Track explicitly closed connections
self._closed_connections: set[UUID] = set()
self._input_responses: Dict[UUID, asyncio.Queue] = {}
self._closed_connections: set[int] = set()
self._input_responses: Dict[int, asyncio.Queue] = {}

self._cancel_message = TeamResult(
task_result=TaskResult(
Expand All @@ -63,7 +62,7 @@ def _get_stop_message(self, reason: str) -> dict:
duration=0,
).model_dump()

async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
async def connect(self, websocket: WebSocket, run_id: int) -> bool:
try:
await websocket.accept()
self._connections[run_id] = websocket
Expand All @@ -80,7 +79,7 @@ async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
logger.error(f"Connection error for run {run_id}: {e}")
return False

async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None:
async def start_stream(self, run_id: int, task: str, team_config: dict) -> None:
"""Start streaming task execution with proper run management"""
if run_id not in self._connections or run_id in self._closed_connections:
raise ValueError(f"No active connection for run {run_id}")
Expand Down Expand Up @@ -161,7 +160,7 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None
finally:
self._cancellation_tokens.pop(run_id, None)

async def _save_message(self, run_id: UUID, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None:
async def _save_message(self, run_id: int, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None:
"""Save a message to the database"""

run = await self._get_run(run_id)
Expand All @@ -175,7 +174,7 @@ async def _save_message(self, run_id: UUID, message: Union[AgentEvent | ChatMess
self.db_manager.upsert(db_message)

async def _update_run(
self, run_id: UUID, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
self, run_id: int, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
) -> None:
"""Update run status and result"""
run = await self._get_run(run_id)
Expand All @@ -187,7 +186,7 @@ async def _update_run(
run.error_message = error
self.db_manager.upsert(run)

def create_input_func(self, run_id: UUID) -> Callable:
def create_input_func(self, run_id: int) -> Callable:
"""Creates an input function for a specific run"""

async def input_handler(prompt: str = "", cancellation_token: Optional[CancellationToken] = None) -> str:
Expand Down Expand Up @@ -216,14 +215,14 @@ async def input_handler(prompt: str = "", cancellation_token: Optional[Cancellat

return input_handler

async def handle_input_response(self, run_id: UUID, response: str) -> None:
async def handle_input_response(self, run_id: int, response: str) -> None:
"""Handle input response from client"""
if run_id in self._input_responses:
await self._input_responses[run_id].put(response)
else:
logger.warning(f"Received input response for inactive run {run_id}")

async def stop_run(self, run_id: UUID, reason: str) -> None:
async def stop_run(self, run_id: int, reason: str) -> None:
if run_id in self._cancellation_tokens:
logger.info(f"Stopping run {run_id}")

Expand Down Expand Up @@ -253,7 +252,7 @@ async def stop_run(self, run_id: UUID, reason: str) -> None:
# We might want to force disconnect here if db update failed
# await self.disconnect(run_id) # Optional

async def disconnect(self, run_id: UUID) -> None:
async def disconnect(self, run_id: int) -> None:
"""Clean up connection and associated resources"""
logger.info(f"Disconnecting run {run_id}")

Expand All @@ -268,11 +267,11 @@ async def disconnect(self, run_id: UUID) -> None:
self._cancellation_tokens.pop(run_id, None)
self._input_responses.pop(run_id, None)

async def _send_message(self, run_id: UUID, message: dict) -> None:
async def _send_message(self, run_id: int, message: dict) -> None:
"""Send a message through the WebSocket with connection state checking

Args:
run_id: UUID of the run
run_id: id of the run
message: Message dictionary to send
"""
if run_id in self._closed_connections:
Expand All @@ -292,7 +291,7 @@ async def _send_message(self, run_id: UUID, message: dict) -> None:
await self._update_run_status(run_id, RunStatus.ERROR, str(e))
await self.disconnect(run_id)

async def _handle_stream_error(self, run_id: UUID, error: Exception) -> None:
async def _handle_stream_error(self, run_id: int, error: Exception) -> None:
"""Handle stream errors with proper run updates"""
if run_id not in self._closed_connections:
error_result = TeamResult(
Expand Down Expand Up @@ -366,11 +365,11 @@ def _format_message(self, message: Any) -> Optional[dict]:
logger.error(f"Message formatting error: {e}")
return None

async def _get_run(self, run_id: UUID) -> Optional[Run]:
async def _get_run(self, run_id: int) -> Optional[Run]:
"""Get run from database

Args:
run_id: UUID of the run to retrieve
run_id: id of the run to retrieve

Returns:
Optional[Run]: Run object if found, None otherwise
Expand All @@ -388,11 +387,11 @@ async def _get_settings(self, user_id: str) -> Optional[Settings]:
response = self.db_manager.get(filters={"user_id": user_id}, model_class=Settings, return_json=False)
return response.data[0] if response.status and response.data else None

async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optional[str] = None) -> None:
async def _update_run_status(self, run_id: int, status: RunStatus, error: Optional[str] = None) -> None:
"""Update run status in database

Args:
run_id: UUID of the run to update
run_id: id of the run to update
status: New status to set
error: Optional error message
"""
Expand Down Expand Up @@ -451,11 +450,11 @@ async def cleanup(self) -> None:
self._input_responses.clear()

@property
def active_connections(self) -> set[UUID]:
def active_connections(self) -> set[int]:
"""Get set of active run IDs"""
return set(self._connections.keys()) - self._closed_connections

@property
def active_runs(self) -> set[UUID]:
def active_runs(self) -> set[int]:
"""Get set of runs with active cancellation tokens"""
return set(self._cancellation_tokens.keys())
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# /api/runs routes
from typing import Dict
from uuid import UUID

from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel
Expand Down Expand Up @@ -40,7 +39,7 @@ async def create_run(
),
return_json=False,
)
return {"status": run.status, "data": {"run_id": str(run.data.id)}}
return {"status": run.status, "data": {"run_id": run.data.id}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

Expand All @@ -49,7 +48,7 @@ async def create_run(


@router.get("/{run_id}")
async def get_run(run_id: UUID, db=Depends(get_db)) -> Dict:
async def get_run(run_id: int, db=Depends(get_db)) -> Dict:
"""Get run details including task and result"""
run = db.get(Run, filters={"id": run_id}, return_json=False)
if not run.status or not run.data:
Expand All @@ -59,7 +58,7 @@ async def get_run(run_id: UUID, db=Depends(get_db)) -> Dict:


@router.get("/{run_id}/messages")
async def get_run_messages(run_id: UUID, db=Depends(get_db)) -> Dict:
async def get_run_messages(run_id: int, db=Depends(get_db)) -> Dict:
"""Get all messages for a run"""
messages = db.get(Message, filters={"run_id": run_id}, order="created_at asc", return_json=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import APIRouter, Depends, HTTPException

from ...datamodel import Team
from ...gallery.builder import create_default_gallery
from ..deps import get_db

router = APIRouter()
Expand All @@ -13,6 +14,13 @@
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
"""List all teams for a user"""
response = db.get(Team, filters={"user_id": user_id})
if not response.data or len(response.data) == 0:
default_gallery = create_default_gallery()
default_team = Team(user_id=user_id, component=default_gallery.components.teams[0].model_dump())

db.upsert(default_team)
response = db.get(Team, filters={"user_id": user_id})

return {"status": True, "data": response.data}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import asyncio
import json
from datetime import datetime
from uuid import UUID

from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from loguru import logger
Expand All @@ -17,7 +16,7 @@
@router.websocket("/runs/{run_id}")
async def run_websocket(
websocket: WebSocket,
run_id: UUID,
run_id: int,
ws_manager: WebSocketManager = Depends(get_websocket_manager),
db=Depends(get_db),
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ export interface DBModel {
export interface Message extends DBModel {
config: AgentMessageConfig;
session_id: number;
run_id: string;
run_id: number;
}

export interface Team extends DBModel {
Expand Down Expand Up @@ -321,7 +321,7 @@ export interface TeamResult {
}

export interface Run {
id: string;
id: number;
created_at: string;
updated_at?: string;
status: RunStatus;
Expand Down
Loading
Loading