Coverage for cli/src/version_finder_cli/cli.py: 26%

268 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-18 10:30 +0000

1# version_finder/__main__.py 

2import argparse 

3import sys 

4import os 

5import re 

6from typing import List, Any 

7from prompt_toolkit import prompt 

8from prompt_toolkit.styles import Style 

9from prompt_toolkit.completion import WordCompleter, PathCompleter 

10from prompt_toolkit.validation import Validator, ValidationError 

11from version_finder.logger import get_logger 

12from version_finder.version_finder import VersionFinder, GitError 

13from version_finder.version_finder import VersionFinderTask, VersionFinderTaskRegistry 

14from version_finder.common import parse_arguments 

15import threading 

16import time 

17 

18# Initialize module logger 

19logger = get_logger() 

20 

21 

22class TaskNumberValidator(Validator): 

23 def __init__(self, min_index: int, max_index: int): 

24 self.min_index = min_index 

25 self.max_index = max_index 

26 

27 def validate(self, document): 

28 text = document.text.strip() 

29 if not text: 

30 raise ValidationError(message="Task number cannot be empty") 

31 try: 

32 task_idx = int(text) 

33 if not (self.min_index <= task_idx <= self.max_index): 

34 raise ValidationError( 

35 message=f"Please select a task number between {self.min_index} and {self.max_index}") 

36 except ValueError: 

37 raise ValidationError(message="Please enter a valid number") 

38 

39 

40class CommitSHAValidator(Validator): 

41 def validate(self, document): 

42 text = document.text.strip() 

43 if not text: 

44 raise ValidationError(message="Commit SHA cannot be empty") 

45 # Allow full SHA (40 chars), short SHA (min 7 chars), or HEAD~n format 

46 if not (len(text) >= 7 and len(text) <= 40) and not text.startswith("HEAD~"): 

47 raise ValidationError(message="Invalid commit SHA format. Use 7-40 hex chars or HEAD~n format") 

48 

49 

50class ProgressIndicator: 

51 """A simple progress indicator for CLI operations""" 

52 

53 def __init__(self, message="Processing"): 

54 self.message = message 

55 self.running = False 

56 self.thread = None 

57 

58 def start(self): 

59 """Start the progress indicator""" 

60 self.running = True 

61 self.thread = threading.Thread(target=self._show_progress) 

62 self.thread.daemon = True 

63 self.thread.start() 

64 

65 def stop(self): 

66 """Stop the progress indicator""" 

67 self.running = False 

68 if self.thread: 

69 self.thread.join() 

70 # Clear the indicator line 

71 sys.stdout.write("\r" + " " * (len(self.message) + 10) + "\r") 

72 sys.stdout.flush() 

73 

74 def _show_progress(self): 

75 """Show the progress indicator animation""" 

76 symbols = ["-", "\\", "|", "/"] 

77 i = 0 

78 while self.running: 

79 sys.stdout.write(f"\r{self.message} {symbols[i]} ") 

80 sys.stdout.flush() 

81 time.sleep(0.1) 

82 i = (i + 1) % len(symbols) 

83 

84 

85def with_progress(message): 

86 """Decorator to add progress indicator to functions""" 

87 def decorator(func): 

88 def wrapper(*args, **kwargs): 

89 progress = ProgressIndicator(message) 

90 progress.start() 

91 try: 

92 result = func(*args, **kwargs) 

93 return result 

94 finally: 

95 progress.stop() 

96 return wrapper 

97 return decorator 

98 

99 

100class VersionFinderCLI: 

101 """ 

102 Version Finder CLI class. 

103 """ 

104 

105 def __init__(self): 

106 """ 

107 Initialize the VersionFinderCLI with a logger. 

108 """ 

109 self.registry = VersionFinderTaskRegistry() 

110 self.prompt_style = Style.from_dict({ 

111 # User input (default text). 

112 # '': '#ff0066', 

113 

114 # Prompt. 

115 'current_status': '#00aa00', 

116 }) 

117 

118 def get_task_functions(self) -> List[VersionFinderTask]: 

119 """ 

120 Get the list of available task functions. 

121 

122 Returns: 

123 List[VersionFinderTask]: List of available task functions. 

124 """ 

125 tasks_actions = {} 

126 for task in self.registry._tasks_by_index.values(): 

