Skip to content

Support OpenAI reasoning models #1841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20250325000101658359.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Support OpenAI reasoning models."
}
43 changes: 13 additions & 30 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,7 @@ class BasicSearchDefaults:
"""Default values for basic search."""

prompt: None = None
text_unit_prop: float = 0.5
conversation_history_max_turns: int = 5
temperature: float = 0
top_p: float = 1
n: int = 1
max_tokens: int = 12_000
llm_max_tokens: int = 2000
k: int = 10
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID

Expand Down Expand Up @@ -104,13 +98,10 @@ class DriftSearchDefaults:

prompt: None = None
reduce_prompt: None = None
temperature: float = 0
top_p: float = 1
n: int = 1
max_tokens: int = 12_000
data_max_tokens: int = 12_000
reduce_max_tokens: int = 2_000
reduce_max_tokens: None = None
reduce_temperature: float = 0
reduce_max_completion_tokens: None = None
concurrency: int = 32
drift_k_followups: int = 20
primer_folds: int = 5
Expand All @@ -124,7 +115,8 @@ class DriftSearchDefaults:
local_search_temperature: float = 0
local_search_top_p: float = 1
local_search_n: int = 1
local_search_llm_max_gen_tokens: int = 4_096
local_search_llm_max_gen_tokens = None
local_search_llm_max_gen_completion_tokens = None
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID

Expand Down Expand Up @@ -168,7 +160,6 @@ class ExtractClaimsDefaults:
)
max_gleanings: int = 1
strategy: None = None
encoding_model: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID


Expand All @@ -182,7 +173,6 @@ class ExtractGraphDefaults:
)
max_gleanings: int = 1
strategy: None = None
encoding_model: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID


Expand Down Expand Up @@ -228,20 +218,14 @@ class GlobalSearchDefaults:
map_prompt: None = None
reduce_prompt: None = None
knowledge_prompt: None = None
temperature: float = 0
top_p: float = 1
n: int = 1
max_tokens: int = 12_000
max_context_tokens: int = 12_000
data_max_tokens: int = 12_000
map_max_tokens: int = 1000
reduce_max_tokens: int = 2000
concurrency: int = 32
dynamic_search_llm: str = "gpt-4o-mini"
map_max_length: int = 1000
reduce_max_length: int = 2000
dynamic_search_threshold: int = 1
dynamic_search_keep_parent: bool = False
dynamic_search_num_repeats: int = 1
dynamic_search_use_summary: bool = False
dynamic_search_concurrent_coroutines: int = 16
dynamic_search_max_level: int = 2
chat_model_id: str = DEFAULT_CHAT_MODEL_ID

Expand Down Expand Up @@ -271,8 +255,10 @@ class LanguageModelDefaults:
api_key: None = None
auth_type = AuthType.APIKey
encoding_model: str = ""
max_tokens: int = 4000
max_tokens: int | None = None
temperature: float = 0
max_completion_tokens: int | None = None
reasoning_effort: str | None = None
top_p: float = 1
n: int = 1
frequency_penalty: float = 0.0
Expand Down Expand Up @@ -305,11 +291,7 @@ class LocalSearchDefaults:
conversation_history_max_turns: int = 5
top_k_entities: int = 10
top_k_relationships: int = 10
temperature: float = 0
top_p: float = 1
n: int = 1
max_tokens: int = 12_000
llm_max_tokens: int = 2000
max_context_tokens: int = 12_000
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID

Expand Down Expand Up @@ -364,6 +346,7 @@ class SummarizeDescriptionsDefaults:

prompt: None = None
max_length: int = 500
max_input_tokens: int = 4_000
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID

