Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
54 changes: 43 additions & 11 deletions RenAIssance_Transformer_OCR_Utsav_Rai/code/app/app_streamlit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import sys
import os
# Add CRAFT directory to sys.path for craft imports
CRAFT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'CRAFT'))
if CRAFT_DIR not in sys.path:
sys.path.insert(0, CRAFT_DIR)

APP_DIR = os.path.dirname(os.path.abspath(__file__))
CRAFT_DIR = os.path.abspath(os.path.join(APP_DIR, "..", "CRAFT"))
for path in (APP_DIR, CRAFT_DIR):
if os.path.isdir(path) and path not in sys.path:
sys.path.insert(0, path)
import torch
import torch.backends.cudnn as cudnn
from collections import OrderedDict
Expand All @@ -17,14 +19,28 @@
from PIL import Image, ImageEnhance
import cv2
import numpy as np
import os
import math
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import streamlit as st
from deskew import determine_skew

st.set_page_config(layout="wide")


def resolve_existing_path(env_var, *candidates):
override = os.getenv(env_var)
if override:
return override

for candidate in candidates:
if os.path.exists(candidate):
return candidate

raise FileNotFoundError(
f"Could not resolve a path for {env_var or 'required asset'}. "
f"Tried: {', '.join(candidates)}"
)

def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
Expand All @@ -39,7 +55,11 @@ def copyStateDict(state_dict):
@st.cache_resource
def load_craft_model():
# Define the path to the pre-trained CRAFT model weights
trained_model_path = '../../weights/craft_mlt_25k.pth'
trained_model_path = resolve_existing_path(
"RENAISSANCE_CRAFT_MODEL_PATH",
os.path.join(APP_DIR, "weights", "craft_mlt_25k.pth"),
os.path.abspath(os.path.join(APP_DIR, "..", "..", "weights", "craft_mlt_25k.pth")),
)

# Initialize the CRAFT model
net = CRAFT() # initialize
Expand All @@ -57,7 +77,11 @@ def load_craft_model():
refine = True # Set to True if using refine_net
if refine:
from refinenet import RefineNet
refiner_model_path = '../../weights/craft_refiner_CTW1500.pth' # Update the path
refiner_model_path = resolve_existing_path(
"RENAISSANCE_CRAFT_REFINER_PATH",
os.path.join(APP_DIR, "weights", "craft_refiner_CTW1500.pth"),
os.path.abspath(os.path.join(APP_DIR, "..", "..", "weights", "craft_refiner_CTW1500.pth")),
)
refine_net = RefineNet()
refine_net.load_state_dict(copyStateDict(torch.load(refiner_model_path, map_location=device)))
refine_net.to(device)
Expand Down Expand Up @@ -109,9 +133,17 @@ def test_net(net, image, text_threshold, link_threshold, low_text, *, cuda, poly
@st.cache_resource
def load_ocr_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Update path to point to the correct location of the OCR weights
model_path = "../../models"
processor_path = "../../models"
model_path = resolve_existing_path(
"RENAISSANCE_OCR_MODEL_DIR",
os.path.join(APP_DIR, "models"),
os.path.abspath(os.path.join(APP_DIR, "..", "..", "models")),
)
processor_path = resolve_existing_path(
"RENAISSANCE_OCR_PROCESSOR_DIR",
model_path,
os.path.join(APP_DIR, "models"),
os.path.abspath(os.path.join(APP_DIR, "..", "..", "models")),
)
processor = TrOCRProcessor.from_pretrained(processor_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path).to(device)
return processor, model, device
Expand Down Expand Up @@ -771,4 +803,4 @@ def get_virtual_page(pdf_document, virtual_index, dpi, **kwargs):
st.write("No image to display.")

else:
st.info("Please upload a PDF file from the left panel.")
st.info("Please upload a PDF file from the left panel.")