-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for local models via Ollama #6
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Ollama setup | ||
1. Download and install [Ollama](https://ollama.com/) | ||
2. Once Ollama is running on your system, run `ollama pull llama3.1` | ||
> Currently this is a ~5GB download, it's best to download it before the workshop if you plan on using it | ||
3. Update the `MODEL_NAME` in your `dot.env` file to `ollama` | ||
|
||
You're now ready to begin the workshop! Head back to the [Readme.md](Readme.md) | ||
|
||
## Restarting the workshop | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may bear further investigation, but in my tests it was best to kill and re-create |
||
Mixing use of llama and openai on the same Redis instance can cause unexpected behavior. If you want to switch from one to the other it is recommended to kill and re-create the instance. To do this: | ||
1. Run `docker ps` and take note of the ID for the running image | ||
2. `docker stop imageId` | ||
3. `docker rm imageId` | ||
4. Start a new instance using the command from earlier, `docker run -d --name redis -p 6379:6379 -p 8001:8001 redis/redis-stack:latest` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ OPENAI_API_KEY=openai_key | |
LANGCHAIN_TRACING_V2= | ||
LANGCHAIN_ENDPOINT= | ||
LANGCHAIN_API_KEY= | ||
LANGCHAIN_PROJECT= | ||
LANGCHAIN_PROJECT= | ||
MODEL_NAME=openai | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defaulting to openai |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,26 @@ | ||
import os | ||
from functools import lru_cache | ||
|
||
from dotenv import load_dotenv | ||
from langchain_core.messages import HumanMessage | ||
from langchain_openai import ChatOpenAI | ||
from langchain_ollama import ChatOllama | ||
from langgraph.prebuilt import ToolNode | ||
|
||
from example_agent.utils.ex_tools import tools | ||
|
||
from .ex_state import AgentState, MultipleChoiceResponse | ||
|
||
load_dotenv() | ||
|
||
ENVIRON_MODEL_NAME = os.environ.get("MODEL_NAME") | ||
|
||
@lru_cache(maxsize=4) | ||
def _get_tool_model(model_name: str): | ||
if model_name == "openai": | ||
model = ChatOpenAI(temperature=0, model_name="gpt-4o") | ||
elif model_name == "ollama": | ||
model = ChatOllama(temperature=0, model="llama3.1", num_ctx=4096) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. increasing the context from the default (which is pretty low) provided much more reliable results |
||
else: | ||
raise ValueError(f"Unsupported model type: {model_name}") | ||
|
||
|
@@ -24,6 +32,8 @@ def _get_tool_model(model_name: str): | |
def _get_response_model(model_name: str): | ||
if model_name == "openai": | ||
model = ChatOpenAI(temperature=0, model_name="gpt-4o") | ||
elif model_name == "ollama": | ||
model = ChatOllama(temperature=0, model="llama3.1", num_ctx=4096) | ||
else: | ||
raise ValueError(f"Unsupported model type: {model_name}") | ||
|
||
|
@@ -36,7 +46,7 @@ def multi_choice_structured(state: AgentState, config): | |
# We call the model with structured output in order to return the same format to the user every time | ||
# state['messages'][-2] is the last ToolMessage in the convo, which we convert to a HumanMessage for the model to use | ||
# We could also pass the entire chat history, but this saves tokens since all we care to structure is the output of the tool | ||
model_name = config.get("configurable", {}).get("model_name", "openai") | ||
model_name = config.get("configurable", {}).get("model_name", ENVIRON_MODEL_NAME) | ||
|
||
response = _get_response_model(model_name).invoke( | ||
[ | ||
|
@@ -75,7 +85,7 @@ def call_tool_model(state: AgentState, config): | |
messages = [{"role": "system", "content": system_prompt}] + state["messages"] | ||
|
||
# Get from LangGraph config | ||
model_name = config.get("configurable", {}).get("model_name", "openai") | ||
model_name = config.get("configurable", {}).get("model_name", ENVIRON_MODEL_NAME) | ||
|
||
# Get our model that binds our tools | ||
model = _get_tool_model(model_name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping most ollama-specific instructions separate to be less intrusive