qw-gallery-scenery/python/rgb_histogram_web.py
2021-04-16 23:25:31 +03:00

212 lines
6.4 KiB
Python

from pydantic import BaseModel
from fastapi import FastAPI, File,Body,Form, HTTPException
from os import listdir
import numpy as np
from PIL import Image
from sklearn.neighbors import NearestNeighbors
import cv2
IMAGE_PATH="./../public/images"
import sqlite3
import io
conn = sqlite3.connect('rgb_histograms.db')
import nmslib
# dim=4096
index = nmslib.init(method='hnsw', space="l1", data_type=nmslib.DataType.DENSE_VECTOR)
index_time_params = {'M': 32,'efConstruction': 200}
IN_MEMORY_HISTS={}
def init_index():
image_data=get_all_data()
features=[]
ids=[]
for image in image_data:
ids.append(image['image_id'])
features.append(image['features'])
# IN_MEMORY_HISTS[image['image_id']]=image['features']
ids=np.array(ids)
features=np.array(features).squeeze()
index.addDataPointBatch(features,ids)
index.createIndex(index_time_params)
print("Index is ready")
def read_img_file(image_data):
img = Image.open(io.BytesIO(image_data))
return img
def get_rgb_histogram_by_id(id):
cursor = conn.cursor()
query = '''
SELECT rgb_histogram
FROM rgb_hists
WHERE id = (?)
'''
cursor.execute(query,(id,))
all_rows = cursor.fetchone()
return all_rows[0]
def get_features(image_buffer):
query_image=np.array(read_img_file(image_buffer).convert('RGB'))
query_hist_combined=cv2.calcHist([query_image],[0,1,2],None,[16,16,16],[0,256,0,256,0,256])
query_hist_combined = query_hist_combined.flatten()
query_hist_combined=cv2.divide(query_hist_combined,query_image.shape[0]*query_image.shape[1])
return query_hist_combined
def create_table():
cursor = conn.cursor()
query = '''
CREATE TABLE IF NOT EXISTS rgb_hists(
id INTEGER NOT NULL UNIQUE PRIMARY KEY,
rgb_histogram BLOB NOT NULL
)
'''
cursor.execute(query)
conn.commit()
def check_if_exists_by_id(id):
cursor = conn.cursor()
query = '''SELECT EXISTS(SELECT 1 FROM rgb_hists WHERE id=(?))'''
cursor.execute(query,(id,))
all_rows = cursor.fetchone()
return all_rows[0] == 1
def delete_descriptor_by_id(id):
cursor = conn.cursor()
query = '''DELETE FROM rgb_hists WHERE id=(?)'''
cursor.execute(query,(id,))
conn.commit()
def get_all_ids():
cursor = conn.cursor()
query = '''SELECT id FROM rgb_hists'''
cursor.execute(query)
all_rows = cursor.fetchall()
return list(map(lambda el:el[0],all_rows))
def convert_array(text):
out = io.BytesIO(text)
out.seek(0)
return np.load(out)
def get_all_data():
cursor = conn.cursor()
query = '''
SELECT id, rgb_histogram
FROM rgb_hists
'''
cursor.execute(query)
all_rows = cursor.fetchall()
return list(map(lambda el:{"image_id":el[0],"features":convert_array(el[1])},all_rows))
def adapt_array(arr):
out = io.BytesIO()
np.save(out, arr)
out.seek(0)
return sqlite3.Binary(out.read())
def add_descriptor(id,rgb_histogram):
cursor = conn.cursor()
query = '''INSERT INTO rgb_hists(id, rgb_histogram) VALUES (?,?)'''
cursor.execute(query,(id,rgb_histogram))
conn.commit()
def sync_db():
ids_in_db=set(get_all_ids())
file_names=listdir(IMAGE_PATH)
for file_name in file_names:
file_id=int(file_name[:file_name.index('.')])
if file_id in ids_in_db:
ids_in_db.remove(file_id)
for id in ids_in_db:
delete_descriptor_by_id(id) #Fix this
print(f"deleting {id}")
print("db synced")
app = FastAPI()
@app.get("/")
async def read_root():
return {"Hello": "World"}
@app.post("/calculate_HIST_features")
async def calculate_HIST_features_handler(image: bytes = File(...),image_id: str = Form(...)):
features=get_features(image)
add_descriptor(int(image_id),adapt_array(features))
# IN_MEMORY_HISTS[int(image_id)]=features
index.addDataPoint(int(image_id),features)
index.createIndex(index_time_params)
return {"status":"200"}
class Item_image_id(BaseModel):
image_id: int
from timeit import default_timer as timer
@app.post("/get_similar_images_by_id")
async def get_similar_images_by_id_handler(item: Item_image_id):
try:
start = timer()
target_features = convert_array(get_rgb_histogram_by_id(item.image_id))
labels, _ = index.knnQuery(target_features, k=20)
end = timer()
print((end - start)*1000)
return labels.tolist()
except RuntimeError:
raise HTTPException(
status_code=500, detail="Image with this id is not found")
import heapq
# def find_bruteforce(target_image_id,k):
# query_hist=IN_MEMORY_HISTS[target_image_id]
# # print(query_hist[0])
# # print(query_hist[1])
# # heap=[]
# # for key in IN_MEMORY_HISTS:
# # # similarity=np.sum(np.minimum(query_hist,IN_MEMORY_HISTS[key]))
# # similarity=cv2.compareHist(query_hist,IN_MEMORY_HISTS[key],cv2.HISTCMP_INTERSECT)
# # if len(heap) < k or similarity > heap[0][0]:
# # # If the heap is full, remove the smallest element on the heap.
# # if len(heap) == k: heapq.heappop(heap)
# # # add the current element as the new smallest.
# # heapq.heappush( heap, (similarity,key) )
# # heap=[heapq.heappop(heap) for i in range(len(heap))]
# # heap.reverse()
# # print(heap)
# # found_images_filenames=list(map(lambda el: el[1],heap))
# found_images=[]
# for key in IN_MEMORY_HISTS:
# # similarity=cv2.compareHist(query_hist,IN_MEMORY_HISTS[key],cv2.HISTCMP_INTERSECT)
# similarity=np.abs(query_hist-IN_MEMORY_HISTS[key]).sum()
# found_images.append({"similarity":similarity,"file_name":key})
# found_images.sort(key=lambda item: item["similarity"],reverse=False)
# found_images=found_images[:20]
# found_images_filenames=list(map(lambda el: el["file_name"],found_images))
# return found_images_filenames
# @app.post("/get_similar_images_by_id")
# async def get_similar_images_by_id_handler(item: Item_image_id):
# start = timer()
# similar=find_bruteforce(item.image_id,10)
# end = timer()
# print((end - start)*1000) # Time in seconds, e.g. 5.38091952400282
# return similar
@app.post("/delete_HIST_features")
async def delete_hist_features_handler(item:Item_image_id):
delete_descriptor_by_id(item.image_id)
init_index()
return {"status":"200"}
print(__name__)
if __name__ == 'rgb_histogram_web':
create_table()
sync_db()
init_index()