File size: 23,380 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
"""This script is used to generate images with different checkpoints and prompts"""
from copy import copy
import os
import re
import subprocess
import sys
from typing import Any, List, Tuple, Union

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts"))
from scripts.Utils import Utils
from scripts.Logger import Logger
from scripts.CivitaihelperPrompts import CivitaihelperPrompts
from scripts.Save import Save
from scripts.BatchParams import BatchParams, get_all_batch_params


import gradio as gr
import modules
import modules.scripts as scripts
import modules.shared as shared
from modules.shared_state import State as shared_state
from modules import processing
from modules.processing import process_images
from modules.ui_components import (FormColumn, FormRow)

from PIL import Image, ImageDraw, ImageFont

import PIL



try:
    import matplotlib.font_manager as fm
except:
    subprocess.check_call(["pip", "install", "matplotlib"])
    import matplotlib.font_manager as fm

class ToolButton(gr.Button, gr.components.FormComponent):
    """Small button with single emoji as text, fits inside gradio forms"""

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(variant="tool", elem_classes=["batch-checkpoint-prompt"], **kwargs)

    def get_block_name(self) -> str:
        return "button"


class CheckpointLoopScript(scripts.Script):
    """Script for generating images with different checkpoints and prompts
    This calss is called by A1111
    """

    def __init__(self) -> None:
        self.margin_size = 0
        self.logger = Logger()
        self.logger.debug = False
        self.font = None
        self.text_margin_left_and_right = 16
        self.fill_values_symbol = "\U0001f4d2"  # πŸ“’
        self.zero_width_space = '\u200B' # zero width space
        self.zero_width_joiner = '\u200D' # zero width joiner
        self.save_symbol = "\U0001F4BE"  # πŸ’Ύ
        self.reload_symbol = "\U0001F504"  # πŸ”„
        self.index_symbol = "\U0001F522"  # πŸ”’
        self.rm_index_symbol = "\U0001F5D1"  # πŸ—‘οΈ
        self.save = Save()
        self.utils = Utils()
        self.civitai_helper = CivitaihelperPrompts()
        self.outdir_txt2img_grids = shared.opts.outdir_txt2img_grids
        self.outdir_img2img_grids = shared.opts.outdir_img2img_grids


    def title(self) -> str:
        return "Batch Checkpoint and Prompt"

    def save_inputs(self, save_name: str, checkpoints: str, prompt_templates: str, action : str) -> str:
        """Save the inputs to a file

        Args:
            save_name (str): the save name
            checkpoints (str): the checkpoints
            prompt_templates (str): the prompt templates
            action (str): Possible values: "No", "Overwrite existing save", "append existing save"

        Returns:
            str: the save status
        """
        overwrite_existing_save = False
        append_existing_save = False
        if action == "Overwrite existing save":
            overwrite_existing_save = True
        elif action == "append existing save":
            append_existing_save = True
        return self.save.store_values(
            save_name.strip(), checkpoints.strip(), prompt_templates.strip(), overwrite_existing_save, append_existing_save)
        

    """ def load_inputs(self, name: str) -> None:
        values = self.save.read_value(name.strip()) """

    def get_checkpoints(self) -> str:
        """Get the checkpoints from the sd_models module.
        Add the index to the checkpoints

        Returns:
            str: the checkpoints
        """
        checkpoint_list_no_index = list(modules.sd_models.checkpoints_list)
        checkpoint_list_with_index = []
        for i in range(len(checkpoint_list_no_index)):
            checkpoint_list_with_index.append(
                f"{checkpoint_list_no_index[i]} @index:{i}")
        return ',\n'.join(checkpoint_list_with_index)

    def getCheckpoints_and_prompt_with_index_and_version(self, checkpoint_list: str, prompts: str, add_model_version: bool) -> Tuple[str, str]:
        """Add the index to the checkpoints and prompts
        and add the model version to the checkpoints

        Args:
            checkpoint_list (str): the checkpoint list
            prompts (str): the prompts
            add_model_version (bool): add the model version to the checkpoints. EXPERIMENTAL!

        Returns:
            Tuple[str, str]: the checkpoints and prompts
        """
        checkpoints = self.utils.add_index_to_string(checkpoint_list)
        if add_model_version:
            checkpoints = self.utils.add_model_version_to_string(checkpoints)
        prompts = self.utils.add_index_to_string(prompts, is_checkpoint=False)
        return checkpoints, prompts
    
    def refresh_saved(self) -> gr.Dropdown:
        """Refresh the saved values dropdown

        Returns:
            gr.Dropdown: the updated dropdown
        """
        return gr.Dropdown.update(choices=self.save.get_keys())
    
    def remove_checkpoints_prompt_at_index(self, checkpoints: str, prompts: str, index: str) -> List[str]:
        """Remove the checkpoint and prompt at the specified index

        Args:
            checkpoints (str): the checkpoints
            prompts (str): the prompts
            index (str): the index

        Returns:
            List[str]: the checkpoints and prompts
        """
        index_list = index.split(",")
        index_list_num = [int(i) for i in index_list]
        return self.utils.remove_element_at_index(checkpoints, prompts, index_list_num)
        
        
        

    def ui(self, is_img2img: bool) -> List[Union[gr.components.Textbox, gr.components.Slider]]:
        """Create the UI

        Args:
            is_img2img (bool): not used.

        Returns:
            List[Union[gr.components.Textbox, gr.components.Slider]]: the UI components
        """
        with gr.Tab("Parameters"):
            with FormRow():
                checkpoints_input = gr.components.Textbox(
                    lines=5, label="Checkpoint Names", placeholder="Checkpoint names (separated with comma)")
                fill_checkpoints_button = ToolButton(
                    value=self.fill_values_symbol, visible=True)
            with FormRow():
                
                checkpoints_prompt = gr.components.Textbox(
                    lines=5, label="Prompts/prompt templates for Checkpoints", placeholder="prompts/prompt templates (separated with semicolon)")
                
                civitai_prompt_fill_button = ToolButton(
                    value=self.fill_values_symbol+self.zero_width_joiner, visible=True)
                add_index_button = ToolButton(
                    value=self.index_symbol, visible=True)
            with FormColumn():
                with FormRow():
                    rm_model_prompt_at_indexes_textbox = gr.components.Textbox(lines=1, label="Remove checkpoint and prompt at index", placeholder="Remove checkpoint and prompt at index (separated with comma)")
                    rm_model_prompt_at_indexes_button = ToolButton(value=self.rm_index_symbol, visible=True)
                margin_size = gr.Slider(
                    label="Grid margins (px)", minimum=0, maximum=10, value=0, step=1)

            # save and load inputs

            with FormRow():
                keys = self.save.get_keys()
                saved_inputs_dropdown = gr.components.Dropdown(
                    choices=keys, label="Saved values")
                
                load_button = ToolButton(
                    value=self.fill_values_symbol+self.zero_width_space, visible=True)
                refresh_button = ToolButton(value=self.reload_symbol, visible=True)


            with FormRow():
                save_name = gr.components.Textbox(
                    lines=1, label="save name", placeholder="save name")
                save_button = ToolButton(value=self.save_symbol, visible=True)
            with FormRow():
                test = gr.components.Radio(["No", "Overwrite existing save", "append existing save"], label="Change saves?")

                save_status = gr.Textbox(label="", interactive=False)

                      
                
                
            with gr.Accordion(label='Advanced settings', open=False):
                gr.Markdown("""
                    This can take a long time depending on the number of checkpoints! <br>
                    See the help tab for more information
                """)
                add_model_version_checkbox = gr.components.Checkbox(label="Add model version to checkpoint names", interactive=False
                                                                    , info="Not working in current webui versions")

            # Actions

            fill_checkpoints_button.click(
                fn=self.get_checkpoints, outputs=[checkpoints_input])
            save_button.click(fn=self.save_inputs, inputs=[
                save_name, checkpoints_input, checkpoints_prompt, test], outputs=[save_status])
            load_button.click(fn=self.save.read_value, inputs=[saved_inputs_dropdown], outputs=[
                checkpoints_input, checkpoints_prompt])
            civitai_prompt_fill_button.click(fn=self.civitai_helper.createCivitaiPromptString, inputs=[
                checkpoints_input], outputs=[checkpoints_prompt])
            add_index_button.click(fn=self.getCheckpoints_and_prompt_with_index_and_version, inputs=[
                                   checkpoints_input, checkpoints_prompt, add_model_version_checkbox], outputs=[checkpoints_input, checkpoints_prompt])
        
            refresh_button.click(fn=self.refresh_saved, outputs=[saved_inputs_dropdown]) 

            rm_model_prompt_at_indexes_button.click(fn=self.remove_checkpoints_prompt_at_index, inputs=[
                                   checkpoints_input, checkpoints_prompt, rm_model_prompt_at_indexes_textbox], outputs=[checkpoints_input, checkpoints_prompt])

        with gr.Tab("help"):
            gr.Markdown(self.utils.get_help_md())

        return [checkpoints_input, checkpoints_prompt, margin_size]

    def show(self, is_img2img: bool) -> bool:
        """Show the UI in text2img and img2img mode

        Args:
            is_img2img (bool): not used 

        Returns:
            bool: True
        """
        return True
        

    def _generate_images_with_SD(self,p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img],
                                  batch_params: BatchParams, orginal_size: Tuple[int, int]) -> modules.processing.Processed:
        """ manipulates the StableDiffusionProcessing Obect
         to generate images with the new checkpoint and prompt
         and other parameters

        Args:
            p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object
            batch_params (BatchParams): the batch parameters
            orginal_size (Tuple[int, int]): the original size specified in the UI

        Returns:
            modules.processing.Processed: the processed object
        """
        self.logger.debug_log(str(batch_params), False)
        
        info = None
        info = modules.sd_models.get_closet_checkpoint_match(batch_params.checkpoint)
        modules.sd_models.reload_model_weights(shared.sd_model, info)
        p.override_settings['sd_model_checkpoint'] = info.name
        p.prompt = batch_params.prompt
        p.negative_prompt = batch_params.neg_prompt
        if len(batch_params.style) > 0:
            p.styles = batch_params.style
        p.n_iter = batch_params.batch_count
        shared.opts.data["CLIP_stop_at_last_layers"] = batch_params.clip_skip
        if batch_params.width > 0 and batch_params.height > 0:
            self.logger.debug_print_attributes(p, False)
            p.height = batch_params.height
            p.width = batch_params.width
        else:
            p.width, p.height = orginal_size
        p.hr_prompt = batch_params.hr_prompt
        p.hr_negative_prompt = p.negative_prompt
        self.logger.debug_log(f"batch count {p.n_iter}")

        processed = process_images(p)

        return processed
    

    def _generate_infotexts(self, pc: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img],
                             all_infotexts: List[str], n_iter: int) -> List[str]:
        """Generate the infotexts for the images

        Args:
            pc (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object
            all_infotexts (List[str]): the infotexts created by A1111
            n_iter (int): the number of iterations

        Returns:
            List[str]: the infotexts
        """

        def _a1111_infotext_caller(i: int = 0) -> str:
            """Call A1111 to create a infotext. This is a helper function.

            Args:
                i (int, optional): the index. Defaults to 0. Used to get the correct seed and subseed.

            Returns:
                str: the infotext
            """
            return processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds, position_in_batch=i)

        self.logger.pretty_debug_log(all_infotexts)


        self.logger.debug_print_attributes(pc)

        if n_iter == 1:
            all_infotexts.append(_a1111_infotext_caller())
        else:
            all_infotexts.append(self.base_prompt)
            for i in range(n_iter * pc.batch_size):
                all_infotexts.append(_a1111_infotext_caller(i))

        return all_infotexts


    def run(self, p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], checkpoints_text: str, checkpoints_prompt: str, margin_size: int) -> modules.processing.Processed:
        """The main function to generate the images

        Args:
            p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object
            checkpoints_text (str): the checkpoints
            checkpoints_prompt (str): the prompts
            margin_size (int): the margin size for the grid

        Returns:
            modules.processing.Processed: the processed object
        """
        image_processed = []
        self.margin_size = margin_size

        def _get_total_batch_count(batchParams: List[BatchParams]) -> int:
            """Get the total batch count to update the progress bar

            Args:
                batchParams (List[BatchParams]): the batch parameters

            Returns:
                int: the total batch count
            """
            summe = 0
            for param in batchParams:
                summe += param.batch_count
            return summe
        
        self.base_prompt: str = p.prompt

        all_batchParams = get_all_batch_params(p, checkpoints_text, checkpoints_prompt)

        total_batch_count = _get_total_batch_count(all_batchParams)
        total_steps = p.steps * total_batch_count
        self.logger.debug_log(f"total steps: {total_steps}")

        shared.state.job_count = total_batch_count
        shared.total_tqdm.updateTotal(total_steps)

        all_infotexts = [self.base_prompt]

        p.extra_generation_params['Script'] = self.title()

        self.logger.log_info(f'will generate {total_batch_count} images over {len(all_batchParams)} checkpoints)')

        original_size = p.width, p.height
        

        for i, checkpoint in enumerate(all_batchParams):

            
            self.logger.log_info(f"checkpoint: {i+1}/{len(all_batchParams)} ({checkpoint.checkpoint})")


            self.logger.debug_log(
                f"Propmpt with replace: {all_batchParams[i].prompt}, neg prompt: {all_batchParams[i].neg_prompt}")
            

            processed_sd_object = self._generate_images_with_SD(p, all_batchParams[i], original_size)

            image_processed.append(processed_sd_object)

            
            all_infotexts = self._generate_infotexts(copy(p), all_infotexts, all_batchParams[i].batch_count)


            if shared.state.interrupted or shared.state.stopping_generation:
                break

        img_grid = self._create_grid(image_processed, all_batchParams)

        image_processed[0].images.insert(0, img_grid)
        image_processed[0].index_of_first_image = 1
        for i, image in enumerate(image_processed):
            if i > 0:
                for j in range(len(image_processed[i].images)):
                    image_processed[0].images.append(
                        image_processed[i].images[j])
                    
            image_processed[0].infotexts = all_infotexts        


        return image_processed[0]

    

    def _create_grid(self, image_processed: List[modules.processing.Processed], all_batch_params: List[BatchParams]) -> PIL.Image.Image:
        """Create the grid with the images

        Args:
            image_processed (List[modules.processing.Processed]): the images
            all_batch_params (List[BatchParams]): the batch parameters

        Returns:
            PIL.Image.Image: the grid
        """
        self.logger.log_info(
            "creating the grid. This can take a while, depending on the amount of images")

        def _getFileName(save_path: str) -> str:
            """Get the file name for the grid.
            The files are acsending numbered.

            Args:
                save_path (str): the save path

            Returns:
                str: the file name
            """
            save_path = os.path.join(save_path, "Checkpoint-Prompt-Loop")
            self.logger.debug_log(f"save path: {save_path}")
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            files = os.listdir(save_path)
            pattern = r"img_(\d{4})"

            matching_files = [f for f in files if re.match(pattern, f)]

            if matching_files:

                matching_files.sort()
                last_file = matching_files[-1]
                match = re.search(r"\d{4}", last_file)
                number = int(match.group()) if match else 0
            else:
                number = 0

            new_number = number + 1

            return os.path.join(save_path, f"img_{new_number:04d}.png")

        total_width = 0
        max_height = 0
        min_height = 0

        spacing = self.margin_size

        
        for img in image_processed:
            total_width += img.images[0].size[0] + spacing

        img_with_legend = []
        for i, img in enumerate(image_processed):
            img_with_legend.append(self._add_legend(
                img.images[0], all_batch_params[i].checkpoint))

        for img in img_with_legend:
            max_height = max(max_height, img.size[1])
            min_height = min(min_height, img.size[1])

        result_img = Image.new('RGB', (total_width, max_height), "white")

        x_offset = -spacing
        for i, img in enumerate(img_with_legend):
            y_offset = max_height - img.size[1]
            result_img.paste(((0, 0, 0)), (x_offset, 0, x_offset +
                            img.size[0] + spacing, max_height + spacing))
            result_img.paste(((255, 255, 255)), (x_offset, 0,
                            x_offset + img.size[0], max_height - min_height))
            result_img.paste(img, (x_offset + spacing, y_offset))

            x_offset += img.size[0] + spacing

        if self.is_img2img:
            result_img.save(_getFileName(self.outdir_img2img_grids))
        else:
            result_img.save(_getFileName(self.outdir_txt2img_grids))

        return result_img
        
    def _add_legend(self, img: Image, checkpoint_name: str) -> Image:
        """Add the checkpoint name to the image

        Args:
            img (Image): the image
            checkpoint_name (str): the checkpoint name

        Returns:
            Image: the image with the checkpoint name as legend
        """

        def _find_available_font() -> str: #TODO: make this method more efficient
            """Find an available font

            Returns:
                str: the font
            """

            if self.font is None:

                self.font = fm.findfont(
                    fm.FontProperties(family='DejaVu Sans'))

                if self.font is None:
                    font_list = fm.findSystemFonts(
                        fontpaths=None, fontext='ttf')

                    for font_file in font_list:
                        self.font = os.path.abspath(font_file)
                        if os.path.isfile(self.font): # type: ignore
                            self.logger.debug_log("font list font")
                            return self.font # type: ignore

                    self.logger.debug_log("default font")
                    return ImageFont.load_default()
                self.logger.debug_log("DejaVu font")

            return self.font

        def _strip_checkpoint_name(checkpoint_name: str) -> str:
            """Remove the path from the checkpoint name

            Args:
                checkpoint_name (str): the checkpoint with path

            Returns:
                str: the checkpoint name
            """
            checkpoint_name = os.path.basename(checkpoint_name)
            return self.utils.get_clean_checkpoint_path(checkpoint_name)

        def _calculate_font(draw: ImageDraw, text: str, width: int) -> Tuple[int, int]:
            """Calculate the font size for the text according to the image width

            Args:
                draw (ImageDraw): the draw object
                text (str): the text
                width (int): the image width

            Returns:
                Tuple[int, int]: the font and the text height
            """
            width -= self.text_margin_left_and_right
            default_font_path = _find_available_font()
            font_size = 1
            font = ImageFont.truetype(
                default_font_path, font_size) if default_font_path else ImageFont.load_default()
            text_width, text_height = draw.textsize(text, font)

            while text_width < width:
                self.logger.debug_log(
                    f"text width: {text_width}, img width: {width}")
                font_size += 1
                font = ImageFont.truetype(
                    default_font_path, font_size) if default_font_path else ImageFont.load_default()
                text_width, text_height = draw.textsize(text, font)

            return font, text_height

        checkpoint_name = _strip_checkpoint_name(checkpoint_name)

        width, height = img.size

        draw = ImageDraw.Draw(img)

        font, text_height = _calculate_font(draw, checkpoint_name, width)

        new_image = Image.new("RGB", (width, height + text_height), "white")
        new_image.paste(img, (0, text_height))

        new_draw = ImageDraw.Draw(new_image)

        new_draw.text((self.text_margin_left_and_right/4, 0),
                      checkpoint_name, fill="black", font=font)

        return new_image