127 if (task.name == "Find all commits between two versions"): 

128 tasks_actions[task.index] = (self.find_all_commits_between_versions) 

129 continue 

130 if (task.name == "Find commit by text"): 

131 tasks_actions[task.index] = (self.find_commits_by_text) 

132 continue 

133 if (task.name == "Find first version containing commit"): 

134 tasks_actions[task.index] = (self.find_first_version_containing_commit) 

135 continue 

136 return tasks_actions 

137 

138 def run(self, args: argparse.Namespace): 

139 """ 

140 Run the CLI with the provided arguments. 

141 

142 Args: 

143 args: Parsed command-line arguments. 

144 

145 Returns: 

146 int: 0 on success, 1 on error 

147 """ 

148 try: 

149 self.path = self.handle_path_input(args.path) 

150 

151 # Initialize VersionFinder with force=True to allow uncommitted changes 

152 self.finder = VersionFinder(path=self.path, force=True) 

153 

154 # Check for uncommitted changes 

155 state = self.finder.get_saved_state() 

156 if state.get("has_changes", False): 

157 logger.warning("Repository has uncommitted changes") 

158 has_submodules = bool(state.get("submodules", {})) 

159 

160 if not args.force: 

161 # Build message with details about what will happen 

162 message = ( 

163 "Repository has uncommitted changes. Version Finder will:\n" 

164 "1. Stash your changes with a unique identifier\n" 

165 "2. Perform the requested operations\n" 

166 "3. Restore your original branch and stashed changes when closing\n" 

167 ) 

168 

169 if has_submodules: 

170 message += "Submodules with uncommitted changes will also be handled similarly.\n" 

171 

172 message += "Proceed anyway? (y/N): " 

173 

174 proceed = input(message).lower() == 'y' 

175 if not proceed: 

176 logger.info("Operation cancelled by user") 

177 return 0 

178 

179 actions = self.get_task_functions() 

180 params = self.finder.get_task_api_functions_params() 

181 self.registry.initialize_actions_and_args(actions, params) 

182 

183 self.branch = self.handle_branch_input(args.branch) 

184 

185 self.finder.update_repository(self.branch) 

186 

187 self.task_name = self.handle_task_input(args.task) 

188 

189 self.run_task(self.task_name) 

190 

191 # Restore original state if requested 

192 if args.restore_state: 

193 logger.info("Restoring original repository state") 

194 

195 # Get the state before restoration for logging 

196 state = self.finder.get_saved_state() 

197 has_changes = state.get("has_changes", False) 

198 stash_created = state.get("stash_created", False) 

199 

200 if has_changes: 

201 if stash_created: 

202 logger.info("Attempting to restore stashed changes") 

203 else: 

204 logger.warning("Repository had changes but they were not stashed") 

205 

206 # Perform the restoration 

207 if self.finder.restore_repository_state(): 

208 logger.info("Original repository state restored successfully") 

209 

210 # Verify the restoration 

211 current_branch = self.finder.get_current_branch() 

212 original_branch = state.get("branch") 

213 if original_branch and current_branch: 

214 if original_branch.startswith("HEAD:"): 

215 logger.info("Restored to detached HEAD state") 

216 else: 

217 logger.info(f"Restored to branch: {current_branch}") 

218 

219 # Check if we still have uncommitted changes 

220 if has_changes and self.finder.has_uncommitted_changes(): 

221 logger.info("Uncommitted changes were successfully restored") 

222 elif has_changes and not self.finder.has_uncommitted_changes(): 

223 logger.error("Failed to restore uncommitted changes") 

224 else: 

225 logger.warning("Failed to restore original repository state") 

226 

227 except KeyboardInterrupt: 

228 logger.info("\nOperation cancelled by user") 

229 

230 # Try to restore original state 

231 if hasattr(self, 'finder') and self.finder and args.restore_state: 

232 logger.info("Restoring original repository state") 

233 if self.finder.restore_repository_state(): 

234 logger.info("Original repository state restored successfully") 

235 else: 

236 logger.warning("Failed to restore original repository state") 

237 

238 return 0 

239 except Exception as e: 

240 logger.error("Error during task execution: %s", str(e)) 

241 

242 # Try to restore original state 

243 if hasattr(self, 'finder') and self.finder and args.restore_state: 

244 logger.info("Restoring original repository state") 

