Commit 77e4e52c authored by dmMaze's avatar dmMaze
Browse files

fix ref_src_lines mode & speed up mit48px inference"

parent c1ae207c
Loading
Loading
Loading
Loading
+168 −2
Original line number Diff line number Diff line
@@ -159,7 +159,7 @@ class Model48pxOCR:
                image_tensor = image_tensor.to(self.device)

            with torch.no_grad():
                ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
                ret = self.model.infer_beam_batch_tensor(image_tensor, widths, beams_k = 5, max_seq_length = 255)
            for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret):
                if prob < 0.2:
                    continue
@@ -551,11 +551,15 @@ class OCR(nn.Module):
            encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
            encoder.forward = transformer_encoder_forward
            self.encoders.append(encoder)
        self.encoders.forward = self.encoder_forward

        for i in range(5) :
            decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
            decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
            decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True)
            self.decoders.append(decoder)
        self.decoders.forward = self.decoder_forward

        self.embd = nn.Embedding(self.dict_size, embd_dim)
        self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15))
        self.pred = nn.Linear(embd_dim, self.dict_size)
@@ -670,6 +674,168 @@ class OCR(nn.Module):
            result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))
        return result

    def infer_beam_batch_tensor(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
        N, C, H, W = img.shape
        assert H == 48 and C == 3


        memory = self.backbone(img)
        memory = einops.rearrange(memory, 'N C 1 W -> N W C')
        valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
        input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device)
        
        for i, l in enumerate(valid_feats_length):
            input_mask[i, l:] = True
        memory = self.encoders(memory, input_mask) # N, W, Dim


        out_idx = torch.full((N, 1), start_tok, dtype=torch.long, device=img.device)  # Shape [N, 1]
        cached_activations = torch.zeros(N, len(self.decoders)+1, max_seq_length, 320, device=img.device)  # [N, L, S, E]
        log_probs = torch.zeros(N, 1, device=img.device)  # Shape [N, 1]        # N, E
        idx_embedded = self.embd(out_idx[:, -1:]) 


        decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, 0)
        pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)   # N, n_chars
        pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)  # N, k


        out_idx = torch.cat([out_idx.unsqueeze(1).expand(-1, beams_k, -1), pred_chars_index.unsqueeze(-1)], dim=-1).reshape(-1, 2)  # Shape [N * k, 2]
        log_probs = pred_chars_values.view(-1, 1)  # Shape [N * k, 1]
        memory = memory.repeat_interleave(beams_k, dim=0)
        input_mask = input_mask.repeat_interleave(beams_k, dim=0)
        cached_activations = cached_activations.repeat_interleave(beams_k, dim=0)
        batch_index = torch.arange(N).repeat_interleave(beams_k, dim=0).to(img.device)


        finished_hypos = defaultdict(list)
        N_remaining = N


        for step in range(1, max_seq_length):
            idx_embedded = self.embd(out_idx[:, -1:])
            decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, step)
            pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)  # Shape [N * k, dict_size]
            pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim=1)  # [N * k, k]


            finished = out_idx[:, -1] == end_tok
            pred_chars_values[finished] = 0
            pred_chars_index[finished] = end_tok


            # Extend hypotheses
            new_out_idx = out_idx.unsqueeze(1).expand(-1, beams_k, -1)  # Shape [N * k, k, seq_len]
            new_out_idx = torch.cat([new_out_idx, pred_chars_index.unsqueeze(-1)], dim=-1)  # Shape [N * k, k, seq_len + 1]
            new_out_idx = new_out_idx.view(-1, step + 2)  # Reshape to [N * k^2, seq_len + 1]
            new_log_probs = log_probs.unsqueeze(1).expand(-1, beams_k, -1) + pred_chars_values.unsqueeze(-1)  # Shape [N * k^2, 1]
            new_log_probs = new_log_probs.view(-1, 1)  # [N * k^2, 1]


            # Sort and select top-k hypotheses per sample
            new_out_idx = new_out_idx.view(N_remaining, -1, step + 2)  # [N, k^2, seq_len + 1]
            new_log_probs = new_log_probs.view(N_remaining, -1)  # [N, k^2]
            batch_topk_log_probs, batch_topk_indices = new_log_probs.topk(beams_k, dim=1)  # [N, k]
            
            # Gather the top-k hypotheses based on log probabilities
            expanded_topk_indices = batch_topk_indices.unsqueeze(-1).expand(-1, -1, new_out_idx.shape[-1])  # Shape [N, k, seq_len + 1]
            out_idx = torch.gather(new_out_idx, 1, expanded_topk_indices).reshape(-1, step + 2)  # [N * k, seq_len + 1]
            log_probs = batch_topk_log_probs.view(-1, 1)  # Reshape to [N * k, 1]


            # Check for finished sequences
            finished = (out_idx[:, -1] == end_tok)  # Check if the last token is the end token
            finished = finished.view(N_remaining, beams_k)  # Reshape to [N, k]
            finished_counts = finished.sum(dim=1)  # Count the number of finished hypotheses per sample
            finished_batch_indices = (finished_counts >= max_finished_hypos).nonzero(as_tuple=False).squeeze()


            if finished_batch_indices.numel() == 0:
                continue


            if finished_batch_indices.dim() == 0:
                finished_batch_indices = finished_batch_indices.unsqueeze(0)
            
            for idx in finished_batch_indices:
                batch_log_probs = batch_topk_log_probs[idx]
                best_beam_idx = batch_log_probs.argmax()
                finished_hypos[batch_index[beams_k * idx].item()] = \
                    out_idx[idx * beams_k + best_beam_idx], \
                    torch.exp(batch_log_probs[best_beam_idx]).item(), \
                    cached_activations[idx * beams_k + best_beam_idx] 


            remaining_indexs = []
            for i in range(N_remaining):
                if i not in finished_batch_indices:
                    for j in range(beams_k):
                        remaining_indexs.append(i * beams_k + j)


            if not remaining_indexs:
                break


            N_remaining = int(len(remaining_indexs) / beams_k)
            out_idx = out_idx.index_select(0, torch.tensor(remaining_indexs, device=img.device))
            log_probs = log_probs.index_select(0, torch.tensor(remaining_indexs, device=img.device))
            memory = memory.index_select(0, torch.tensor(remaining_indexs, device=img.device))
            cached_activations = cached_activations.index_select(0, torch.tensor(remaining_indexs, device=img.device))
            input_mask = input_mask.index_select(0, torch.tensor(remaining_indexs, device=img.device))
            batch_index = batch_index.index_select(0, torch.tensor(remaining_indexs, device=img.device))


        # Ensure we have the correct number of finished hypotheses for each sample
        assert len(finished_hypos) == N


        # Final output processing and color predictions
        result = []
        for i in range(N):
            final_idx, prob, decoded = finished_hypos[i] 
            color_feats = self.color_pred1(decoded[-1].unsqueeze(0))
            fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \
                self.color_pred_fg(color_feats), \
                self.color_pred_bg(color_feats), \
                self.color_pred_fg_ind(color_feats), \
                self.color_pred_bg_ind(color_feats)
            result.append((final_idx[1:], prob, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))


        return result

    def encoder_forward(self, memory, encoder_mask):
        for layer in self.encoders :
            memory = layer(layer, src = memory, src_key_padding_mask = encoder_mask)
        return memory

    def decoder_forward(
        self,
        embd: torch.Tensor,
        cached_activations: torch.Tensor,  # Shape [N, L, T, E] where L=num_layers, T=sequence length, E=embedding size
        memory: torch.Tensor,  # Shape [N, H, W, C] (Encoder memory output)
        memory_mask: torch.BoolTensor,
        step: int
    ):

        layer: nn.TransformerDecoderLayer
        tgt = embd  # N, 1, E for the last token embedding

        for l, layer in enumerate(self.decoders):
            combined_activations = cached_activations[:, l, :step, :]  # N, T, E
            combined_activations = torch.cat([combined_activations, tgt], dim=1)  # N, T+1, E
            cached_activations[:, l, step, :] = tgt.squeeze(1)

            # Update cache and perform self attention
            tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset=step)[0]
            tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask=memory_mask, q_offset=step)[0]
            tgt = tgt + layer._ff_block(layer.norm3(tgt))

        cached_activations[:, l+1, step, :] = tgt.squeeze(1) # Append the new activations

        return tgt.squeeze_(1), cached_activations

