Skip to content

Commit

Permalink
cleaned code, and modified README
Browse files Browse the repository at this point in the history
  • Loading branch information
Raghav010 committed Nov 6, 2023
1 parent c181684 commit 91ca32b
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 96 deletions.
22 changes: 18 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
## Instructions for Setup
# **MedMini**

## _Runs within_ **< 3GB RAM | 0 vRAM**
## _Inference Time_ **< 3 sec**
### A Lightweight architecture for an Answering System on medical data based on LLMs, designed to run on **edge devices**.
### Entirely on-device processing
### Model Size on Disk: 500 + 250 MB

![Pipeline](./media/diagram.png)




## **Instructions for Setup**

1. `cd mashqa_data`
2. `python format.py`
3. `cd ..; python dbGen.py`

## Instructions to run
## **Instructions to run**
1. `cd App; npx expo start`
2. `python backend.py`

## Future work
1. Dockerisation of the project to allow it to run from any IP address without having to do many changes.
## **Future work**
1. Dockerisation of the project to allow it to run from any IP address without having to do many changes.
2. Improve the RAG algorithm without compromising on efficiency
4 changes: 2 additions & 2 deletions dbGen.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from langchain.document_loaders import JSONLoader,TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

med_loader=TextLoader('mashqa_data/sentences.txt')

med_data = med_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 50)
all_splits = text_splitter.split_documents(med_data)

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cuda'}
Expand Down
Binary file added media/diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions media/diagram.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added medmini.pdf
Binary file not shown.
175 changes: 85 additions & 90 deletions medmini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@


def formatPrompt(prompt,context):
# prompts
# fp=f'Use the context to answer the question.Incorporate the context in your answer.\nContext: {context}\nQuestion: {prompt}\nAnswer: '
# fp = f'Summarize the following:\n {context}'
#fp = f'Summarize the following:\n {context}'
# fp = f'Summarize the following:\n {context}'
fp=f'{context}{prompt}'
return fp

Expand All @@ -31,125 +32,119 @@ def infer(prompt):


vectordb=Chroma(persist_directory='./med_db',embedding_function=hf)
print(vectordb._collection.count())
# print(vectordb._collection.count())

docs = vectordb.similarity_search(prompt,k=6)
print(len(docs))
# print(len(docs))
# print(docs)
context=' '.join(d.page_content for d in docs)
print(context)


#### Base GPT2 unquantized
#quantized=False
#extreme_quantization=False
#low_cpu=True

#tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
#model = GPT2Model.from_pretrained('gpt2',low_cpu_mem_usage=low_cpu,device_map='cpu',load_in_8bit=quantized,load_in_4bit=extreme_quantization)

## args={"low_cpu_mem_usage":low_cpu,"device":'cpu',"load_in_8bit":quantized,"load_in_4bit":extreme_quantization,"torch_dtype":torch.float32}
#args={"low_cpu_mem_usage":low_cpu,"device_map":'cpu',"load_in_8bit":quantized,"load_in_4bit":extreme_quantization,"torch_dtype":torch.float32}
#pipe = pipeline('text-generation', model='gpt2',model_kwargs=args)

#output=pipe(formatPrompt(prompt, context), max_new_tokens=60)




#### Base GPT2 quantized
# tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
# tokenizer.pad_token = tokenizer.eos_token

# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16
# )

# model = AutoModelForCausalLM.from_pretrained(
# "gpt2",
# quantization_config=bnb_config,
# trust_remote_code=True
# )
# model.config.use_cache = False

# pipe=pipeline('text-generation',model=model,tokenizer=tokenizer)
# output=pipe(formatPrompt(prompt,context))
# print(context)

def useModel(model_name):

#### Phi 1.5 quantized/unquantized
#tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
## tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", trust_remote_code=True)
#tokenizer.pad_token = tokenizer.eos_token
if model_name=='summarizer':

#bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16
#)
#### Summarization
# from transformers import pipeline