Expand Down
30 changes: 3 additions & 27 deletions graphrag/config/models/basic_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,7 @@ class BasicSearchConfig(BaseModel):
description="The model ID to use for text embeddings.",
default=graphrag_config_defaults.basic_search.embedding_model_id,
)
text_unit_prop: float = Field(
description="The text unit proportion.",
default=graphrag_config_defaults.basic_search.text_unit_prop,
)
conversation_history_max_turns: int = Field(
description="The conversation history maximum turns.",
default=graphrag_config_defaults.basic_search.conversation_history_max_turns,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=graphrag_config_defaults.basic_search.temperature,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=graphrag_config_defaults.basic_search.top_p,
)
n: int = Field(
description="The number of completions to generate.",
default=graphrag_config_defaults.basic_search.n,
)
max_tokens: int = Field(
description="The maximum tokens.",
default=graphrag_config_defaults.basic_search.max_tokens,
)
llm_max_tokens: int = Field(
description="The LLM maximum tokens.",
default=graphrag_config_defaults.basic_search.llm_max_tokens,
k: int = Field(
description="The number of text units to include in search context.",
default=graphrag_config_defaults.basic_search.k,
)
1 change: 0 additions & 1 deletion graphrag/config/models/community_reports_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def resolved_strategy(
return self.strategy or {
"type": CreateCommunityReportsStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"num_threads": model_config.concurrent_requests,
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
encoding="utf-8"
)
Expand Down
30 changes: 12 additions & 18 deletions graphrag/config/models/drift_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,12 @@ class DRIFTSearchConfig(BaseModel):
description="The model ID to use for drift search.",
default=graphrag_config_defaults.drift_search.embedding_model_id,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=graphrag_config_defaults.drift_search.temperature,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=graphrag_config_defaults.drift_search.top_p,
)
n: int = Field(
description="The number of completions to generate.",
default=graphrag_config_defaults.drift_search.n,
)
max_tokens: int = Field(
description="The maximum context size in tokens.",
default=graphrag_config_defaults.drift_search.max_tokens,
)
data_max_tokens: int = Field(
description="The data llm maximum tokens.",
default=graphrag_config_defaults.drift_search.data_max_tokens,
)

reduce_max_tokens: int = Field(
reduce_max_tokens: int | None = Field(
description="The reduce llm maximum tokens response to produce.",
default=graphrag_config_defaults.drift_search.reduce_max_tokens,
)
Expand All @@ -58,6 +42,11 @@ class DRIFTSearchConfig(BaseModel):
default=graphrag_config_defaults.drift_search.reduce_temperature,
)

reduce_max_completion_tokens: int | None = Field(
description="The reduce llm maximum tokens response to produce.",
default=graphrag_config_defaults.drift_search.reduce_max_completion_tokens,
)

concurrency: int = Field(
description="The number of concurrent requests.",
default=graphrag_config_defaults.drift_search.concurrency,
Expand Down Expand Up @@ -123,7 +112,12 @@ class DRIFTSearchConfig(BaseModel):
default=graphrag_config_defaults.drift_search.local_search_n,
)

local_search_llm_max_gen_tokens: int = Field(
local_search_llm_max_gen_tokens: int | None = Field(
description="The maximum number of generated tokens for the LLM in local search.",
default=graphrag_config_defaults.drift_search.local_search_llm_max_gen_tokens,
)

local_search_llm_max_gen_completion_tokens: int | None = Field(
description="The maximum number of generated tokens for the LLM in local search.",
default=graphrag_config_defaults.drift_search.local_search_llm_max_gen_completion_tokens,
)
6 changes: 0 additions & 6 deletions graphrag/config/models/extract_claims_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,18 @@ class ClaimExtractionConfig(BaseModel):
description="The override strategy to use.",
default=graphrag_config_defaults.extract_claims.strategy,
)
encoding_model: str | None = Field(
default=graphrag_config_defaults.extract_claims.encoding_model,
description="The encoding model to use.",
)

def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved claim extraction strategy."""
return self.strategy or {
"llm": model_config.model_dump(),
"num_threads": model_config.concurrent_requests,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
if self.prompt
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": model_config.encoding_model,
}
6 changes: 0 additions & 6 deletions graphrag/config/models/extract_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ class ExtractGraphConfig(BaseModel):
description="Override the default entity extraction strategy",
default=graphrag_config_defaults.extract_graph.strategy,
)
encoding_model: str | None = Field(
default=graphrag_config_defaults.extract_graph.encoding_model,
description="The encoding model to use.",
)

def resolved_strategy(
self, root_dir: str, model_config: LanguageModelConfig
Expand All @@ -50,12 +46,10 @@ def resolved_strategy(
return self.strategy or {
"type": ExtractEntityStrategyType.graph_intelligence,
"llm": model_config.model_dump(),
"num_threads": model_config.concurrent_requests,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
encoding="utf-8"
)
if self.prompt
else None,
"max_gleanings": self.max_gleanings,
"encoding_name": model_config.encoding_model,
}
40 changes: 8 additions & 32 deletions graphrag/config/models/global_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,44 +27,24 @@ class GlobalSearchConfig(BaseModel):
description="The global search general prompt to use.",
default=graphrag_config_defaults.global_search.knowledge_prompt,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=graphrag_config_defaults.global_search.temperature,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=graphrag_config_defaults.global_search.top_p,
)
n: int = Field(
description="The number of completions to generate.",
default=graphrag_config_defaults.global_search.n,
)
max_tokens: int = Field(
max_context_tokens: int = Field(
description="The maximum context size in tokens.",
default=graphrag_config_defaults.global_search.max_tokens,
default=graphrag_config_defaults.global_search.max_context_tokens,
)
data_max_tokens: int = Field(
description="The data llm maximum tokens.",
default=graphrag_config_defaults.global_search.data_max_tokens,
)
map_max_tokens: int = Field(
description="The map llm maximum tokens.",
default=graphrag_config_defaults.global_search.map_max_tokens,
map_max_length: int = Field(
description="The map llm maximum response length in words.",
default=graphrag_config_defaults.global_search.map_max_length,
)
reduce_max_tokens: int = Field(
description="The reduce llm maximum tokens.",
default=graphrag_config_defaults.global_search.reduce_max_tokens,
)
concurrency: int = Field(
description="The number of concurrent requests.",
default=graphrag_config_defaults.global_search.concurrency,
reduce_max_length: int = Field(
description="The reduce llm maximum response length in words.",
default=graphrag_config_defaults.global_search.reduce_max_length,
)

# configurations for dynamic community selection
dynamic_search_llm: str = Field(
description="LLM model to use for dynamic community selection",
default=graphrag_config_defaults.global_search.dynamic_search_llm,
)
dynamic_search_threshold: int = Field(
description="Rating threshold in include a community report",
default=graphrag_config_defaults.global_search.dynamic_search_threshold,
Expand All @@ -81,10 +61,6 @@ class GlobalSearchConfig(BaseModel):
description="Use community summary instead of full_context",
default=graphrag_config_defaults.global_search.dynamic_search_use_summary,
)
dynamic_search_concurrent_coroutines: int = Field(
description="Number of concurrent coroutines to rate community reports",
default=graphrag_config_defaults.global_search.dynamic_search_concurrent_coroutines,
)
dynamic_search_max_level: int = Field(
description="The maximum level of community hierarchy to consider if none of the processed communities are relevant",
default=graphrag_config_defaults.global_search.dynamic_search_max_level,
Expand Down
10 changes: 9 additions & 1 deletion graphrag/config/models/language_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,22 @@ def _validate_deployment_name(self) -> None:
default=language_model_defaults.responses,
description="Static responses to use in mock mode.",
)
max_tokens: int = Field(
max_tokens: int | None = Field(
description="The maximum number of tokens to generate.",
default=language_model_defaults.max_tokens,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=language_model_defaults.temperature,
)
max_completion_tokens: int | None = Field(
description="The maximum number of tokens to consume. This includes reasoning tokens for the o* reasoning models.",
default=language_model_defaults.max_completion_tokens,
)
reasoning_effort: str | None = Field(
description="Level of effort OpenAI reasoning models should expend. Supported options are 'low', 'medium', 'high'; and OAI defaults to 'medium'.",
default=language_model_defaults.reasoning_effort,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=language_model_defaults.top_p,
Expand Down
20 changes: 2 additions & 18 deletions graphrag/config/models/local_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,7 @@ class LocalSearchConfig(BaseModel):
description="The top k mapped relations.",
default=graphrag_config_defaults.local_search.top_k_relationships,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=graphrag_config_defaults.local_search.temperature,
)
top_p: float = Field(
description="The top-p value to use for token generation.",
default=graphrag_config_defaults.local_search.top_p,
)
n: int = Field(
description="The number of completions to generate.",
default=graphrag_config_defaults.local_search.n,
)
max_tokens: int = Field(
max_context_tokens: int = Field(
description="The maximum tokens.",
default=graphrag_config_defaults.local_search.max_tokens,
)
llm_max_tokens: int = Field(
description="The LLM maximum tokens.",
default=graphrag_config_defaults.local_search.llm_max_tokens,
default=graphrag_config_defaults.local_search.max_context_tokens,
)
Loading
Loading