import numpy as np

def convert_pl_model(filename: str) :
@@ -707,7 +873,7 @@ def test_infer() :
    img_torch = einops.rearrange((torch.from_numpy(img) / 127.5 - 1.0), 'h w c -> 1 c h w')

    with torch.no_grad() :
        idx, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = model.infer_beam_batch(img_torch, [new_w], 5, max_seq_length = 32)[0]
        idx, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = model.infer_beam_batch_tensor(img_torch, [new_w], 5, max_seq_length = 32)[0]
        txt = ''
        for i in idx :
            txt += dictionary[i]
+1 −1
Original line number Diff line number Diff line
@@ -88,7 +88,7 @@ class BaseTranslator(BaseModule):
            if TRANSLATORS.module_dict[key] == self.__class__:
                self.name = key
                break
        self.textblk_break = '\n###\n'
        self.textblk_break = '\n##\n'
        self.lang_source: str = lang_source
        self.lang_target: str = lang_target
        self.lang_map: Dict = LANGMAP_GLOBAL.copy()
+3 −3
Original line number Diff line number Diff line
@@ -220,7 +220,7 @@ def layout_lines_aligncenter(
            elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
                mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
                line_valid = False
                if (len(lines) == 1 and ref_src_lines or line_right_no + 1 >= len(srcline_wlist)) and \
                if ref_src_lines and (len(wl_list) == 1 or line_right_no + 1 >= len(srcline_wlist)) and \
                    line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_right_no, line_height, ref_src_lines):
                    line_valid = True
            else:
