Spaces:
Sleeping
Sleeping
| import importlib | |
| import warnings | |
| from collections import defaultdict | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from config import Config | |
| from data_utils.image_utils import _to_2d | |
| warnings.filterwarnings("ignore") | |
| DocTr_Plus = importlib.import_module("models.DocTr-Plus.inference") | |
| DocScanner = importlib.import_module("models.DocScanner.inference") | |
| cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mask_dict = defaultdict(int) | |
| def load_geotrp_model(cuda, path=""): | |
| _GeoTrP = DocTr_Plus.GeoTrP() | |
| _GeoTrP = _GeoTrP.to(cuda) | |
| DocTr_Plus.reload_model(_GeoTrP.GeoTr, path) | |
| _GeoTrP.eval() | |
| return _GeoTrP | |
| def load_docscanner_model(cuda, path_l="", path_m=""): | |
| net = DocScanner.Net().to(cuda) | |
| DocScanner.reload_seg_model(cuda, net.msk, path_m) | |
| DocScanner.reload_rec_model(cuda, net.bm, path_l) | |
| net.eval() | |
| return net | |
| def preprocess_image(img, target_size=[288, 288]): | |
| im_ori = img[:, :, :3] / 255.0 | |
| h_, w_, _ = im_ori.shape | |
| im_ori_resized = cv2.resize(im_ori, (288, 288)) | |
| im = cv2.resize(im_ori_resized, target_size) | |
| im = im.transpose(2, 0, 1) | |
| im = torch.from_numpy(im).float().unsqueeze(0) | |
| return im_ori, im, h_, w_ | |
| def geotrp_rec(img, model, cuda): | |
| im_ori, im, h_, w_ = preprocess_image(img) | |
| with torch.no_grad(): | |
| bm = model(im.to(cuda)) | |
| bm = bm.cpu().numpy()[0] | |
| bm0 = bm[0, :, :] | |
| bm1 = bm[1, :, :] | |
| bm0 = cv2.blur(bm0, (3, 3)) | |
| bm1 = cv2.blur(bm1, (3, 3)) | |
| img_geo = cv2.remap(im_ori, bm0, bm1, cv2.INTER_LINEAR) * 255 | |
| img_geo = cv2.resize(img_geo, (w_, h_)) | |
| return img_geo | |
| def docscanner_get_mask(img, model, cuda): | |
| _, im, h, w = preprocess_image(img) | |
| with torch.no_grad(): | |
| _, msk = model(im.to(cuda)) | |
| msk = msk.cpu() | |
| mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8) | |
| mask_resized = cv2.resize(mask_np, (w, h)) | |
| return mask_resized | |
| def docscanner_rec_img(img, model, cuda): | |
| im_ori, im, h, w = preprocess_image(img) | |
| with torch.no_grad(): | |
| bm = model(im.to(cuda)) | |
| bm = bm.cpu() | |
| # save rectified image | |
| bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow | |
| bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow | |
| bm0 = cv2.blur(bm0, (3, 3)) | |
| bm1 = cv2.blur(bm1, (3, 3)) | |
| lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 | |
| out = F.grid_sample( | |
| torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), | |
| lbl, | |
| align_corners=True, | |
| ) | |
| img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8) | |
| return img | |
| def docscanner_rec(img, model, cuda): | |
| im_ori = img[:, :, :3] / 255.0 | |
| h, w, _ = im_ori.shape | |
| im = cv2.resize(im_ori, (288, 288)) | |
| im = im.transpose(2, 0, 1) | |
| im = torch.from_numpy(im).float().unsqueeze(0) | |
| with torch.no_grad(): | |
| bm, msk = model(im.to(cuda)) | |
| bm = bm.cpu() | |
| msk = msk.cpu() | |
| mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8) | |
| mask_resized = cv2.resize(mask_np, (w, h)) | |
| mask_img = mask_resized | |
| # save rectified image | |
| bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow | |
| bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow | |
| bm0 = cv2.blur(bm0, (3, 3)) | |
| bm1 = cv2.blur(bm1, (3, 3)) | |
| lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 | |
| out = F.grid_sample( | |
| torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), | |
| lbl, | |
| align_corners=True, | |
| ) | |
| img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8) | |
| return img, mask_img | |
| # μΆν data_utilsμ λ£μ μμ | |
| def get_mask_white_area(mask): | |
| """ | |
| Get the white area (non-zero pixels) of a mask. | |
| Args: | |
| mask (np.ndarray): Input mask image (2D or 3D array) | |
| Returns: | |
| np.ndarray: Array of (y, x) coordinates of white pixels | |
| """ | |
| mask = _to_2d(mask) | |
| white_pixels = np.argwhere(mask > 0) | |
| return white_pixels | |
| def main(): | |
| config = Config() | |
| img = cv2.imread("input/test.jpg") # μ½λ μ€νμ μμ νμ | |
| docscanner = load_docscanner_model( | |
| cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path | |
| ) | |
| doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path) | |
| mask = docscanner_get_mask(img, docscanner, cuda) | |
| mask_dict.add(get_mask_white_area(mask)) | |
| if __name__ == "__main__": | |
| main() | |