#model = AutoModelForCausalLM.from_pretrained(
# "microsoft/phi-1_5",
## "google/flan-t5-small",
# # quantization_config=bnb_config,
# trust_remote_code=True
#)
summarizer = pipeline("summarization", model="Falconsai/medical_summarization")
output = summarizer(context, max_length=512, min_length=32, do_sample=False)

#pipe=pipeline('text-generation',model=model,tokenizer=tokenizer)
#output=pipe(formatPrompt(prompt,context), max_new_tokens=60)
elif model_name=='gpt2':

### Phi 1.5 official
# torch.set_default_device("cpu")
# model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True,device_map='cuda')
# tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
# inputs = tokenizer(formatPrompt(prompt,context), return_tensors="pt", return_attention_mask=False)
#### Base GPT2 unquantized
low_cpu=True

# outputs = model.generate(**inputs)
# output = tokenizer.batch_decode(outputs)[0]
# print(text)
# args={"low_cpu_mem_usage":low_cpu,"device":'cpu',"load_in_8bit":quantized,"load_in_4bit":extreme_quantization,"torch_dtype":torch.float32}
args={"low_cpu_mem_usage":low_cpu,"device_map":'cpu',"torch_dtype":torch.float32}
pipe = pipeline('text-generation', model='gpt2',model_kwargs=args)

output=pipe(formatPrompt(prompt, context), max_new_tokens=60)

elif model_name=='gpt2_quantized':

#### Base GPT2 quantized
tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)

# from transformers import T5Tokenizer, T5ForConditionalGeneration
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False

# tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
# model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
pipe=pipeline('text-generation',model=model,tokenizer=tokenizer)
output=pipe(formatPrompt(prompt,context))

# input_text = formatPrompt(prompt, context)
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids
elif model_name=='phi1.5_quantized':

# outputs= model.generate(input_ids, max_new_tokens=400)
# output = tokenizer.decode(outputs[0])
#### Phi 1.5 quantized/unquantized
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-1_5",
# "google/flan-t5-small",
# quantization_config=bnb_config,
trust_remote_code=True
)

pipe=pipeline('text-generation',model=model,tokenizer=tokenizer)
output=pipe(formatPrompt(prompt,context), max_new_tokens=60)

elif model_name=='phi1.5':

### Phi 1.5 official
torch.set_default_device("cpu")
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True,device_map='cuda')
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
inputs = tokenizer(formatPrompt(prompt,context), return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs)
output = tokenizer.batch_decode(outputs)[0]

elif model_name=='flant5':
from transformers import T5Tokenizer, T5ForConditionalGeneration

#### Finetuned GPT2 model
# model = AutoModelForCausalLM.from_pretrained("SidhiPanda/gpt2-finetuned-megathon", trust_remote_code=True, torch_dtype=torch.float32)
# tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
# inputs = tokenizer(formatPrompt(prompt,context), return_tensors="pt", return_attention_mask=False)
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")

# outputs = model.generate(**inputs, max_new_tokens=45)
# output = tokenizer.batch_decode(outputs)[0]
input_text = formatPrompt(prompt, context)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

##########
outputs= model.generate(input_ids, max_new_tokens=400)
output = tokenizer.decode(outputs[0])

elif model_name=='gpt2_finetuned':

#### Summarization
# from transformers import pipeline
#### Finetuned GPT2 model
model = AutoModelForCausalLM.from_pretrained("SidhiPanda/gpt2-finetuned-megathon", trust_remote_code=True, torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
inputs = tokenizer(formatPrompt(prompt,context), return_tensors="pt", return_attention_mask=False)

summarizer = pipeline("summarization", model="Falconsai/medical_summarization")
output = summarizer(context, max_length=512, min_length=32, do_sample=False)
outputs = model.generate(**inputs, max_new_tokens=45)
output = tokenizer.batch_decode(outputs)[0]


# output

print(output)
return output[0]['summary_text']
# output
# print(output)
return output[0]['summary_text']

return useModel('summarizer')

0 comments on commit 91ca32b

Please sign in to comment.