@@ -271,7 +271,7 @@ def layout_lines_aligncenter(
            elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
                mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
                line_valid = False
                if line_left_no - 1 < 0 and \
                if ref_src_lines and line_left_no - 1 < 0 and \
                    line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_left_no, line_height, ref_src_lines):
                    line_valid = True
            else:
@@ -358,7 +358,7 @@ def layout_lines_alignside(
                if mask[np.clip(pos_y, 0, bh - 1): np.clip(line_bottom - lh_pad, 0, bh), new_x].mean() > 240:
                    line_valid = True
                else:
                    if line_id + 1 >= len(srcline_wlist) and line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines):
                    if ref_src_lines and line_id + 1 >= len(srcline_wlist) and line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines):
                        line_valid = True
            if line_valid:
                line_valid = line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines)
+12 −11
Original line number Diff line number Diff line
@@ -392,17 +392,18 @@ class TextBlock:
    def get_transformed_region(self, img: np.ndarray, idx: int, textheight: int, maxwidth: int = None) -> np.ndarray :
        im_h, im_w = img.shape[:2]

        lines = np.round(np.array(self.lines[idx])).astype(np.int64)[None]
        line = np.round(np.array(self.lines[idx])).astype(np.int64)
        

        expand_size = max(int(self._detected_font_size * 0.1), 2)
        if not self.src_is_vertical and self.det_model == 'ctd':
            # ctd detected horizontal bbox is smaller than GT
            expand_size = max(int(self._detected_font_size * 0.1), 3)
            rad = np.deg2rad(self.angle)
            shifted_vec = np.array([[[-1, -1],[1, -1],[1, 1],[-1, 1]]])
            shifted_vec = shifted_vec * np.array([[[np.sin(rad), np.cos(rad)]]]) * expand_size
        lines = lines + shifted_vec
        lines[..., 0] = np.clip(lines[..., 0], 0, im_w)
        lines[..., 1] = np.clip(lines[..., 1], 0, im_h)
        line = np.round(lines[0]).astype(np.int64)
            line = line + shifted_vec
            line[..., 0] = np.clip(line[..., 0], 0, im_w)
            line[..., 1] = np.clip(line[..., 1], 0, im_h)
            line = np.round(line[0]).astype(np.int64)

        x1, y1, x2, y2 = line[:, 0].min(), line[:, 1].min(), line[:, 0].max(), line[:, 1].max()