import uvicorn if __name__ == '__main__': uvicorn.run('rgb_histogram_web:app', host='127.0.0.1', port=33335, log_level="info") from pydantic import BaseModel from fastapi import FastAPI, File,Body,Form, HTTPException from os import listdir import numpy as np from PIL import Image from sklearn.neighbors import NearestNeighbors import cv2 IMAGE_PATH="./../public/images" import sqlite3 import io conn = sqlite3.connect('rgb_histograms.db') import nmslib # dim=4096 index = nmslib.init(method='hnsw', space="l1", data_type=nmslib.DataType.DENSE_VECTOR) index_time_params = {'M': 32,'efConstruction': 200} IN_MEMORY_HISTS={} def init_index(): image_data=get_all_data() features=[] ids=[] for image in image_data: ids.append(image['image_id']) features.append(image['features']) # IN_MEMORY_HISTS[image['image_id']]=image['features'] ids=np.array(ids) features=np.array(features).squeeze() index.addDataPointBatch(features,ids) index.createIndex(index_time_params) print("Index is ready") def read_img_file(image_data): img = Image.open(io.BytesIO(image_data)) return img def get_rgb_histogram_by_id(id): cursor = conn.cursor() query = ''' SELECT rgb_histogram FROM rgb_hists WHERE id = (?) ''' cursor.execute(query,(id,)) all_rows = cursor.fetchone() return all_rows[0] def get_features(image_buffer): query_image=np.array(read_img_file(image_buffer).convert('RGB')) query_hist_combined=cv2.calcHist([query_image],[0,1,2],None,[16,16,16],[0,256,0,256,0,256]) query_hist_combined = query_hist_combined.flatten() query_hist_combined=cv2.divide(query_hist_combined,query_image.shape[0]*query_image.shape[1]) return query_hist_combined def create_table(): cursor = conn.cursor() query = ''' CREATE TABLE IF NOT EXISTS rgb_hists( id INTEGER NOT NULL UNIQUE PRIMARY KEY, rgb_histogram BLOB NOT NULL ) ''' cursor.execute(query) conn.commit() def check_if_exists_by_id(id): cursor = conn.cursor() query = '''SELECT EXISTS(SELECT 1 FROM rgb_hists WHERE id=(?))''' cursor.execute(query,(id,)) all_rows = cursor.fetchone() return all_rows[0] == 1 def delete_descriptor_by_id(id): cursor = conn.cursor() query = '''DELETE FROM rgb_hists WHERE id=(?)''' cursor.execute(query,(id,)) conn.commit() def get_all_ids(): cursor = conn.cursor() query = '''SELECT id FROM rgb_hists''' cursor.execute(query) all_rows = cursor.fetchall() return list(map(lambda el:el[0],all_rows)) def convert_array(text): out = io.BytesIO(text) out.seek(0) return np.load(out) def get_all_data(): cursor = conn.cursor() query = ''' SELECT id, rgb_histogram FROM rgb_hists ''' cursor.execute(query) all_rows = cursor.fetchall() return list(map(lambda el:{"image_id":el[0],"features":convert_array(el[1])},all_rows)) def adapt_array(arr): out = io.BytesIO() np.save(out, arr) out.seek(0) return sqlite3.Binary(out.read()) def add_descriptor(id,rgb_histogram): cursor = conn.cursor() query = '''INSERT INTO rgb_hists(id, rgb_histogram) VALUES (?,?)''' cursor.execute(query,(id,rgb_histogram)) conn.commit() def sync_db(): ids_in_db=set(get_all_ids()) file_names=listdir(IMAGE_PATH) for file_name in file_names: file_id=int(file_name[:file_name.index('.')]) if file_id in ids_in_db: ids_in_db.remove(file_id) for id in ids_in_db: delete_descriptor_by_id(id) #Fix this print(f"deleting {id}") print("db synced") app = FastAPI() @app.get("/") async def read_root(): return {"Hello": "World"} @app.post("/calculate_HIST_features") async def calculate_HIST_features_handler(image: bytes = File(...),image_id: str = Form(...)): features=get_features(image) add_descriptor(int(image_id),adapt_array(features)) # IN_MEMORY_HISTS[int(image_id)]=features index.addDataPoint(int(image_id),features) index.createIndex(index_time_params) return {"status":"200"} class Item_image_id(BaseModel): image_id: int from timeit import default_timer as timer @app.post("/get_similar_images_by_id") async def get_similar_images_by_id_handler(item: Item_image_id): try: start = timer() target_features = convert_array(get_rgb_histogram_by_id(item.image_id)) labels, _ = index.knnQuery(target_features, k=20) end = timer() print((end - start)*1000) return labels.tolist() except RuntimeError: raise HTTPException( status_code=500, detail="Image with this id is not found") import heapq # def find_bruteforce(target_image_id,k): # query_hist=IN_MEMORY_HISTS[target_image_id] # # print(query_hist[0]) # # print(query_hist[1]) # # heap=[] # # for key in IN_MEMORY_HISTS: # # # similarity=np.sum(np.minimum(query_hist,IN_MEMORY_HISTS[key])) # # similarity=cv2.compareHist(query_hist,IN_MEMORY_HISTS[key],cv2.HISTCMP_INTERSECT) # # if len(heap) < k or similarity > heap[0][0]: # # # If the heap is full, remove the smallest element on the heap. # # if len(heap) == k: heapq.heappop(heap) # # # add the current element as the new smallest. # # heapq.heappush( heap, (similarity,key) ) # # heap=[heapq.heappop(heap) for i in range(len(heap))] # # heap.reverse() # # print(heap) # # found_images_filenames=list(map(lambda el: el[1],heap)) # found_images=[] # for key in IN_MEMORY_HISTS: # # similarity=cv2.compareHist(query_hist,IN_MEMORY_HISTS[key],cv2.HISTCMP_INTERSECT) # similarity=np.abs(query_hist-IN_MEMORY_HISTS[key]).sum() # found_images.append({"similarity":similarity,"file_name":key}) # found_images.sort(key=lambda item: item["similarity"],reverse=False) # found_images=found_images[:20] # found_images_filenames=list(map(lambda el: el["file_name"],found_images)) # return found_images_filenames # @app.post("/get_similar_images_by_id") # async def get_similar_images_by_id_handler(item: Item_image_id): # start = timer() # similar=find_bruteforce(item.image_id,10) # end = timer() # print((end - start)*1000) # Time in seconds, e.g. 5.38091952400282 # return similar @app.post("/delete_HIST_features") async def delete_hist_features_handler(item:Item_image_id): delete_descriptor_by_id(item.image_id) init_index() return {"status":"200"} print(__name__) if __name__ == 'rgb_histogram_web': create_table() sync_db() init_index()