245 if self.finder.restore_repository_state(): 

246 logger.info("Original repository state restored successfully") 

247 else: 

248 logger.warning("Failed to restore original repository state") 

249 

250 return 1 

251 

252 def handle_task_input(self, task_name: str) -> str: 

253 """ 

254 Handle task input from user. 

255 """ 

256 if task_name is None: 

257 print("You have not selected a task.") 

258 print("Please select a task:") 

259 # Iterate through tasks in index order 

260 for task in self.registry.get_tasks_by_index(): 

261 print(f"{task.index}: {task.name}") 

262 min_index = self.registry.get_tasks_by_index()[0].index 

263 max_index = self.registry.get_tasks_by_index()[-1].index 

264 

265 task_validator = TaskNumberValidator(min_index, max_index) 

266 task_idx = int(prompt( 

267 "Enter task number: ", 

268 validator=task_validator, 

269 validate_while_typing=True 

270 ).strip()) 

271 

272 logger.debug("Selected task: %d", task_idx) 

273 if not self.registry.has_index(task_idx): 

274 logger.error("Invalid task selected") 

275 sys.exit(1) 

276 

277 task_struct = self.registry.get_by_index(task_idx) 

278 return task_struct.name 

279 

280 def handle_branch_input(self, branch_name: str) -> str: 

281 """ 

282 Handle branch input from user with auto-completion. 

283 

284 Args: 

285 branch_name: Optional branch name from command line 

286 

287 Returns: 

288 str: Selected branch name 

289 """ 

290 if branch_name is not None: 

291 return branch_name 

292 

293 branches = self.finder.list_branches() 

294 # When creating the branch_completer, modify it to: 

295 branch_completer = WordCompleter( 

296 branches, 

297 ignore_case=True, 

298 match_middle=True, 

299 pattern=re.compile(r'\S+') # Matches non-whitespace characters 

300 ) 

301 

302 current_branch = self.finder.get_current_branch() 

303 logger.info("Current branch: %s", current_branch) 

304 

305 if current_branch: 

306 prompt_message = [ 

307 ('', 'Current branch: '), 

308 ('class:current_status', f'{current_branch}'), 

309 ('', '\nPress [ENTER] to use the current branch or type to select a different branch: '), 

310 ] 

311 branch_name = prompt( 

312 prompt_message, 

313 completer=branch_completer, 

314 complete_while_typing=True, 

315 style=self.prompt_style 

316 ).strip() 

317 if branch_name == "": 

318 return current_branch 

319 return branch_name 

320 

321 def handle_submodule_input(self, submodule_name: str = None) -> str: 

322 """ 

323 Handle branch input from user. 

324 """ 

325 if submodule_name is None: 

326 submodule_list = self.finder.list_submodules() 

327 submodule_completer = WordCompleter(submodule_list, ignore_case=True, match_middle=True) 

328 # Take input from user 

329 submodule_name = prompt( 

330 "\nEnter submodule name (Tab for completion) or [ENTER] to continue without a submodule:", 

331 completer=submodule_completer, 

332 complete_while_typing=True 

333 ).strip() 

334 return submodule_name 

335 

336 def handle_path_input(self, path: str = None) -> str: 

337 """ 

338 Handle path input from user using prompt_toolkit. 

339 

340 Args: 

341 path: Optional path from command line 

342 

343 Returns: 

344 str: Path entered by user 

345 """ 

346 if path is None: 

347 prompt_msg = [ 

348 ('', 'Current directory: '), 

349 ('class:current_status', f'{os.getcwd()}'), 

350 ('', ':\nPress [ENTER] to use the current directory or type to select a different directory: '), 

351 ] 

352 

353 path_completer = PathCompleter( 

354 only_directories=True, 

355 expanduser=True 

356 ) 

357 path = prompt( 

358 prompt_msg, 

359 completer=path_completer, 

360 complete_while_typing=True, 

361 style=self.prompt_style 

362 ).strip() 

363 

364 if not path: 

365 path = os.getcwd() 

366 

367 # Validate the path 

368 if not os.path.exists(path) or not os.path.isdir(path): 

369 print(f"Error: Invalid path '{path}'", file=sys.stderr) 

370 sys.exit(1) 

371 

372 return path 

373 

374 def get_branch_selection(self) -> str: 

