dikdimon's picture
Upload extensions using SD-Hub extension
bb7f1f4 verified
"""This module provides utility functions."""
from scripts.Logger import Logger
import os
import re
import requests
from typing import List
import modules
from modules.sd_models import read_state_dict
from modules.sd_models_config import (find_checkpoint_config, config_default, config_sd2, config_sd2v, config_sd2_inpainting,
config_depth_model, config_unclip, config_unopenclip, config_inpainting, config_instruct_pix2pix, config_alt_diffusion)
import sys
sys.path.insert(0, os.path.join(os.path.dirname(
os.path.abspath(__file__)), "scripts"))
class Utils():
"""
methods that are needed in different classes
"""
def __init__(self) -> None:
self.logger = Logger()
self.logger.debug = False
script_path = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
self.held_md_file_name = os.path.join(
script_path, "HelpBatchCheckpointsPrompt.md")
self.held_md_url = f"https://raw.githubusercontent.com/h43lb1t0/BatchCheckpointPrompt/main/{self.held_md_file_name}.md"
def split_prompts(self, text: str) -> List[str]:
"""Split the prompts by the ; and remove empty strings and newlines
Args:
text (str): the input string
Returns:
List[str]: a list of prompts
"""
prompt_list = text.split(";")
return [prompt.replace('\n', '').strip(
) for prompt in prompt_list if not prompt.isspace() and prompt != '']
def remove_index_from_string(self, input: str) -> str:
"""Remove the index from the string
Args:
input (str): the input string
Returns:
str: the string without the index
"""
return re.sub(r"@index:\d+", "", input).strip()
def remove_model_version_from_string(self, checkpoints_text: str) -> str:
"""Remove the model version from the string
Args:
input (str): the input string with all checkpoints
Returns:
str: the string without the model version
"""
patterns = [
'@version:sd1',
'@version:sd2',
'@version:sd2v',
'@version:sd2-inpainting',
'@version:depth',
'@version:unclip',
'@version:unopenclip',
'@version:sd1-inpainting',
'@version:pix2pix',
'@version:alt'
]
# Iterate over the patterns and substitute them with an empty string
for pattern in patterns:
checkpoints_text = re.sub(pattern, '', checkpoints_text)
return checkpoints_text
def get_clean_checkpoint_path(self, checkpoint: str) -> str:
"""Remove the checkpoint hash from the filename
Args:
input (str): the input string with hash
Returns:
str: the string without the hash
"""
return re.sub(r' \[.*?\]', '', checkpoint).strip()
def getCheckpointListFromInput(self, checkpoints_text: str, clean: bool = True) -> List[str]:
"""Get a list of checkpoints from the input string
Args:
checkpoints_text (str): the input string with all checkpoints
clean (bool): remove the index and hash from the string
Returns:
List[str]: a list of checkpoints
"""
self.logger.debug_log(f"checkpoints: {checkpoints_text}")
checkpoints_text = self.remove_model_version_from_string(checkpoints_text)
if clean:
checkpoints_text = self.remove_index_from_string(checkpoints_text)
checkpoints_text = self.get_clean_checkpoint_path(checkpoints_text)
checkpoints = checkpoints_text.split(",")
checkpoints = [checkpoint.replace('\n', '').strip(
) for checkpoint in checkpoints if checkpoints if not checkpoint.isspace() and checkpoint != '']
return checkpoints
def get_help_md(self) -> str:
"""Gets the help md file.
If the file is not localy found downloads it from the github repository
Returns:
str: the help md file as a string
"""
md = "could not get help file. Check Github for more information"
if os.path.isfile(self.held_md_file_name):
with open(self.held_md_file_name) as f:
md = f.read()
else:
self.logger.debug_log("downloading help md")
result = requests.get(self.held_md_url)
if result.status_code == 200:
with open(self.held_md_file_name, "wb") as file:
file.write(result.content)
return self.get_help_md()
return md
def add_index_to_string(self, text: str, is_checkpoint: bool = True) -> str:
"""Add the index to the string
Args:
text (str): the input string
is_checkpoint (bool): if the string is a checkpoint lits or a prompt list
Returns:
str: the string with the index
"""
text_string = ""
if is_checkpoint:
checkpoint_List = self.getCheckpointListFromInput(text)
for i, checkpoint in enumerate(checkpoint_List):
text_string += f"{self.remove_index_from_string(checkpoint)} @index:{i},\n"
return text_string
else:
prompt_list = self.split_prompts(text)
for i, prompt in enumerate(prompt_list):
text_string += f"{self.remove_index_from_string(prompt)} @index:{i};\n\n"
return text_string
def add_model_version_to_string(self, checkpoints_text: str) -> str:
"""Add the model version to the string.
EXPERIMENTAL!
Args:
checkpoints_text (str): the input string with all checkpoints
Returns:
str: the string with the model version
"""
text_string = ""
checkpoints_not_cleaned = self.getCheckpointListFromInput(
checkpoints_text, clean=False)
checkpoints = self.getCheckpointListFromInput(checkpoints_text)
for i, checkpoint in enumerate(checkpoints):
info = modules.sd_models.get_closet_checkpoint_match(checkpoint)
state_dict = read_state_dict(info.filename)
version_string = find_checkpoint_config(state_dict, None)
if version_string == config_default:
version_string = "sd1"
elif version_string == config_sd2:
version_string = "sd2"
elif version_string == config_sd2v:
version_string = "sd2v"
elif version_string == config_sd2_inpainting:
version_string = "sd2-inpainting"
elif version_string == config_depth_model:
version_string = "depth"
elif version_string == config_unclip:
version_string = "unclip"
elif version_string == config_unopenclip:
version_string = "unopenclip"
elif version_string == config_inpainting:
version_string = "sd1-inpainting"
elif version_string == config_instruct_pix2pix:
version_string = "pix2pix"
elif version_string == config_alt_diffusion:
version_string = "alt"
checkpoint_partly_cleaned = checkpoints_not_cleaned[i].replace(
"\n", "").replace(",", "")
text_string += f"{checkpoint_partly_cleaned} @version:{version_string},\n\n"
return text_string
def remove_element_at_index(self, checkpoints: str, prompts: str, index: List[int]) -> List[str]:
"""Remove the element at the given index from the string
Args:
checkpoints (str): the input string with all checkpoints
prompts (str): the input string with all prompts
index (List[int]): the indices to remove
Returns:
List[str]: a list with the new checkpoints and prompts
"""
checkpoints_list = self.getCheckpointListFromInput(checkpoints)
prompts_list = self.split_prompts(prompts)
if (len(checkpoints_list) == len(prompts_list) or len(prompts_list) - len(index) <= 0 ):
if max(index) <= len(checkpoints_list) -1:
for i in index:
checkpoints_list.pop(i)
prompts_list.pop(i)
checkpoints = ""
for c in checkpoints_list:
checkpoints += f"{c},"
prompts = ""
for p in prompts_list:
prompts += f"{p};"
result = [self.add_index_to_string(checkpoints, True), self.add_index_to_string(prompts, False)]
self.logger.debug_log(f"result: {result}")
return result
else:
self.logger.debug_log("index is out of range")
return [checkpoints, prompts]
else:
self.logger.debug_log(
f"checkpoints and prompts are not the same length cp: {len(checkpoints_list)} p: {len(prompts_list)}")
return [checkpoints, prompts]