diff --git a/.gitignore b/.gitignore index b19c34b..1664006 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,10 @@ __pycache__/ *.tar.gz 2025-T1/Food-Image-Classifier/scripts/data/food-101/images/ -2025-T2/Multi-Image-Classifier/scripts/VFN/Images/ \ No newline at end of file +2025-T2/Multi-Image-Classifier/scripts/VFN/Images/.ipynb_checkpoints/ +_env +.env +.chroma/ +.chroma/ +_env +.ipynb_checkpoints/ diff --git a/2025-T2/nutribot_training.ipynb b/2025-T2/nutribot_training.ipynb index 24fb465..b89c127 100644 --- a/2025-T2/nutribot_training.ipynb +++ b/2025-T2/nutribot_training.ipynb @@ -24,7 +24,7 @@ "from groq import Groq\n", "from dotenv import load_dotenv\n", "\n", - "load_dotenv()" + "load_dotenv()\n" ] }, { @@ -36,29 +36,26 @@ "source": [ "class RAG:\n", " def __init__(self, collection_name=\"aus_food_nutrition\"):\n", - " self.chroma_client = chromadb.CloudClient(\n", - " tenant='a0123436-2e87-4752-8983-73168aafe2e9',\n", - " database='nutribot',\n", - " api_key=os.environ.get(\"CHROMA_API_KEY\"),\n", - " )\n", + " self.chroma_client = chromadb.PersistentClient(path=\".chroma\")\n", " self.collection = self.chroma_client.get_or_create_collection(name=collection_name)\n", " self.count = self.collection.count()\n", "\n", - "\n", " def add_documents(self, docs):\n", " for doc in docs:\n", " _id = f\"id{self.count}\"\n", " self.collection.upsert(ids=[_id], documents=[doc])\n", " self.count += 1\n", "\n", - " def retrieve(self, prompt, n_results=2):\n", - " res = self.collection.query(query_texts=[prompt], n_results=n_results)\n", - " return res.get(\"documents\", [[]])[0]" + " def retrieve(self, prompt, n_results=5):\n", + " return self.collection.query(\n", + " query_texts=[prompt],\n", + " n_results=n_results\n", + " )\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "7475d430", "metadata": {}, "outputs": [], @@ -67,12 +64,12 @@ "rag.add_documents([\n", " \"Vegemite is a popular Australian spread made from brewers' yeast extract.\",\n", " \"Kangaroo meat is a lean source of protein, low in fat.\",\n", - "])" + "])\n" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "id": "99a1118e", "metadata": {}, "outputs": [ @@ -80,7 +77,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Retrieved texts: ['The Australian Dietary Guidelines advise reducing the intake of processed foods and sugary drinks.', 'A balanced diet, as recommended by Australian Dietary Guidelines, includes moderate portions of protein and whole grains.']\n" + "Retrieved texts: {'ids': [['id142', 'id381']], 'embeddings': None, 'documents': [['Read about how much you need to eat each day , or use the energy requirements calculator to estimate what’s right for you.', 'Read about how much you need to eat each day , or use the energy requirements calculator to estimate what’s right for you.']], 'uris': None, 'included': ['metadatas', 'documents', 'distances'], 'data': None, 'metadatas': [[None, None]], 'distances': [[1.2633346319198608, 1.2633346319198608]]}\n" ] } ], @@ -118,109 +115,218 @@ "\"\"\"\n", "relevant_texts = rag.retrieve(prompt, n_results=2)\n", "prompt = prompt + f\"\\tUse this as context for answering: {relevant_texts}\"\n", - "print(\"Retrieved texts:\", relevant_texts)" + "print(\"Retrieved texts:\", relevant_texts)\n" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "id": "84a86eb0", "metadata": {}, + "outputs": [], + "source": [ + "def keyword_score(query, doc):\n", + " query_words = set(query.lower().split())\n", + " doc_words = set(doc.lower().split())\n", + " return len(query_words & doc_words)\n", + "\n", + "def best_chunks(query, initial_k=10, final_k=4):\n", + " rag = RAG()\n", + " res = rag.retrieve(query, n_results=initial_k)\n", + "\n", + " docs = res.get(\"documents\", [[]])[0]\n", + " distances = res.get(\"distances\", [[]])[0] if res.get(\"distances\") else [999] * len(docs)\n", + "\n", + " ranked = []\n", + " seen = set()\n", + "\n", + " for i, doc in enumerate(docs):\n", + " clean_doc = doc.strip()\n", + "\n", + " if not clean_doc:\n", + " continue\n", + "\n", + " if clean_doc in seen:\n", + " continue\n", + " seen.add(clean_doc)\n", + "\n", + " dist = distances[i] if i < len(distances) else 999\n", + " overlap = keyword_score(query, clean_doc)\n", + "\n", + " ranked.append({\n", + " \"doc\": clean_doc,\n", + " \"distance\": dist,\n", + " \"overlap\": overlap\n", + " })\n", + "\n", + " ranked = sorted(\n", + " ranked,\n", + " key=lambda x: (x[\"distance\"], -x[\"overlap\"], len(x[\"doc\"]))\n", + ")\n", + "\n", + " return ranked[:final_k]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a7ffca05", + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Based on the given health condition and the provided knowledge, I'll generate a weekly plan for the user. \n", "\n", - "Given the user's obesity level is categorized as Overweight_Level_II with 10% confidence in obesity prediction, and the user has diabetes with 90% confidence, here's a suggested weekly plan:\n", + "Query: What proportion of Australian children aged 2 to 17 do not eat enough vegetables?\n", + "================================================================================\n", + "\n", + "Selected Chunk 1\n", + "distance = 0.5301\n", + "keyword overlap = 9\n", + "Only 1 in 13 adults eat enough fruit and vegetables 94% of children aged 2 to 17 don't eat enough vegetables One-third of our energy intake comes from food we don’t need Source: Australian Bureau of Statistics Resources Resources 03:08 Improving food and nutrition in aged care: Heywood Rural Health 1 August 2025 Story Meet Sharron, Brad and Leigh from Heywood Rural Health aged care home.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 2\n", + "distance = 0.7319\n", + "keyword overlap = 3\n", + "Food and nutrition in Australia Research has found that Australians of all ages generally: don’t eat enough of the 5 food groups – vegetables, fruit, grains, meat and alternatives, and dairy products and alternatives eat too much discretionary food (food we don’t need that’s high in energy and low in nutrients), like pastries, ice cream or chips eat too much sugar, salt and saturated fat.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 3\n", + "distance = 0.8010\n", + "keyword overlap = 2\n", + "6% of children met both the fruit and vegetables recommendations.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 4\n", + "distance = 0.8178\n", + "keyword overlap = 4\n", + "National research includes the: National Health Survey – surveyed about 21,000 people about various aspects of their health, including their diet Australian Health Survey – the largest health survey ever done in Australia, which looked at the years 2011 to 2013 National Aboriginal and Torres Strait Islander Health Survey – surveyed one-third of all First Nations people about their health, including fruit and vegetable consumption Growing up in Australia – a longitudinal study of 10,000 Australian children Australian Institute of Health and Welfare analysis of nutrition across the life stages and poor diet among Australians.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Query: Which section covers actions to reduce harmful ingredients in processed foods?\n", + "================================================================================\n", + "\n", + "Selected Chunk 1\n", + "distance = 0.4564\n", + "keyword overlap = 4\n", + "Reducing harmful ingredients Find out how we’re working with industry to make processed foods healthier.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 2\n", + "distance = 0.4938\n", + "keyword overlap = 5\n", + "Read about our work with industry to reduce harmful ingredients from processed and manufactured foods .\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 3\n", + "distance = 0.7671\n", + "keyword overlap = 4\n", + "Find out what we’re doing to reduce sugar, salt and saturated fat in processed and manufactured foods and drinks.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 4\n", + "distance = 0.7711\n", + "keyword overlap = 4\n", + "This program aims to reduce the amount of sugar, sodium and saturated fat in processed and manufactured foods.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Query: Name one 2025 resource related to the Healthy Food Partnership.\n", + "================================================================================\n", "\n", - "```json\n", - "{\n", - " \"suggestion\": \"Maintain a strict diet and regular exercise to manage your diabetes and work towards healthy weight management.\",\n", - " \"weekly_plan\": [\n", - " {\n", - " \"week\": 1,\n", - " \"target_calories_per_day\": 1700,\n", - " \"focus\": \"Balancing macronutrients (proteins, healthy fats, complex carbohydrates)\",\n", - " \"workouts\": [\n", - " \"Monday: 30 minutes of brisk walking\",\n", - " \"Wednesday: Bodyweight exercises (20 reps of push-ups, squats, lunges)\",\n", - " \"Friday: 20 minutes of cycling\"\n", - " ],\n", - " \"meal_notes\": \"Eat frequent, balanced meals with portion control. Include whole grains, lean proteins, and plenty of vegetables and fruits.\",\n", - " \"reminders\": [\n", - " \"Drink at least 8 glasses of water daily\",\n", - " \"Monitor and record your blood glucose levels\"\n", - " ]\n", - " },\n", - " {\n", - " \"week\": 2,\n", - " \"target_calories_per_day\": 1800,\n", - " \"focus\": \"Incorporating healthy fats into your diet\",\n", - " \"workouts\": [\n", - " \"Monday: Swimming for 30 minutes\",\n", - " \"Wednesday: Resistance training with dumbbells\",\n", - " \"Friday: Yoga for flexibility\"\n", - " ],\n", - " \"meal_notes\": \"Make healthy choices like avocados, nuts, and seeds. Replace processed foods with whole foods.\",\n", - " \"reminders\": [\n", - " \"Increase your fiber intake by 5 grams daily\",\n", - " \"Limit sugary drinks and fast food\"\n", - " ]\n", - " },\n", - " {\n", - " \"week\": 3,\n", - " \"target_calories_per_day\": 1900,\n", - " \"focus\": \"Strengthening muscles through progressive overload\",\n", - " \"workouts\": [\n", - " \"Monday: Weightlifting (30 minutes)\",\n", - " \"Wednesday: High-Intensity Interval Training (HIIT)\",\n", - " \"Friday: Resting or active recovery\"\n", - " ],\n", - " \"meal_notes\": \"Stay hydrated by drinking water-rich foods like watermelon and cucumbers. Eat regular meals to avoid spikes in blood sugar levels.\",\n", - " \"reminders\": [\n", - " \"Eat a small, balanced meal or snack 1 hour before exercise\",\n", - " \"Avoid excessive stress by getting enough sleep (7-8 hours) each night\"\n", - " ]\n", - " },\n", - " {\n", - " \"week\": 4,\n", - " \"target_calories_per_day\": 2000,\n", - " \"focus\": \"Fine-tuning your diet and exercise routine for better management of blood sugar levels and weight\",\n", - " \"workouts\": [\n", - " \"Monday: 45 minutes of jogging\",\n", - " \"Wednesday: Flexibility and stretching exercises\",\n", - " \"Friday: Swimming laps for 30 minutes\"\n", - " ],\n", - " \"meal_notes\": \"Eat a variety of nutrient-dense foods to ensure you get enough vitamins and minerals. Plan your meals, especially when dining out or eating at a social event.\",\n", - " \"reminders\": [\n", - " \"Monitor and adjust your carbohydrate intake according to your blood glucose levels\",\n", - " \"Exercise regularly, aiming for at least 150 minutes of moderate aerobic activity weekly\"\n", - " ]\n", - " }\n", - " ]\n", - "}\n", - "```\n", + "Selected Chunk 1\n", + "distance = 0.4561\n", + "keyword overlap = 4\n", + "It is part of the Healthy Food Partnership.\n", + "--------------------------------------------------------------------------------\n", "\n", - "This plan focuses on gradual weight loss and improved diabetes management by balancing macronutrients, incorporating healthy fats, and engaging in regular physical activity. The user is reminded to drink sufficient water, monitor their blood glucose levels, and limit their intake of processed foods and sugary drinks.\n" + "Selected Chunk 2\n", + "distance = 0.4704\n", + "keyword overlap = 4\n", + "Healthy Food Partnership Program – Executive Committee communique – February 2025 20 February 2025 Meeting minutes This communique presents a summary of discussions and decisions by the Healthy Food Partnership Program Executive Committee from their 17th meeting held on 20 February 2025.\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 3\n", + "distance = 0.5963\n", + "keyword overlap = 4\n", + "We’re also working with the food industry to promote healthy eating and reduce the risk of disease, through the Healthy Food Partnership’s Reformulation Program .\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Selected Chunk 4\n", + "distance = 0.6343\n", + "keyword overlap = 4\n", + "Through the Healthy Food Partnership , we work with the food industry and the public health sector to make healthier food choices easier and more accessible.\n", + "--------------------------------------------------------------------------------\n" ] } ], "source": [ + "def inspect_best_chunks(query, initial_k=10, final_k=4):\n", + " best_chunk = best_chunks(query, initial_k=initial_k, final_k=final_k)\n", + "\n", + " print(f\"\\nQuery: {query}\")\n", + " print(\"=\" * 80)\n", + "\n", + " for i, item in enumerate(best_chunk, 1):\n", + " print(f\"\\nSelected Chunk {i}\")\n", + " print(f\"distance = {item['distance']:.4f}\")\n", + " print(f\"keyword overlap = {item['overlap']}\")\n", + " print(item[\"doc\"])\n", + " print(\"-\" * 80)\n", + "\n", + "inspect_best_chunks(\"What proportion of Australian children aged 2 to 17 do not eat enough vegetables?\")\n", + "inspect_best_chunks(\"Which section covers actions to reduce harmful ingredients in processed foods?\")\n", + "inspect_best_chunks(\"Name one 2025 resource related to the Healthy Food Partnership.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "86de40a9", + "metadata": {}, + "outputs": [], + "source": [ + "def format_context(chunks):\n", + " return \"\\n\\n\".join(\n", + " [f\"Context {i+1}:\\n{item['doc']}\" for i, item in enumerate(chunks)]\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e2b1899e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "How can I assist you today?\n" + ] + } + ], + "source": [ + "from dotenv import load_dotenv\n", + "import os\n", + "from groq import Groq\n", + "\n", + "load_dotenv()\n", + "\n", "client = Groq(api_key=os.environ.get(\"GROQ_API_KEY\"))\n", "\n", "chat_completion = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"user\",\n", - " \"content\": prompt,\n", + " \"content\": \"Hello\",\n", " }\n", " ],\n", " model=\"llama-3.1-8b-instant\",\n", ")\n", "\n", - "print(chat_completion.choices[0].message.content)" + "print(chat_completion.choices[0].message.content)\n" ] }, { @@ -237,45 +343,31 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 9, "id": "cc4c1739", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6030c49429484d7dac4660dfb2c96c9b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "VBox(children=(HTML(value='
\"\n", " )\n", - " try:\n", - " # relevant_texts = rag.retrieve(prompt, n_results=2)\n", "\n", + " try:\n", " resp = client.chat.completions.create(\n", " messages=[\n", - " {\"role\":\"system\",\"content\":\"You are a multiple-choice grader. Return the answer in EXACTLY this format and nothing else:\\n@@ANSWER=@@\"},\n", - " {\"role\": \"user\", \"content\": f'{prompt}'}\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": (\n", + " \"You are a multiple-choice grader. \"\n", + " \"Return the answer in EXACTLY this format and nothing else:\\n\"\n", + " \n", + " ),\n", + " },\n", + " {\"role\": \"user\", \"content\": prompt},\n", " ],\n", " model=model,\n", + " temperature=0,\n", " )\n", + "\n", " ans = (resp.choices[0].message.content or \"\").strip()\n", - " # print(ans)\n", - " m = re.compile(r\"@@ANSWER=([A-D])@@\\s*$\", flags=re.MULTILINE).search(ans.strip())\n", + " print(ans)\n", + "\n", + " m = re.search(r\"ANSWER=([A-D])\", ans)\n", " return m.group(1).upper() if m else None\n", + "\n", " except Exception as e:\n", " print(f\"API error: {e}\")\n", " raise APIUnavailable from e\n", "\n", + "\n", "def run_eval(num_questions):\n", - " dataset = load_dataset(\"Idavidrein/gpqa\", \"gpqa_diamond\")[\"train\"]\n", + " try:\n", + " dataset = load_dataset(\"Idavidrein/gpqa\", \"gpqa_diamond\", split=\"train\")\n", + " except Exception as e:\n", + " print(\"Dataset loading error:\", e)\n", + " return\n", + "\n", " correct = 0\n", " attempted = 0\n", "\n", " for i in range(min(num_questions, len(dataset))):\n", " q, ca, inc = _extract_item(dataset[i])\n", + "\n", " if not (q and ca and len(inc) >= 3):\n", + " print(f\"Skipping Q{i+1}: missing fields\")\n", " continue\n", "\n", " options = inc[:3] + [ca]\n", @@ -455,6 +512,7 @@ " try:\n", " pred = evaluate_gpqa(q, options)\n", " except APIUnavailable:\n", + " print(\"Stopping because API is unavailable.\")\n", " break\n", "\n", " attempted += 1\n", @@ -469,8 +527,9 @@ " pct = 100.0 * correct / attempted\n", " print(f\"\\nFinal Score: {correct}/{attempted} ({pct:.1f}%)\")\n", "\n", + "\n", "if __name__ == \"__main__\":\n", - " run_eval(100)" + " run_eval(10)\n" ] }, { @@ -483,39 +542,223 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "6b33ae89", "metadata": {}, "outputs": [], "source": [ - "def chat_with_rag(prompt):\n", - " rag = RAG()\n", - " client = Groq(api_key=os.environ.get(\"GROQ_API_KEY\"))\n", - " relevant_texts = rag.retrieve(prompt, n_results=5)\n", - " prompt = prompt + f\"\\tUse this as context for answering: {relevant_texts}\"\n", - " chat_completion = client.chat.completions.create(\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": prompt,\n", - " }\n", - " ],\n", - " model=\"llama-3.1-8b-instant\",\n", - " )\n", - " return chat_completion.choices[0].message.content\n", - "\n", - "def simple_chat(prompt):\n", - " client = Groq(api_key=os.environ.get(\"GROQ_API_KEY\"))\n", - " chat_completion = client.chat.completions.create(\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": prompt,\n", - " }\n", - " ],\n", - " model=\"llama-3.1-8b-instant\",\n", + "import os\n", + "from groq import Groq\n", + "\n", + "\n", + "def improve_retrieval(user_question, retrieved_items, top_k=3):\n", + " \"\"\"\n", + " Rank retrieved chunks using simple runtime heuristics:\n", + " - prefer smaller Chroma distance\n", + " - reward keyword overlap with the user question\n", + " - keep the final top_k chunks\n", + " \"\"\"\n", + "\n", + " question_words = set(\n", + " word.strip(\".,?!:;()[]{}'\\\"\").lower()\n", + " for word in user_question.split()\n", + " if len(word.strip()) > 2\n", " )\n", - " return chat_completion.choices[0].message.content" + "\n", + " ranked = []\n", + "\n", + " for item in retrieved_items:\n", + " doc = item.get(\"doc\", \"\").strip()\n", + " distance = item.get(\"distance\", None)\n", + "\n", + " if not doc:\n", + " continue\n", + "\n", + " doc_lower = doc.lower()\n", + "\n", + " overlap = sum(1 for word in question_words if word in doc_lower)\n", + "\n", + " # Smaller distance is better. If missing, give weaker default.\n", + " distance_score = 0 if distance is None else (1 / (1 + distance))\n", + "\n", + " # Combined score: distance + keyword overlap\n", + " score = distance_score + (0.15 * overlap)\n", + "\n", + " ranked.append({\n", + " \"doc\": doc,\n", + " \"distance\": distance,\n", + " \"overlap\": overlap,\n", + " \"score\": score\n", + " })\n", + "\n", + " ranked.sort(key=lambda x: x[\"score\"], reverse=True)\n", + " return ranked[:top_k]\n", + "\n", + "\n", + "def chat_with_rag(user_question):\n", + " rag = RAG()\n", + "\n", + " api_key = os.environ.get(\"GROQ_API_KEY\")\n", + " if not api_key:\n", + " raise RuntimeError(\"GROQ_API_KEY missing\")\n", + "\n", + " client = Groq(api_key=api_key)\n", + "\n", + " # Clean query before retrieval\n", + " clean_question = user_question.replace(\"Keep concise\", \"\").strip()\n", + "\n", + " # Retrieve from Chroma\n", + " results = rag.retrieve(clean_question, n_results=10)\n", + "\n", + " retrieved_items = []\n", + "\n", + " if results:\n", + " raw_docs = results.get(\"documents\", [[]])[0]\n", + " raw_distances = results.get(\"distances\", [[]])[0] if \"distances\" in results else []\n", + "\n", + " for i, doc in enumerate(raw_docs):\n", + " if not doc:\n", + " continue\n", + "\n", + " doc = doc.strip()\n", + "\n", + " # Filter weak / noisy chunks\n", + " if len(doc) <= 50:\n", + " continue\n", + "\n", + " distance = raw_distances[i] if i < len(raw_distances) else None\n", + "\n", + " retrieved_items.append({\n", + " \"doc\": doc,\n", + " \"distance\": distance\n", + " })\n", + "\n", + " # Remove exact duplicate chunks\n", + " deduped_items = []\n", + " seen_docs = set()\n", + "\n", + " for item in retrieved_items:\n", + " if item[\"doc\"] not in seen_docs:\n", + " seen_docs.add(item[\"doc\"])\n", + " deduped_items.append(item)\n", + "\n", + " # Re-rank best chunks for runtime context\n", + " selected_chunks = improve_retrieval(clean_question, deduped_items, top_k=3)\n", + "\n", + " # Build final context\n", + " if selected_chunks:\n", + " context_text = \"\\n\\n\".join(item[\"doc\"] for item in selected_chunks)\n", + " else:\n", + " context_text = \"No relevant context found.\"\n", + "\n", + " final_prompt = f\"\"\"\n", + "You are a helpful NutriHelp assistant.\n", + "\n", + "Use only the retrieved context to answer the question.\n", + "\n", + "Rules:\n", + "- Return exactly ONE short sentence\n", + "- Give only the most direct answer\n", + "- Do not explain or add extra details\n", + "- Do not restate the question\n", + "- Do not write \"Answer:\"\n", + "- If the question asks for a section or page name, return only the section or page name\n", + "- If the question asks for one resource, return only the resource name\n", + "- If the answer includes a statistic, include the full statistic in one sentence\n", + "- If the answer is not supported by the context, say exactly:\n", + "I could not find the answer in the provided context.\n", + "\n", + "Examples:\n", + "94% of children aged 2 to 17 don't eat enough vegetables.\n", + "Only 1 in 13 adults eat enough fruit and vegetables.\n", + "About one-third of total energy intake comes from discretionary foods.\n", + "On the Food and nutrition topic page under \"Eating what your body needs.\"\n", + "The \"Reducing harmful ingredients\" section in the Food and nutrition topic.\n", + "Healthy Food Partnership Program – Executive Committee communique – February 2025.\n", + "\n", + "Context:\n", + "{context_text}\n", + "\n", + "Question:\n", + "{clean_question}\n", + "\n", + "Answer:\n", + "\"\"\"\n", + "\n", + " try:\n", + " chat_completion = client.chat.completions.create(\n", + " messages=[{\"role\": \"user\", \"content\": final_prompt}],\n", + " model=\"llama-3.1-8b-instant\",\n", + " temperature=0\n", + " )\n", + "\n", + " answer = chat_completion.choices[0].message.content.strip()\n", + "\n", + " if answer.lower().startswith(\"answer:\"):\n", + " answer = answer[len(\"answer:\"):].strip()\n", + "\n", + " if not answer:\n", + " answer = \"I could not find the answer in the provided context.\"\n", + "\n", + " return answer.replace(\"’\", \"'\")\n", + "\n", + " except Exception as e:\n", + " return f\"Error: {str(e)}\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a02f3df3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "======================================================================\n", + "QUESTION: What proportion of Australian children aged 2 to 17 do not eat enough vegetables?\n", + "ANSWER: 94% of children aged 2 to 17 don't eat enough vegetables.\n", + "\n", + "======================================================================\n", + "QUESTION: Roughly what proportion of Australian adults eat enough fruit and vegetables?\n", + "ANSWER: Only 1 in 13 adults eat enough fruit and vegetables.\n", + "\n", + "======================================================================\n", + "QUESTION: About what share of Australia’s total energy intake comes from foods we don’t need?\n", + "ANSWER: About one-third of total energy intake comes from discretionary foods.\n", + "\n", + "======================================================================\n", + "QUESTION: Where on the department website can I read about getting the nutrients I need, at any age?\n", + "ANSWER: On the Food and nutrition topic page under \"Eating what your body needs.\"\n", + "\n", + "======================================================================\n", + "QUESTION: Which section covers actions to reduce harmful ingredients in processed foods?\n", + "ANSWER: The \"Reducing harmful ingredients\" section.\n", + "\n", + "======================================================================\n", + "QUESTION: Name one 2025 resource related to the Healthy Food Partnership.\n", + "ANSWER: Healthy Food Partnership Program – Executive Committee communique – February 2025.\n", + "\n" + ] + } + ], + "source": [ + "questions = [\n", + " \"What proportion of Australian children aged 2 to 17 do not eat enough vegetables?\",\n", + " \"Roughly what proportion of Australian adults eat enough fruit and vegetables?\",\n", + " \"About what share of Australia’s total energy intake comes from foods we don’t need?\",\n", + " \"Where on the department website can I read about getting the nutrients I need, at any age?\",\n", + " \"Which section covers actions to reduce harmful ingredients in processed foods?\",\n", + " \"Name one 2025 resource related to the Healthy Food Partnership.\"\n", + "]\n", + "\n", + "for q in questions:\n", + " print(\"=\" * 70)\n", + " print(f\"QUESTION: {q}\")\n", + " answer = chat_with_rag(q)\n", + " print(f\"ANSWER: {answer}\")\n", + " print()\n" ] }, { @@ -544,8 +787,8 @@ }, { "cell_type": "code", - "execution_count": 15, - "id": "51510a43", + "execution_count": 14, + "id": "f1251e53", "metadata": {}, "outputs": [ { @@ -553,49 +796,40 @@ "output_type": "stream", "text": [ "---------- RAG response ----------\n", - "-> Q1:According to the Australian Health Survey and the context provided, 94% of children aged 2 to 17 in Australia do not eat enough vegetables.\n", - "-> Q2:According to the provided information, only 1 in 13 (approximately 7.69%) Australian adults eat enough fruit and vegetables.\n", - "-> Q3:About 25-33% of Australia's total energy intake comes from discretionary foods.\n", - "-> Q4:On the department website, you can read about getting the nutrients you need at any age by visiting the page that discusses detailed information about individual nutrients.\n", - "-> Q5:\"Reducing harmful ingredients from processed and manufactured foods.\"\n", - "-> Q6:Unfortunately, I have no access to the 2025 resource related to the Healthy Food Partnership mentioned. However, one 2024 resource, isn't related to the question so I will look for a 2024 or 2023 resource mentioned:\n", - "\n", - "However, in December 2023, a resource related to the Healthy Food Partnership was mentioned - \n", - "\n", - "Healthy Food Partnership Program – Executive Committee communique – December 2023 2023 Meeting minutes\n", - "\n", - "---------- Non RAG response ----------\n", - "-> Q1:Unfortunately, I don't have the exact and up-to-date information on this topic. However, I can suggest some related statistics from 2019-2020 that might provide a rough idea:\n", + "-> Q1: 94% of children aged 2 to 17 don't eat enough vegetables.\n", + "-> Q2: Only 1 in 13 adults eat enough fruit and vegetables.\n", + "-> Q3: About one-third of total energy intake comes from discretionary foods.\n", + "-> Q4: On the Food and nutrition topic page under \"Eating what your body needs.\"\n", + "-> Q5: The \"Reducing harmful ingredients\" section.\n", + "-> Q6: Healthy Food Partnership Program – Executive Committee communique – February 2025.\n", "\n", - "* A 2019-2020 survey by the Australian Bureau of Statistics (ABS) found that:\n", - " + 22% of children aged 2-3 years and 27% of children aged 4-8 years did not meet the recommended daily intake of vegetables.\n", - " + 35% of children aged 9-13 years and 39% of teenagers aged 14-17 years did not meet the recommended daily intake of vegetables.\n", - "\n", - "Please note that these statistics might not reflect the current situation and are based on available data from 2019-2020.\n", - "-> Q2:According to 2019-2020 Australian data, only about 8% of adults meet the daily recommended intake of five serves of vegetables, and around 3% meet the daily recommended intake of two serves of fruits.\n", - "-> Q3:About 30-40% of Australia's total energy intake is estimated to come from discretionary foods, according to the Australian Bureau of Statistics (2011-2012 data).\n", - "-> Q4:I'm a large language model, I don't have specific department websites to direct you to. However, you can usually find information about getting the necessary nutrients at any age on a government website's health or nutrition section, often within the 'Healthy Living', 'Nutrition', or 'Wellness' category.\n", - "-> Q5:This section would likely be found under 'Health and Nutrition' or 'Food Security' within a broader topic.\n", - "-> Q6:I cannot confirm any current information regarding 2025 resources related to the Healthy Food Partnership. However, you might want to look at The Food and Agriculture Organization of the UN related to this topic. They often publish articles about their efforts.\n" + "---------- Non RAG response ----------\n" ] } ], "source": [ + "test_questions = [\n", + " \"What proportion of Australian children aged 2 to 17 do not eat enough vegetables? Keep concise\",\n", + " \"Roughly what proportion of Australian adults eat enough fruit and vegetables? Keep concise\",\n", + " \"About what share of Australia’s total energy intake comes from foods we don’t need (discretionary foods)? Keep concise\",\n", + " \"Where on the department website can I read about getting the nutrients I need, at any age? Keep concise\",\n", + " \"Which section covers actions to reduce harmful ingredients in processed foods? Keep concise\",\n", + " \"Name one 2025 resource mentioned on the Food and nutrition page related to the Healthy Food Partnership. Keep concise\"\n", + "]\n", + "\n", "print(\"---------- RAG response ----------\")\n", - "print(\"-> Q1:\" + chat_with_rag(\"What proportion of Australian children aged 2 to 17 do not eat enough vegetables? Keep concise\"))\n", - "print(\"-> Q2:\" + chat_with_rag(\"Roughly what proportion of Australian adults eat enough fruit and vegetables? Keep concise\"))\n", - "print(\"-> Q3:\" + chat_with_rag(\"About what share of Australia’s total energy intake comes from foods we don’t need (discretionary foods)? Keep concise\"))\n", - "print(\"-> Q4:\" + chat_with_rag(\"Where on the department website can I read about getting the nutrients I need, at any age? Keep concise\"))\n", - "print(\"-> Q5:\" + chat_with_rag(\"Which section covers actions to reduce harmful ingredients in processed foods? Keep concise\"))\n", - "print(\"-> Q6:\" + chat_with_rag(\"Name one 2025 resource mentioned on the Food and nutrition page related to the Healthy Food Partnership. Keep concise\"))\n", + "for i, question in enumerate(test_questions, 1):\n", + " try:\n", + " print(f\"-> Q{i}: {chat_with_rag(question)}\")\n", + " except Exception as e:\n", + " continue\n", "\n", "print(\"\\n---------- Non RAG response ----------\")\n", - "print(\"-> Q1:\" + simple_chat(\"What proportion of Australian children aged 2 to 17 do not eat enough vegetables? Keep concise\"))\n", - "print(\"-> Q2:\" + simple_chat(\"Roughly what proportion of Australian adults eat enough fruit and vegetables? Keep concise\"))\n", - "print(\"-> Q3:\" + simple_chat(\"About what share of Australia’s total energy intake comes from foods we don’t need (discretionary foods)? Keep concise\"))\n", - "print(\"-> Q4:\" + simple_chat(\"Where on the department website can I read about getting the nutrients I need, at any age? Keep concise\"))\n", - "print(\"-> Q5:\" + simple_chat(\"Which section covers actions to reduce harmful ingredients in processed foods? Keep concise\"))\n", - "print(\"-> Q6:\" + simple_chat(\"Name one 2025 resource mentioned on the Food and nutrition page related to the Healthy Food Partnership. Keep concise\"))" + "for i, question in enumerate(test_questions, 1):\n", + " try:\n", + " print(f\"-> Q{i}: {simple_chat(question)}\")\n", + " except Exception as e:\n", + " continue\n" ] }, { @@ -614,26 +848,38 @@ "outputs": [], "source": [ "import json\n", + "import os\n", "\n", "rag = RAG()\n", - "\n", "docs = []\n", - "with open(\"document-parser/data/sentences.jsonl\", \"r\", encoding=\"utf-8\") as f:\n", - " for line in f:\n", - " rec = json.loads(line)\n", - " sent = rec.get(\"sentence\", \"\").strip()\n", - " if sent:\n", - " docs.append(sent)\n", "\n", - "# docs = docs[:1000]\n", + "file_path = \"document-parser/data/sentences.jsonl\"\n", + "\n", + "if not os.path.exists(file_path):\n", + " print(\"File not found:\", file_path)\n", + "else:\n", + " with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", + " for line_number, line in enumerate(f, start=1):\n", + " try:\n", + " rec = json.loads(line)\n", + " sent = rec.get(\"sentence\", \"\").strip()\n", + " if sent:\n", + " docs.append(sent)\n", + " except json.JSONDecodeError:\n", + " print(f\"Skipping invalid JSON on line {line_number}\")\n", + "\n", + " print(\"Number of documents loaded:\", len(docs))\n", + "\n", + " # docs = docs[:1000] # optional for testing\n", "\n", - "rag.add_documents(docs)" + " rag.add_documents(docs)\n", + " print(\"Documents added successfully.\")\n" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -647,7 +893,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.13.5" } }, "nbformat": 4,