Coverage for cli/src/version_finder_cli/cli.py: 26%
268 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-18 10:30 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-18 10:30 +0000
1# version_finder/__main__.py
2import argparse
3import sys
4import os
5import re
6from typing import List, Any
7from prompt_toolkit import prompt
8from prompt_toolkit.styles import Style
9from prompt_toolkit.completion import WordCompleter, PathCompleter
10from prompt_toolkit.validation import Validator, ValidationError
11from version_finder.logger import get_logger
12from version_finder.version_finder import VersionFinder, GitError
13from version_finder.version_finder import VersionFinderTask, VersionFinderTaskRegistry
14from version_finder.common import parse_arguments
15import threading
16import time
18# Initialize module logger
19logger = get_logger()
22class TaskNumberValidator(Validator):
23 def __init__(self, min_index: int, max_index: int):
24 self.min_index = min_index
25 self.max_index = max_index
27 def validate(self, document):
28 text = document.text.strip()
29 if not text:
30 raise ValidationError(message="Task number cannot be empty")
31 try:
32 task_idx = int(text)
33 if not (self.min_index <= task_idx <= self.max_index):
34 raise ValidationError(
35 message=f"Please select a task number between {self.min_index} and {self.max_index}")
36 except ValueError:
37 raise ValidationError(message="Please enter a valid number")
40class CommitSHAValidator(Validator):
41 def validate(self, document):
42 text = document.text.strip()
43 if not text:
44 raise ValidationError(message="Commit SHA cannot be empty")
45 # Allow full SHA (40 chars), short SHA (min 7 chars), or HEAD~n format
46 if not (len(text) >= 7 and len(text) <= 40) and not text.startswith("HEAD~"):
47 raise ValidationError(message="Invalid commit SHA format. Use 7-40 hex chars or HEAD~n format")
50class ProgressIndicator:
51 """A simple progress indicator for CLI operations"""
53 def __init__(self, message="Processing"):
54 self.message = message
55 self.running = False
56 self.thread = None
58 def start(self):
59 """Start the progress indicator"""
60 self.running = True
61 self.thread = threading.Thread(target=self._show_progress)
62 self.thread.daemon = True
63 self.thread.start()
65 def stop(self):
66 """Stop the progress indicator"""
67 self.running = False
68 if self.thread:
69 self.thread.join()
70 # Clear the indicator line
71 sys.stdout.write("\r" + " " * (len(self.message) + 10) + "\r")
72 sys.stdout.flush()
74 def _show_progress(self):
75 """Show the progress indicator animation"""
76 symbols = ["-", "\\", "|", "/"]
77 i = 0
78 while self.running:
79 sys.stdout.write(f"\r{self.message} {symbols[i]} ")
80 sys.stdout.flush()
81 time.sleep(0.1)
82 i = (i + 1) % len(symbols)
85def with_progress(message):
86 """Decorator to add progress indicator to functions"""
87 def decorator(func):
88 def wrapper(*args, **kwargs):
89 progress = ProgressIndicator(message)
90 progress.start()
91 try:
92 result = func(*args, **kwargs)
93 return result
94 finally:
95 progress.stop()
96 return wrapper
97 return decorator
100class VersionFinderCLI:
101 """
102 Version Finder CLI class.
103 """
105 def __init__(self):
106 """
107 Initialize the VersionFinderCLI with a logger.
108 """
109 self.registry = VersionFinderTaskRegistry()
110 self.prompt_style = Style.from_dict({
111 # User input (default text).
112 # '': '#ff0066',
114 # Prompt.
115 'current_status': '#00aa00',
116 })
118 def get_task_functions(self) -> List[VersionFinderTask]:
119 """
120 Get the list of available task functions.
122 Returns:
123 List[VersionFinderTask]: List of available task functions.
124 """
125 tasks_actions = {}
126 for task in self.registry._tasks_by_index.values():
127 if (task.name == "Find all commits between two versions"):
128 tasks_actions[task.index] = (self.find_all_commits_between_versions)
129 continue
130 if (task.name == "Find commit by text"):
131 tasks_actions[task.index] = (self.find_commits_by_text)
132 continue
133 if (task.name == "Find first version containing commit"):
134 tasks_actions[task.index] = (self.find_first_version_containing_commit)
135 continue
136 return tasks_actions
138 def run(self, args: argparse.Namespace):
139 """
140 Run the CLI with the provided arguments.
142 Args:
143 args: Parsed command-line arguments.
145 Returns:
146 int: 0 on success, 1 on error
147 """
148 try:
149 self.path = self.handle_path_input(args.path)
151 # Initialize VersionFinder with force=True to allow uncommitted changes
152 self.finder = VersionFinder(path=self.path, force=True)
154 # Check for uncommitted changes
155 state = self.finder.get_saved_state()
156 if state.get("has_changes", False):
157 logger.warning("Repository has uncommitted changes")
158 has_submodules = bool(state.get("submodules", {}))
160 if not args.force:
161 # Build message with details about what will happen
162 message = (
163 "Repository has uncommitted changes. Version Finder will:\n"
164 "1. Stash your changes with a unique identifier\n"
165 "2. Perform the requested operations\n"
166 "3. Restore your original branch and stashed changes when closing\n"
167 )
169 if has_submodules:
170 message += "Submodules with uncommitted changes will also be handled similarly.\n"
172 message += "Proceed anyway? (y/N): "
174 proceed = input(message).lower() == 'y'
175 if not proceed:
176 logger.info("Operation cancelled by user")
177 return 0
179 actions = self.get_task_functions()
180 params = self.finder.get_task_api_functions_params()
181 self.registry.initialize_actions_and_args(actions, params)
183 self.branch = self.handle_branch_input(args.branch)
185 self.finder.update_repository(self.branch)
187 self.task_name = self.handle_task_input(args.task)
189 self.run_task(self.task_name)
191 # Restore original state if requested
192 if args.restore_state:
193 logger.info("Restoring original repository state")
195 # Get the state before restoration for logging
196 state = self.finder.get_saved_state()
197 has_changes = state.get("has_changes", False)
198 stash_created = state.get("stash_created", False)
200 if has_changes:
201 if stash_created:
202 logger.info("Attempting to restore stashed changes")
203 else:
204 logger.warning("Repository had changes but they were not stashed")
206 # Perform the restoration
207 if self.finder.restore_repository_state():
208 logger.info("Original repository state restored successfully")
210 # Verify the restoration
211 current_branch = self.finder.get_current_branch()
212 original_branch = state.get("branch")
213 if original_branch and current_branch:
214 if original_branch.startswith("HEAD:"):
215 logger.info("Restored to detached HEAD state")
216 else:
217 logger.info(f"Restored to branch: {current_branch}")
219 # Check if we still have uncommitted changes
220 if has_changes and self.finder.has_uncommitted_changes():
221 logger.info("Uncommitted changes were successfully restored")
222 elif has_changes and not self.finder.has_uncommitted_changes():
223 logger.error("Failed to restore uncommitted changes")
224 else:
225 logger.warning("Failed to restore original repository state")
227 except KeyboardInterrupt:
228 logger.info("\nOperation cancelled by user")
230 # Try to restore original state
231 if hasattr(self, 'finder') and self.finder and args.restore_state:
232 logger.info("Restoring original repository state")
233 if self.finder.restore_repository_state():
234 logger.info("Original repository state restored successfully")
235 else:
236 logger.warning("Failed to restore original repository state")
238 return 0
239 except Exception as e:
240 logger.error("Error during task execution: %s", str(e))
242 # Try to restore original state
243 if hasattr(self, 'finder') and self.finder and args.restore_state:
244 logger.info("Restoring original repository state")
245 if self.finder.restore_repository_state():
246 logger.info("Original repository state restored successfully")
247 else:
248 logger.warning("Failed to restore original repository state")
250 return 1
252 def handle_task_input(self, task_name: str) -> str:
253 """
254 Handle task input from user.
255 """
256 if task_name is None:
257 print("You have not selected a task.")
258 print("Please select a task:")
259 # Iterate through tasks in index order
260 for task in self.registry.get_tasks_by_index():
261 print(f"{task.index}: {task.name}")
262 min_index = self.registry.get_tasks_by_index()[0].index
263 max_index = self.registry.get_tasks_by_index()[-1].index
265 task_validator = TaskNumberValidator(min_index, max_index)
266 task_idx = int(prompt(
267 "Enter task number: ",
268 validator=task_validator,
269 validate_while_typing=True
270 ).strip())
272 logger.debug("Selected task: %d", task_idx)
273 if not self.registry.has_index(task_idx):
274 logger.error("Invalid task selected")
275 sys.exit(1)
277 task_struct = self.registry.get_by_index(task_idx)
278 return task_struct.name
280 def handle_branch_input(self, branch_name: str) -> str:
281 """
282 Handle branch input from user with auto-completion.
284 Args:
285 branch_name: Optional branch name from command line
287 Returns:
288 str: Selected branch name
289 """
290 if branch_name is not None:
291 return branch_name
293 branches = self.finder.list_branches()
294 # When creating the branch_completer, modify it to:
295 branch_completer = WordCompleter(
296 branches,
297 ignore_case=True,
298 match_middle=True,
299 pattern=re.compile(r'\S+') # Matches non-whitespace characters
300 )
302 current_branch = self.finder.get_current_branch()
303 logger.info("Current branch: %s", current_branch)
305 if current_branch:
306 prompt_message = [
307 ('', 'Current branch: '),
308 ('class:current_status', f'{current_branch}'),
309 ('', '\nPress [ENTER] to use the current branch or type to select a different branch: '),
310 ]
311 branch_name = prompt(
312 prompt_message,
313 completer=branch_completer,
314 complete_while_typing=True,
315 style=self.prompt_style
316 ).strip()
317 if branch_name == "":
318 return current_branch
319 return branch_name
321 def handle_submodule_input(self, submodule_name: str = None) -> str:
322 """
323 Handle branch input from user.
324 """
325 if submodule_name is None:
326 submodule_list = self.finder.list_submodules()
327 submodule_completer = WordCompleter(submodule_list, ignore_case=True, match_middle=True)
328 # Take input from user
329 submodule_name = prompt(
330 "\nEnter submodule name (Tab for completion) or [ENTER] to continue without a submodule:",
331 completer=submodule_completer,
332 complete_while_typing=True
333 ).strip()
334 return submodule_name
336 def handle_path_input(self, path: str = None) -> str:
337 """
338 Handle path input from user using prompt_toolkit.
340 Args:
341 path: Optional path from command line
343 Returns:
344 str: Path entered by user
345 """
346 if path is None:
347 prompt_msg = [
348 ('', 'Current directory: '),
349 ('class:current_status', f'{os.getcwd()}'),
350 ('', ':\nPress [ENTER] to use the current directory or type to select a different directory: '),
351 ]
353 path_completer = PathCompleter(
354 only_directories=True,
355 expanduser=True
356 )
357 path = prompt(
358 prompt_msg,
359 completer=path_completer,
360 complete_while_typing=True,
361 style=self.prompt_style
362 ).strip()
364 if not path:
365 path = os.getcwd()
367 # Validate the path
368 if not os.path.exists(path) or not os.path.isdir(path):
369 print(f"Error: Invalid path '{path}'", file=sys.stderr)
370 sys.exit(1)
372 return path
374 def get_branch_selection(self) -> str:
375 """
376 Get branch selection from user with auto-completion.
378 Returns:
379 Selected branch name
380 """
381 branches = self.finder.list_branches()
382 branch_completer = WordCompleter(branches, ignore_case=True, match_middle=True)
384 while True:
385 try:
386 logger.debug("\nAvailable branches:")
387 for branch in branches:
388 logger.debug(f" - {branch}")
390 branch = prompt(
391 "\nEnter branch name (Tab for completion): ",
392 completer=branch_completer,
393 complete_while_typing=True
394 ).strip()
396 if branch in branches:
397 return branch
399 logger.error("Invalid branch selected")
401 except (KeyboardInterrupt, EOFError):
402 logger.info("\nOperation cancelled by user")
403 sys.exit(0)
405 def run_task(self, task_name: str):
406 """
407 Run the selected task.
408 """
409 # task_args = self.fetch_arguments_per_task(task_name)
410 self.registry.get_by_name(task_name).run()
412 def fetch_arguments_per_task(self, task_name: str) -> List[Any]:
413 """
414 Fetch arguments for the selected task.
415 """
416 task_args = []
417 for arg_name in self.registry.get_by_name(task_name).args:
418 arg_value = getattr(self, arg_name)
419 task_args.append(arg_value)
420 return task_args
422 @with_progress("Finding version for commit")
423 def find_first_version_containing_commit(self, commit_sha: str, submodule: str = None):
424 """
425 Find the first version containing a commit.
427 Args:
428 commit_sha: The commit SHA to find the version for
429 submodule: Optional submodule path
430 """
431 try:
432 version = self.finder.find_version(commit_sha, submodule)
433 if version:
434 print(f"\nVersion for commit {commit_sha}: {version}")
435 else:
436 print(f"\nNo version found for commit {commit_sha}")
437 except Exception as e:
438 print(f"\nError: {str(e)}")
440 @with_progress("Finding commits between versions")
441 def find_all_commits_between_versions(self, from_version: str, to_version: str, submodule: str = None):
442 """
443 Find all commits between two versions.
445 Args:
446 from_version: The starting version
447 to_version: The ending version
448 submodule: Optional submodule path
449 """
450 try:
451 commits = self.finder.find_commits_between_versions(from_version, to_version, submodule)
452 if commits:
453 print(f"\nFound {len(commits)} commits between {from_version} and {to_version}:")
454 for commit in commits:
455 print(f"{commit.sha[:8]} - {commit.subject}")
456 else:
457 print(f"\nNo commits found between {from_version} and {to_version}")
458 except Exception as e:
459 print(f"\nError: {str(e)}")
461 @with_progress("Searching for commits")
462 def find_commits_by_text(self, text: str, submodule: str = None):
463 """
464 Find commits containing specific text.
466 Args:
467 text: The text to search for
468 submodule: Optional submodule path
469 """
470 try:
471 commits = self.finder.find_commits_by_text(text, submodule)
472 if commits:
473 print(f"\nFound {len(commits)} commits containing '{text}':")
474 for commit in commits:
475 print(f"{commit.sha[:8]} - {commit.subject}")
476 else:
477 print(f"\nNo commits found containing '{text}'")
478 except Exception as e:
479 print(f"\nError: {str(e)}")
482def cli_main(args: argparse.Namespace) -> int:
483 """Main entry point for the version finder CLI."""
484 # Parse arguments
485 if args.version:
486 from .__version__ import __version__
487 print(f"version_finder cli-v{__version__}")
488 return 0
490 # Initialize CLI
491 cli = VersionFinderCLI()
492 # Run CLI
493 try:
494 cli.run(args)
495 return 0
496 except GitError as e:
497 logger.error("Git operation failed: %s", e)
498 return 1
501def main() -> int:
503 args = parse_arguments()
504 return cli_main(args)
507if __name__ == "__main__":
508 sys.exit(main())