375 """ 

376 Get branch selection from user with auto-completion. 

377 

378 Returns: 

379 Selected branch name 

380 """ 

381 branches = self.finder.list_branches() 

382 branch_completer = WordCompleter(branches, ignore_case=True, match_middle=True) 

383 

384 while True: 

385 try: 

386 logger.debug("\nAvailable branches:") 

387 for branch in branches: 

388 logger.debug(f" - {branch}") 

389 

390 branch = prompt( 

391 "\nEnter branch name (Tab for completion): ", 

392 completer=branch_completer, 

393 complete_while_typing=True 

394 ).strip() 

395 

396 if branch in branches: 

397 return branch 

398 

399 logger.error("Invalid branch selected") 

400 

401 except (KeyboardInterrupt, EOFError): 

402 logger.info("\nOperation cancelled by user") 

403 sys.exit(0) 

404 

405 def run_task(self, task_name: str): 

406 """ 

407 Run the selected task. 

408 """ 

409 # task_args = self.fetch_arguments_per_task(task_name) 

410 self.registry.get_by_name(task_name).run() 

411 

412 def fetch_arguments_per_task(self, task_name: str) -> List[Any]: 

413 """ 

414 Fetch arguments for the selected task. 

415 """ 

416 task_args = [] 

417 for arg_name in self.registry.get_by_name(task_name).args: 

418 arg_value = getattr(self, arg_name) 

419 task_args.append(arg_value) 

420 return task_args 

421 

422 @with_progress("Finding version for commit") 

423 def find_first_version_containing_commit(self, commit_sha: str, submodule: str = None): 

424 """ 

425 Find the first version containing a commit. 

426 

427 Args: 

428 commit_sha: The commit SHA to find the version for 

429 submodule: Optional submodule path 

430 """ 

431 try: 

432 version = self.finder.find_version(commit_sha, submodule) 

433 if version: 

434 print(f"\nVersion for commit {commit_sha}: {version}") 

435 else: 

436 print(f"\nNo version found for commit {commit_sha}") 

437 except Exception as e: 

438 print(f"\nError: {str(e)}") 

439 

440 @with_progress("Finding commits between versions") 

441 def find_all_commits_between_versions(self, from_version: str, to_version: str, submodule: str = None): 

442 """ 

443 Find all commits between two versions. 

444 

445 Args: 

446 from_version: The starting version 

447 to_version: The ending version 

448 submodule: Optional submodule path 

449 """ 

450 try: 

451 commits = self.finder.find_commits_between_versions(from_version, to_version, submodule) 

452 if commits: 

453 print(f"\nFound {len(commits)} commits between {from_version} and {to_version}:") 

454 for commit in commits: 

455 print(f"{commit.sha[:8]} - {commit.subject}") 

456 else: 

457 print(f"\nNo commits found between {from_version} and {to_version}") 

458 except Exception as e: 

459 print(f"\nError: {str(e)}") 

460 

461 @with_progress("Searching for commits") 

462 def find_commits_by_text(self, text: str, submodule: str = None): 

463 """ 

464 Find commits containing specific text. 

465 

466 Args: 

467 text: The text to search for 

468 submodule: Optional submodule path 

469 """ 

470 try: 

471 commits = self.finder.find_commits_by_text(text, submodule) 

472 if commits: 

473 print(f"\nFound {len(commits)} commits containing '{text}':") 

474 for commit in commits: 

475 print(f"{commit.sha[:8]} - {commit.subject}") 

476 else: 

477 print(f"\nNo commits found containing '{text}'") 

478 except Exception as e: 

479 print(f"\nError: {str(e)}") 

480 

481 

482def cli_main(args: argparse.Namespace) -> int: 

483 """Main entry point for the version finder CLI.""" 

484 # Parse arguments 

485 if args.version: 

486 from .__version__ import __version__ 

487 print(f"version_finder cli-v{__version__}") 

488 return 0 

489 

490 # Initialize CLI 

491 cli = VersionFinderCLI() 

492 # Run CLI 

493 try: 

494 cli.run(args) 

495 return 0 

496 except GitError as e: 

497 logger.error("Git operation failed: %s", e) 

498 return 1 

499 

500 

501def main() -> int: 

502 

503 args = parse_arguments() 

504 return cli_main(args) 

505 

506 

507if __name__ == "__main__": 

508 sys.exit(main())