Spaces:
Build error
Build error
| import torch | |
| import cv2 | |
| import pytesseract | |
| from PIL import Image, ImageDraw, ImageFont | |
| from collections import deque | |
| import numpy as np | |
| import os | |
| # pytesseract.pytesseract.tesseract_cmd = 'Tesseract\\tesseract.exe' | |
| def get_full_img_path(src_dir): | |
| """ | |
| input: Đường dẫn đền folder chứa ảnh | |
| output: Danh sách tên của tất cả các ảnh | |
| """ | |
| list_img_names = [] | |
| for dirname, _, filenames in os.walk(src_dir): | |
| for filename in filenames: | |
| path = os.path.join(dirname, filename).replace(src_dir, '') | |
| if path[0] == '/': | |
| path = path[1:] | |
| list_img_names.append(path) | |
| return list_img_names | |
| def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3): | |
| """ | |
| input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C] | |
| output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W] | |
| """ | |
| img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device) | |
| imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2) | |
| detect_text_model.eval() | |
| with torch.no_grad(): | |
| result = detect_text_model(imgT).squeeze() | |
| result = (result >= 0.5).detach().cpu().numpy() | |
| mask = ((1-result) * 255).astype(np.uint8) | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| mask = cv2.erode(mask, kernel, iterations=iterations) | |
| mask = cv2.dilate(mask, kernel, iterations=2*iterations) | |
| mask = cv2.erode(mask, kernel, iterations=iterations) | |
| mask = (1 - mask // 255).astype(np.uint8) | |
| return mask | |
| def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3): | |
| """ | |
| input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C] | |
| output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W] | |
| """ | |
| img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device) | |
| imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2) | |
| detect_wordball_model.eval() | |
| with torch.no_grad(): | |
| result = detect_wordball_model(imgT).squeeze() | |
| result = (result >= 0.5).detach().cpu().numpy() | |
| mask = ((1-result) * 255).astype(np.uint8) | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| mask = cv2.erode(mask, kernel, iterations=iterations) | |
| mask = cv2.dilate(mask, kernel, iterations=2*iterations) | |
| mask = cv2.erode(mask, kernel, iterations=iterations) | |
| mask = (1 - mask // 255).astype(np.uint8) | |
| return mask | |
| def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5): | |
| """ | |
| input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C] | |
| text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W] | |
| text_value: Giá trị mà trong mặt nạ nó là text | |
| non_text_value: Giá trị mà trong mặt nạ nó là nền | |
| r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá | |
| output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C] | |
| """ | |
| MAX = max(text_value, non_text_value) | |
| MIN = min(text_value, non_text_value) | |
| scale_text_value = (text_value - MIN) / (MAX - MIN) | |
| scale_non_text_value = (non_text_value - MIN) / (MAX - MIN) | |
| text_msk[text_msk==text_value] = scale_text_value | |
| text_msk[text_msk==non_text_value] = scale_non_text_value | |
| wordball_msk[wordball_msk==text_value] = scale_text_value | |
| wordball_msk[wordball_msk==non_text_value] = scale_non_text_value | |
| if scale_text_value == 0: | |
| text_msk = 1 - text_msk | |
| wordball_msk = 1 - wordball_msk | |
| text_msk = text_msk * 255 | |
| remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA) | |
| remove_wordball = remove_txt.copy() | |
| remove_wordball[wordball_msk==1] = 255 | |
| return remove_wordball | |
| def dfs(grid, y, x, visited, value): | |
| """ | |
| Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì | |
| Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông | |
| """ | |
| max_y, max_x = y, x | |
| min_y, min_x = y+1, x+1 | |
| area = 0 | |
| stack = deque([(y, x)]) | |
| while stack: | |
| y, x = stack.pop() | |
| max_x = max(max_x, x) | |
| max_y = max(max_y, y) | |
| min_x = min(min_x, x) | |
| min_y = min(min_y, y) | |
| if (y, x) not in visited: | |
| visited.add((y, x)) | |
| area += 1 | |
| # Kiểm tra các ô liền kề | |
| for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: | |
| nx, ny = x + dx, y + dy | |
| if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited: | |
| stack.append((ny, nx)) | |
| return (min_x, min_y, max_x, max_y), area | |
| def find_clusters(grid, value): | |
| """ | |
| Thuật toán tìm danh sách các miền liên thông | |
| """ | |
| visited = set() | |
| clusters = [] | |
| areas = [] | |
| for y in range(grid.shape[0]): | |
| for x in range(grid.shape[1]): | |
| if grid[y, x] == value and (y, x) not in visited: | |
| cluster, area = dfs(grid, y, x, visited, value) | |
| clusters.append(cluster) | |
| areas.append(area) | |
| return clusters, areas | |
| def get_text_positions(text_msk, text_value=0): | |
| """ | |
| input: text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W] | |
| text_value: Giá trị mà trong mặt nạ nó là text | |
| min_area: Giả trị tối thiểu của vùng có thể có text | |
| output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y) | |
| """ | |
| clusters, areas = find_clusters(text_msk, value=text_value) | |
| return clusters, areas | |
| def filter_text_positions(clusters, areas, min_area=1200, max_area=10000): | |
| clusters = clusters[(areas >= min_area) & (areas <= max_area)] | |
| return clusters | |
| def get_list_texts(src_img, text_positions, lang='eng'): | |
| """ | |
| input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C] | |
| text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y) | |
| lang: Ngôn ngữ của text | |
| output: Danh sách các câu text | |
| """ | |
| list_texts = [] | |
| for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions): | |
| crop_img = src_img[min_y:max_y+1, min_x:max_x+1] | |
| img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(img_rgb) | |
| text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip() | |
| while ' ' in text: | |
| text = text.replace(' ', ' ') | |
| list_texts.append(text) | |
| return list_texts | |
| def translate(list_texts, translator): | |
| translated_texts = [] | |
| for text in list_texts: | |
| if not text: | |
| text = 'a' | |
| translated_text = translator.translate(text, src='en', dest='vi').text | |
| translated_texts.append(translated_text) | |
| return translated_texts | |
| def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0): | |
| # Mở ảnh | |
| draw = ImageDraw.Draw(image) | |
| # Giải nén box (min_x, min_y, max_x, max_y) | |
| min_x, min_y, max_x, max_y = box | |
| # Tạo font | |
| font = ImageFont.truetype(font_path, font_size) | |
| # Chia văn bản thành nhiều dòng nếu cần | |
| wrapped_lines = wrap_text(text, font, draw, max_x - min_x) | |
| # Tính chiều cao của tất cả các dòng cộng lại | |
| total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines) | |
| # Tính toạ độ y bắt đầu để căn giữa theo chiều dọc | |
| start_y = min_y + (max_y - min_y - total_text_height) // 2 | |
| # Vẽ từng dòng và căn giữa theo chiều ngang | |
| current_y = start_y | |
| for line in wrapped_lines: | |
| text_width, text_height = get_text_dimensions(line, draw, font) | |
| text_x = min_x + (max_x - min_x - text_width) // 2 # Căn giữa theo chiều ngang | |
| draw.text((text_x, current_y), line, fill=text_color, font=font) | |
| current_y += text_height + pad # Di chuyển y xuống để vẽ dòng tiếp theo | |
| # Lưu ảnh mới | |
| return image | |
| def get_text_dimensions(text, draw, font): | |
| """Trả về (width, height) của văn bản.""" | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| width = bbox[2] - bbox[0] | |
| height = bbox[3] - bbox[1] | |
| return width, height | |
| def get_text_height(text, draw, font): | |
| """Trả về chiều cao của văn bản.""" | |
| _, _, _, height = draw.textbbox((0, 0), text, font=font) | |
| return height | |
| def wrap_text(text, font, draw, max_width): | |
| """Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa.""" | |
| words = text.split() | |
| lines = [] | |
| current_line = "" | |
| for word in words: | |
| # Thử thêm từ vào dòng hiện tại | |
| test_line = f"{current_line} {word}".strip() | |
| test_width, _ = get_text_dimensions(test_line, draw, font) | |
| if test_width <= max_width: | |
| current_line = test_line | |
| else: | |
| # Nếu quá rộng, lưu dòng hiện tại và bắt đầu dòng mới | |
| lines.append(current_line) | |
| current_line = word | |
| # Thêm dòng cuối cùng | |
| if current_line: | |
| lines.append(current_line) | |
| return lines | |
| def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]): | |
| # Copy ảnh không chữ | |
| img_bgr = non_text_src_img.copy() | |
| # Thêm text vào măt nạ 1 | |
| for idx, text in enumerate(list_translated_texts): | |
| # Tạo mặt nạ trắng | |
| mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255) | |
| mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255) | |
| mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color) | |
| # Chuyển ảnh từ PIL sang cv2 | |
| mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255 | |
| mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR) | |
| if stroke[idx] > 0: | |
| mask2 = np.array(mask2).astype(np.uint8) | |
| mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR) | |
| mask2 = mask2 - mask1 | |
| kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8) | |
| mask2 = cv2.dilate(mask2, kernel, iterations=1) | |
| img_bgr[mask2==255] = 255 | |
| img_bgr[mask1==text_color] = text_color | |
| return img_bgr | |
| def save_img(path, translated_text_src_img): | |
| """ | |
| input: path: Đường dẫn đến ảnh gốc ban đầu | |
| translated_text_src_img: Ảnh sau khi được dịch | |
| output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-" | |
| """ | |
| dot = path.rfind('.') | |
| last_slash = -1 | |
| if '/' in path: | |
| last_slash = path.rfind('/') | |
| ext = path[dot:] | |
| parent_path = path[:last_slash+1] | |
| name = path[last_slash+1:dot] | |
| if parent_path and not os.path.exists(parent_path): | |
| os.mkdir(parent_path) | |
| cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img) | |
| print(f'Image saved at {parent_path}translated-{name}{ext}') |