|
| 1 | +''' |
| 2 | +demo/arithmetic_calc.py |
| 3 | +
|
| 4 | +Based on an exaple from /llama-cpp-agent |
| 5 | +https://llama-cpp-agent.readthedocs.io/en/latest/parallel_function_calling/ |
| 6 | +''' |
| 7 | + |
| 8 | +import asyncio |
| 9 | +from enum import Enum, auto |
| 10 | + |
| 11 | +from toolio.tool import tool, param |
| 12 | +from toolio.llm_helper import model_manager, extract_content |
| 13 | + |
| 14 | + |
| 15 | +class arithmetic_op(Enum): |
| 16 | + ADD = 'add' |
| 17 | + SUBTRACT = 'subtract' |
| 18 | + MULTIPLY = 'multiply' |
| 19 | + DIVIDE = 'divide' |
| 20 | + |
| 21 | + |
| 22 | +@tool('arithmetic_calc', params=[ |
| 23 | + param('num1', float, 'Number on the left hand side of the calculation', True), |
| 24 | + param('num2', float, 'Number on the left hand side of the calculation', True), |
| 25 | + param('op', arithmetic_op, 'Arithmetic operation to make on the two numbers', True), |
| 26 | + ]) |
| 27 | +async def arithmetic_calc(num1=None, num2=None, op=None): |
| 28 | + 'Very basic arithmetic calculator' |
| 29 | + match op: |
| 30 | + case arithmetic_op.ADD: |
| 31 | + result = num1 + num2 |
| 32 | + case arithmetic_op.SUBTRACT: |
| 33 | + result = num1 - num2 |
| 34 | + case arithmetic_op.MULTIPLY: |
| 35 | + result = num1 * num2 |
| 36 | + case arithmetic_op.DIVIDE: |
| 37 | + result = num1 / num2 |
| 38 | + case _: |
| 39 | + raise ValueError('Unknown operator') # Shouldn't happen |
| 40 | + return result |
| 41 | + |
| 42 | + |
| 43 | +# Had a problem using Hermes-2-Theta-Llama-3-8B-4bit 😬 |
| 44 | +# MLX_MODEL_PATH = 'mlx-community/Hermes-2-Theta-Llama-3-8B-4bit' |
| 45 | +MLX_MODEL_PATH = 'mlx-community/Mistral-Nemo-Instruct-2407-4bit' |
| 46 | + |
| 47 | +toolio_mm = model_manager(MLX_MODEL_PATH, tool_reg=[arithmetic_calc], trace=True) |
| 48 | + |
| 49 | +# PROMPT = 'Solve the following calculations: 42 * 42, 24 * 24, 5 * 5, 89 * 75, 42 * 46, 69 * 85, 422 * 420, 753 * 321, 72 * 55, 240 * 204, 789 * 654, 123 * 321, 432 * 89, 564 * 321?' # noqa: E501 |
| 50 | +PROMPT = 'Solve the following calculation: 4242 * 2424.2' |
| 51 | +async def async_main(tmm): |
| 52 | + msgs = [ {'role': 'user', 'content': PROMPT} ] |
| 53 | + async for chunk in extract_content(tmm.complete_with_tools(msgs)): |
| 54 | + print(chunk, end='') |
| 55 | + |
| 56 | +asyncio.run(async_main(toolio_mm)) |
0 commit comments