Source code for danoan.llm_assistant.prompt.core.api

"""
prompt-manager interface.
"""

from danoan.llm_assistant.common import config
from danoan.llm_assistant.common.logging_config import setup_logging
from danoan.llm_assistant.common.model import PromptRepositoryConfiguration
from danoan.llm_assistant.prompt.core import model, utils

import copy
from dataclasses import asdict, dataclass
from enum import Enum
import git
import logging
from pathlib import Path
import re
import toml
from typing import Generator, List, Literal, Optional

setup_logging()
logger = logging.getLogger(__name__)


[docs] def get_prompts_folder() -> Path: """ Return path to the folder where all tracked prompts are located. """ llma_config = config.get_configuration() return config.get_absolute_configuration_path( llma_config.prompt.prompt_collection_folder )
[docs] def get_prompt_configuration_filepath(prompt_name: str) -> Path: """ Return path to prompt configuration file. """ return get_prompts_folder() / prompt_name / "config.toml"
[docs] def is_prompt_repository(path: Path) -> bool: """ Check if a path points to a prompt repository. A prompt repository must have a `config.toml` that must have the following keys. - user_prompt - system_prompt """ configuration_prompt_path = path / "config.toml" if not configuration_prompt_path.exists(): return False logger.debug(f"Loading: {configuration_prompt_path}") obj = toml.load(configuration_prompt_path) mandatory_keys = ["user_prompt", "system_prompt"] for key in mandatory_keys: if key not in obj: logger.debug(f"Missing mandatory key {key}") return False return True
[docs] def get_tracked_prompt(repository_name: str) -> model.TrackedPrompt: """ Return a TrackedPrompt from the prompt repository. Raises: FileNotFoundError: if prompt folder does not exist. AttributeError: if head points to an invalid object or referece. git.exc.InvalidGitRepositoryError: if prompt folder is not a valid git repository. """ repository_folder = get_prompts_folder() / repository_name if not repository_folder.exists(): raise FileNotFoundError() repository = git.Repo(repository_folder) current_tag = None if repository.head.is_valid(): current_commit = repository.head.commit for tag in repository.tags: if tag.commit == current_commit: current_tag = tag.name branches = [b for b in repository.branches if b.name not in ["master"]] logger.debug(f"Branches of {repository_name}: {branches}") return model.TrackedPrompt( repository_name, repository_folder, current_tag, branches )
[docs] def get_tracked_prompts() -> Generator[model.TrackedPrompt, None, None]: """ Return a list of all prompts tracked by the tool. A prompt is considered tracked if it is a prompt repository located at `get_prompts_folder`. A prompt repository is a folder with a `config.toml` file that can be loaded as a PromptConfiguration. """ prompts_folder = get_prompts_folder() for x in prompts_folder.iterdir(): if not is_prompt_repository(x): continue yield get_tracked_prompt(x.name)
[docs] def get_prompt_test_regression_filepath(prompt_name: str) -> Path: """ Get regression test file. """ return ( get_prompts_folder() / prompt_name / "tests" / "regression" / "regression.json" )
def __resolve_version__(prompt_name: str, version: str): tp = get_tracked_prompt(prompt_name) if version.strip() == "" or version.lower() == "last": all_versions = get_prompt_versions(tp.repository_path) if len(all_versions) == 0: return version return all_versions[-1] return version
[docs] def sync(repo_config: PromptRepositoryConfiguration, progress_callback=None): """ Read prompt repository configuration file and sync local folder with the specified prompt version. """ class Events(Enum): SYNC_CONFIG = "sync_config" FETCH = "fetch" CHECKOUT = "checkout" SYNCED = "synced" SYNC_LOCAL_FOLDER = "sync_prompt_collection_folder" NOT_TRACKED = "not_tracked" NOT_PROMPT_REPOSITORY = "not_prompt_repository" GIT = "git" ITEM = "item" BEGIN = "begin" END = "end" @dataclass class SyncItem: event: Literal[ Events.SYNC_CONFIG, Events.FETCH, Events.CHECKOUT, Events.SYNCED, Events.SYNC_LOCAL_FOLDER, Events.NOT_TRACKED, Events.NOT_PROMPT_REPOSITORY, Events.GIT, Events.ITEM, Events.BEGIN, Events.END, ] name: Optional[str] = None value: Optional[str] = None def _progress_callback(sync_item: SyncItem): if progress_callback: d = asdict(sync_item) d["event"] = sync_item.event.value progress_callback(**d) else: return def _git_progress_callback(op_code, cur_count, max_count=None, message=""): _progress_callback(SyncItem(Events.GIT)) _progress_callback(SyncItem(Events.BEGIN)) _progress_callback(SyncItem(Events.ITEM, "op_code", op_code)) _progress_callback(SyncItem(Events.ITEM, "cur_count", cur_count)) _progress_callback(SyncItem(Events.ITEM, "max_count", max_count)) _progress_callback(SyncItem(Events.ITEM, "message", message)) _progress_callback(SyncItem(Events.END)) # Sync prompt repository configuration logger.debug("Align prompt collection folder with configuration file") _progress_callback(SyncItem(Events.SYNC_CONFIG)) for prompt_name, version in repo_config.versioning.items(): logger.debug(f"Start syncing of {prompt_name} {version}") _progress_callback(SyncItem(Events.SYNC_CONFIG, "prompt_name", prompt_name)) _progress_callback(SyncItem(Events.SYNC_CONFIG, "version", version)) prompt_folder = get_prompts_folder() / prompt_name prompt_repo_url = f"https://github.com/{repo_config.git_user}/{prompt_name}.git" repo = None if not prompt_folder.exists(): _progress_callback(SyncItem(Events.FETCH, "prompt", prompt_repo_url)) repo = git.Repo.clone_from( prompt_repo_url, prompt_folder, progress=_git_progress_callback ) else: repo = git.Repo(prompt_folder) version = __resolve_version__(prompt_name, version) repo.remote().fetch(tags=True) _progress_callback(SyncItem(Events.CHECKOUT, "version", version)) repo.git.checkout(f"tags/v{version}") # Sync prompt repository local folder logger.debug("Align configuration file with prompt collection folder") updated_repo_config = copy.deepcopy(repo_config) _progress_callback(SyncItem(Events.SYNC_LOCAL_FOLDER)) for prompt_folder in get_prompts_folder().iterdir(): _progress_callback(SyncItem(Events.SYNC_LOCAL_FOLDER, "folder", prompt_folder)) if is_prompt_repository(prompt_folder): prompt_name = prompt_folder.stem if prompt_name not in repo_config.versioning.keys(): tp = get_tracked_prompt(prompt_name) updated_repo_config.versioning[prompt_name] = tp.current_tag _progress_callback( SyncItem(Events.NOT_TRACKED, "version", tp.current_tag) ) else: _progress_callback(SyncItem(Events.SYNCED)) else: _progress_callback(SyncItem(Events.NOT_PROMPT_REPOSITORY)) llma_config = config.get_configuration() llma_config.prompt = updated_repo_config with open(config.get_configuration_filepath(), "w") as f: toml.dump(llma_config.__asdict__(), f)
################################### # Prompt Versioning ###################################
[docs] def get_prompt_versions(repository_path: Path) -> List[str]: """ List all the versions available in a local prompt repository. """ l = sorted([model.PromptVersion(v) for v in utils.get_versions(repository_path)]) return [str(v) for v in l]
[docs] def get_most_recent_version_before_commit( repository_folder: Path, commit_hash: str ) -> Optional[str]: """ Find the most recent version prior to a given commit hash. """ tags = utils.get_most_recent_tags_before_commit(repository_folder, commit_hash) for t in tags: if t.name[0] != "v": continue return t.name[1:] return None
[docs] def get_current_version(repository_folder: Path) -> Optional[str]: """ Get the most recent version prior to HEAD. """ repo = git.Repo(repository_folder) current_tags = utils.get_commit_tags(repository_folder, repo.commit()) no_version_at_all = ( len(utils.get_most_recent_tags_before_commit(repository_folder, repo.commit())) == 0 ) if no_version_at_all: return None current_commit_is_not_versioned = len(current_tags) == 0 and not no_version_at_all most_recent_version = None if current_commit_is_not_versioned: most_recent_version = get_most_recent_version_before_commit( repository_folder, repo.commit() ) else: most_recent_version = str(current_tags[0])[1:] return most_recent_version
[docs] def suggest_next_version( repository_path: Path, current_version: str, cn: model.ChangeNature ) -> str: """ Suggest a new version based on the type of changes. If PromptTweak, increment the fix unit. If InterfaceUpdate, increment the minor unit. If ScopeChange, increment the major unit. """ sorted_tags = get_prompt_versions(repository_path) vi = model.PromptVersion(current_version, model.PromptVersion.IncrementerType.Fix) if cn == model.ChangeNature.PromptTweak: vi = model.PromptVersion( current_version, model.PromptVersion.IncrementerType.Fix ) elif cn == model.ChangeNature.InterfaceUpdate: vi = model.PromptVersion( current_version, model.PromptVersion.IncrementerType.Minor ) elif cn == model.ChangeNature.ScopeChange: vi = model.PromptVersion( current_version, model.PromptVersion.IncrementerType.Major ) while str(vi) in set(sorted_tags): vi += 1 return str(vi)
[docs] def update_changelog(changelog_str: str) -> str: """ Rewrites the changelog document such that versions are listed in numerical order and divided by major versions. """ def get_next_header(s: str, start: int): return s.find("#", start) def get_next_version_header(s: str, start: int): return s.find("##", start) p_version_part = r"\d+[\.xyz]?" p = f"({p_version_part}{p_version_part}{p_version_part})" content_versions = [] version_set = set() start = 0 while start != -1: match_start = get_next_version_header(changelog_str, start) if match_start == -1: break m = re.search(p, changelog_str[match_start:]) if not m: break content_start = match_start + m.span()[1] content_end = get_next_header(changelog_str, content_start) content = changelog_str[content_start:content_end] start = content_end version_name = m.group() if version_name in version_set: continue content_versions.append((model.PromptVersion(version_name), content)) version_set.add(version_name) sorted_items = sorted(content_versions, key=lambda x: x[0]) current_major_title = None update_title = False rewriten_document = "" for item in sorted_items: version, content = item if current_major_title != version.major: current_major_title = version.major update_title = True if update_title: rewriten_document += f"# V{current_major_title}\n\n" update_title = False rewriten_document += f"## {version}" rewriten_document += content + "\n\n" return rewriten_document