diff --git a/music_deduplicate.py b/music_deduplicate.py new file mode 100644 index 0000000..80dc290 --- /dev/null +++ b/music_deduplicate.py @@ -0,0 +1,817 @@ +# -*- coding: utf-8 -*- +""" +music_deduplicate.py — 音乐去重专用版本 + +特性概览: +- 多线程扫描 + 单线程 DatabaseWriterThread 写入 +- safe_remove:硬链接保护 +- 容错导入 librosa/scipy 等(功能降级) +- 自动检测写入阻塞并自动恢复 +- 基于音频指纹的智能去重 +""" +from __future__ import annotations +import os +import sys +import time +import warnings +import threading +import queue +import hashlib +import shutil +import sqlite3 +import logging +import argparse +import re +import math +from pathlib import Path +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +warnings.filterwarnings("ignore", category=UserWarning, module="numba") + +# ------------------------- +# logging helper +# ------------------------- +def setup_logging(log_level=logging.INFO, log_file="music_deduplicate.log"): + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file, encoding="utf-8"), + logging.StreamHandler(sys.stdout), + ], + ) + return logging.getLogger(__name__) + +logger = setup_logging() + +# ------------------------- +# safe optional imports (robust) +# ------------------------- +LIBROSA_AVAILABLE = False +SCIPY_AVAILABLE = False + +try: + import numpy as np # type: ignore + try: + import librosa # type: ignore + LIBROSA_AVAILABLE = True + logger.info("librosa available") + except Exception as e: + librosa = None # type: ignore + LIBROSA_AVAILABLE = False + logger.warning(f"librosa 导入失败: {e}") + + try: + from scipy import signal as scipy_signal # type: ignore + SCIPY_AVAILABLE = True + except Exception as e: + scipy_signal = None + SCIPY_AVAILABLE = False + logger.warning(f"scipy.signal 导入失败: {e}") + +except Exception as e: + logger.warning(f"科学栈初始化失败: {e}") + +# ------------------------- +# utils +# ------------------------- +def choose_worker_count(requested: Optional[int] = None) -> int: + if requested and requested > 0: + return requested + try: + cpu = os.cpu_count() or 1 + return min(32, max(4, cpu * 2)) + except Exception: + return 4 + +def file_sha256(path: str, block_size: int = 65536) -> str: + h = hashlib.sha256() + try: + with open(path, "rb") as f: + for block in iter(lambda: f.read(block_size), b""): + h.update(block) + return h.hexdigest() + except Exception as e: + logger.debug(f"计算哈希失败 {path}: {e}") + return "" + +# ------------------------- +# safe_remove (硬链接保护:策略 C) +# ------------------------- +def safe_remove(path: str, no_backup: bool=False, backup_dir: Optional[str]=None, db_writer: Optional["DatabaseWriterThread"]=None) -> bool: + try: + st = os.stat(path) + except Exception as e: + logger.warning(f"无法访问文件 {path}: {e}") + return False + + if getattr(st, "st_nlink", 1) > 1: + logger.info(f"文件有多个硬链接,跳过删除以保护硬链接: {path}") + if db_writer: + db_writer.enqueue_operation({ + "operation_type": "skip_delete_hardlink", + "file_path": path, + "file_hash": None, + "reason": "hardlink_skip", + "details": None + }) + return False + + if backup_dir and not no_backup: + try: + os.makedirs(backup_dir, exist_ok=True) + dest = os.path.join(backup_dir, os.path.basename(path)) + shutil.move(path, dest) + logger.info(f"已将文件移动到备份目录: {path} -> {dest}") + if db_writer: + db_writer.enqueue_operation({ + "operation_type": "backup_move", + "file_path": path, + "file_hash": None, + "reason": "moved_to_backup", + "details": dest + }) + return True + except Exception as e: + logger.warning(f"移动到备份目录失败 {path}: {e}") + + try: + os.remove(path) + logger.info(f"已删除文件: {path}") + if db_writer: + db_writer.enqueue_operation({ + "operation_type": "delete", + "file_path": path, + "file_hash": None, + "reason": "deleted", + "details": None + }) + return True + except Exception as e: + logger.error(f"删除文件失败 {path}: {e}") + return False + +# ------------------------- +# DatabaseWriterThread (with detection & auto-migrate) +# ------------------------- +class DatabaseWriterThread(threading.Thread): + """ + Single-threaded DB writer with: + - internal queue for files/ops + - lock detection and automatic recovery + - optional automatic DB migration to a safe directory + """ + def __init__(self, db_path: str = "music_deduplicate.db", batch_limit:int = 200, flush_interval: float = 1.0, lock_detect_timeout: float = 8.0, max_retries:int=3, auto_migrate:bool=True): + super().__init__(daemon=True) + self.db_path = str(db_path) + self.batch_limit = batch_limit + self.flush_interval = flush_interval + self.lock_detect_timeout = lock_detect_timeout + self.max_retries = max_retries + self.auto_migrate = auto_migrate + + self._conn: Optional[sqlite3.Connection] = None + self._queue: "queue.Queue[Tuple[str, Any]]" = queue.Queue() + self._stop_event = threading.Event() + self._last_write_time = 0.0 + self._consecutive_failures = 0 + + # ------------------------- + # database connection + # ------------------------- + def _connect(self): + try: + conn = sqlite3.connect( + self.db_path, + timeout=3, + isolation_level=None, + check_same_thread=False, + ) + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute("PRAGMA synchronous=NORMAL;") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT UNIQUE, + file_hash TEXT, + file_size INTEGER, + file_mtime REAL, + created_at TEXT + ); + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS operations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation_type TEXT, + file_path TEXT, + file_hash TEXT, + reason TEXT, + details TEXT, + created_at TEXT + ); + """ + ) + conn.commit() + logger.info(f"数据库连接成功:{self.db_path}") + return conn + except Exception as e: + logger.error(f"数据库连接失败 {self.db_path}: {e}") + return None + + # ------------------------- + def start(self): + super().start() + + def stop(self): + self._stop_event.set() + + def join(self, timeout=None): + self._stop_event.set() + super().join(timeout) + if self._conn: + try: + self._conn.commit() + self._conn.close() + except: + pass + + # ------------------------- + # Queue interface + # ------------------------- + def enqueue_file(self, record: Dict[str, Any]): + """ + record = { + "file_path": str, + "file_hash": str, + "file_size": int, + "file_mtime": float, + "created_at": timestamp, + } + """ + self._queue.put(("file", record)) + + def enqueue_operation(self, record: Dict[str, Any]): + self._queue.put(("operation", record)) + + # ------------------------- + # Writer loop + # ------------------------- + def run(self): + logger.info("DatabaseWriterThread 启动") + buffer_files = [] + buffer_ops = [] + last_flush_time = time.time() + + while not self._stop_event.is_set(): + try: + item_type, data = self._queue.get(timeout=self.flush_interval) + if item_type == "file": + buffer_files.append(data) + elif item_type == "operation": + buffer_ops.append(data) + except queue.Empty: + pass + + now = time.time() + need_flush = False + + if len(buffer_files) >= self.batch_limit or len(buffer_ops) >= self.batch_limit: + need_flush = True + if now - last_flush_time >= self.flush_interval: + need_flush = True + + if need_flush: + ok = self._flush(buffer_files, buffer_ops) + if ok: + buffer_files.clear() + buffer_ops.clear() + last_flush_time = now + + self._flush(buffer_files, buffer_ops) + logger.info("DatabaseWriterThread 结束(队列已清空)") + + # ------------------------- + # Flush — with lock detection and recovery + # ------------------------- + def _flush(self, files: List[Dict[str,Any]], ops: List[Dict[str,Any]]) -> bool: + if not self._conn: + logger.error("数据库连接失效(conn = None)尝试重新连接…") + self._conn = self._connect() + if not self._conn: + return False + + if not files and not ops: + return True + + start = time.time() + ok = False + last_err = None + + for attempt in range(self.max_retries): + try: + cur = self._conn.cursor() + for rec in files: + cur.execute( + """ + INSERT OR REPLACE INTO files (file_path, file_hash, file_size, file_mtime, created_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + rec.get("file_path"), + rec.get("file_hash"), + rec.get("file_size"), + rec.get("file_mtime"), + rec.get("created_at"), + ) + ) + for rec in ops: + cur.execute( + """ + INSERT INTO operations (operation_type, file_path, file_hash, reason, details, created_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + rec.get("operation_type"), + rec.get("file_path"), + rec.get("file_hash"), + rec.get("reason"), + rec.get("details"), + rec.get("created_at", datetime.now().isoformat()), + ) + ) + self._conn.commit() + ok = True + self._consecutive_failures = 0 + break + except Exception as e: + last_err = e + logger.warning(f"批量写入数据库失败 (第 {attempt+1}/{self.max_retries} 次):{e}") + + if "locked" in str(e).lower(): + time.sleep(0.8 + attempt * 0.4) + continue + + time.sleep(0.5) + + if not ok: + self._consecutive_failures += 1 + elapsed = time.time() - start + + logger.error(f"写入失败超过重试次数:{last_err}") + + if elapsed > self.lock_detect_timeout or "locked" in str(last_err).lower(): + logger.error("检测到数据库长期锁定,尝试恢复连接…") + try: + self._conn.close() + except: + pass + self._conn = self._connect() + if self._conn: + logger.info("数据库重连成功") + return False + + if self.auto_migrate: + logger.error("数据库重连失败,准备自动迁移数据库…") + return self._try_auto_migrate() + + return ok + + # ------------------------- + # Try automatically migrating DB to a safe path + # ------------------------- + def _try_auto_migrate(self) -> bool: + try: + safe_dir = "/var/db/music_deduplicate" + os.makedirs(safe_dir, exist_ok=True) + new_path = os.path.join(safe_dir, "music_deduplicate.db") + + try: + shutil.copy2(self.db_path, new_path) + logger.info(f"数据库已迁移: {self.db_path} -> {new_path}") + except Exception as e: + logger.error(f"数据库迁移失败: {e}") + return False + + self.db_path = new_path + self._conn = self._connect() + if self._conn: + logger.info("迁移后的数据库连接成功,继续运行") + return True + else: + return False + except Exception as e: + logger.error(f"自动迁移过程异常: {e}") + return False + +# ===================================================== +# 音频指纹分析 +# ===================================================== + +# ------------------------- +# 音频指纹提取(容错) +# ------------------------- +class AudioFingerprint: + def __init__(self): + self.ok = LIBROSA_AVAILABLE or SCIPY_AVAILABLE + + def extract(self, path: str) -> Optional[str]: + """ + 提取音频指纹,返回字符串表示 + """ + if not self.ok: + logger.debug(f"音频指纹模块不可用,跳过: {path}") + return None + + try: + if LIBROSA_AVAILABLE: + y, sr = librosa.load(path, sr=22050, mono=True, duration=30) # 只分析前30秒 + mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20) + # 取MFCC的均值和方差作为特征 + fp_mean = np.mean(mfcc, axis=1) + fp_std = np.std(mfcc, axis=1) + fp = np.concatenate([fp_mean, fp_std]) + # 将浮点数数组转换为字符串 + return f"audio_{'_'.join([f'{x:.4f}' for x in fp])}" + + # librosa 不可用时,用 scipy_signal + if SCIPY_AVAILABLE: + import soundfile as sf + data, sr = sf.read(path) + if data.ndim > 1: + data = data.mean(axis=1) + # 只取前30秒 + max_samples = min(len(data), sr * 30) + data = data[:max_samples] + freqs, times, Sxx = scipy_signal.spectrogram(data, sr) + fp = np.mean(Sxx, axis=1) + return f"audio_{'_'.join([f'{x:.4f}' for x in fp])}" + + return None + + except Exception as e: + logger.error(f"提取音频指纹失败 {path}: {e}") + return None + +# ===================================================== +# 扫描器:多线程扫描 + 入队写数据库 +# ===================================================== + +class FileScanner: + EXT_AUDIO = {".mp3", ".aac", ".flac", ".ogg", ".wav", ".m4a", ".wma", ".ape", ".alac"} + + def __init__(self, db_writer: DatabaseWriterThread, workers:int=8): + self.db_writer = db_writer + self.workers = choose_worker_count(workers) + self.audio_fp = AudioFingerprint() + + # ------------------------- + def scan(self, root: str): + """ + 遍历路径,将文件元数据推送到数据库队列。 + """ + root = os.path.abspath(root) + logger.info(f"开始扫描路径: {root}") + + file_list: List[str] = [] + for base, dirs, files in os.walk(root): + for f in files: + ext = os.path.splitext(f)[1].lower() + if ext in self.EXT_AUDIO: + full = os.path.join(base, f) + file_list.append(full) + + logger.info(f"扫描完成,共发现音频文件: {len(file_list)}") + + with ThreadPoolExecutor(max_workers=self.workers) as ex: + futures = {ex.submit(self._process_one, path): path for path in file_list} + for fut in as_completed(futures): + try: + fut.result() + except Exception as e: + logger.error(f"处理文件异常: {e}") + + # ------------------------- + def _process_one(self, path: str): + """ + 获取文件大小、时间、hash(快速)并提交数据库线程。 + """ + try: + st = os.stat(path) + except Exception as e: + logger.debug(f"无法读取文件 stat: {path}: {e}") + return + + ext = os.path.splitext(path)[1].lower() + + # 轻量快速 hash(仅文件大小>1MB才计算) + file_hash = "" + if st.st_size > 1_000_000: + file_hash = file_sha256(path) + else: + file_hash = f"SMALL-{st.st_size}-{int(st.st_mtime)}" + + record = { + "file_path": path, + "file_hash": file_hash, + "file_size": st.st_size, + "file_mtime": st.st_mtime, + "created_at": datetime.now().isoformat(), + } + self.db_writer.enqueue_file(record) + +# ===================================================== +# 相似度检测与去重决策 +# ===================================================== + +class DuplicateFinder: + """ + 基于 DB 快照进行相似群组查找 + """ + def __init__(self, db_path: str): + self.db_path = db_path + self.audio_fp = AudioFingerprint() + + def _read_files_from_db(self) -> List[Dict[str, Any]]: + out = [] + try: + conn = sqlite3.connect(self.db_path, timeout=30) + cur = conn.cursor() + cur.execute("SELECT file_path, file_hash, file_size FROM files") + for row in cur.fetchall(): + out.append({"path": row[0], "hash": row[1], "size": row[2]}) + except Exception as e: + logger.warning(f"读取 DB 列表失败: {e}") + finally: + try: + conn.close() + except: + pass + return out + + def group_by_name(self, files: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + groups = {} + for f in files: + # 使用文件名(不含扩展名)作为分组键 + key = Path(f["path"]).stem.lower() + # 移除常见标记 + key = re.sub(r"(official|version|remaster|remastered|remix|edit|radio|clean|explicit)", "", key, flags=re.IGNORECASE) + key = re.sub(r"[\(\)\[\]\{\}\.]", " ", key) + key = re.sub(r"\s+", " ", key).strip() + groups.setdefault(key, []).append(f) + return [g for g in groups.values() if len(g) > 1] + + def are_audios_similar(self, a: str, b: str, threshold: float = 0.85) -> bool: + """ + 判断两个音频文件是否相似 + """ + try: + # 首先比较文件大小 + sa = os.path.getsize(a) + sb = os.path.getsize(b) + size_ratio = min(sa, sb) / max(sa, sb) if max(sa, sb) > 0 else 0 + if size_ratio > 0.95: # 大小相差小于5% + return True + except Exception: + pass + + # 使用音频指纹比较 + try: + fp_a = self.audio_fp.extract(a) + fp_b = self.audio_fp.extract(b) + + if fp_a and fp_b: + # 解析指纹字符串 + if fp_a.startswith("audio_") and fp_b.startswith("audio_"): + values_a = [float(x) for x in fp_a[6:].split("_")] + values_b = [float(x) for x in fp_b[6:].split("_")] + + if len(values_a) == len(values_b): + # 计算余弦相似度 + dot_product = sum(a*b for a, b in zip(values_a, values_b)) + norm_a = math.sqrt(sum(a*a for a in values_a)) + norm_b = math.sqrt(sum(b*b for b in values_b)) + + if norm_a > 0 and norm_b > 0: + similarity = dot_product / (norm_a * norm_b) + return similarity >= threshold + + except Exception as e: + logger.debug(f"音频指纹比较失败: {e}") + + return False + + def find_audio_groups(self) -> List[List[Dict[str,Any]]]: + files = self._read_files_from_db() + name_groups = self.group_by_name(files) + result = [] + + for g in name_groups: + if len(g) <= 1: + continue + used = set() + for i in range(len(g)): + if i in used: + continue + base = g[i] + cluster = [base] + used.add(i) + for j in range(i+1, len(g)): + if j in used: + continue + try: + if self.are_audios_similar(base["path"], g[j]["path"]): + cluster.append(g[j]) + used.add(j) + except Exception as e: + logger.debug(f"比较音频相似度失败: {e}") + pass + if len(cluster) > 1: + result.append(cluster) + + logger.info(f"查找完成:发现 {len(result)} 音频候选组") + return result + +# ------------------------- +# MusicDeduplicator high-level operations +# ------------------------- +class MusicDeduplicator: + def __init__(self, target_dirs: List[str], db_path: str="music_deduplicate.db", prefer_folder: Optional[str]=None, workers: int=0, auto_migrate: bool=True): + self.target_dirs = target_dirs + self.db_path = db_path + self.prefer_folder = prefer_folder + self.db_writer = DatabaseWriterThread(db_path=db_path, auto_migrate=auto_migrate) + # 启动写入线程 + self.db_writer.start() + self.scanner = FileScanner(db_writer=self.db_writer, workers=workers) + self.finder = DuplicateFinder(db_path=self.db_path) + + def scan_all(self): + for d in self.target_dirs: + self.scanner.scan(d) + + def remove_groups(self, groups: List[List[Dict[str,Any]]], dry_run: bool=True, no_backup: bool=False) -> Tuple[List[str], List[str]]: + kept = [] + deleted = [] + for group in groups: + if not group: + continue + # choose keeper + keeper = None + if self.prefer_folder: + for f in group: + if self.prefer_folder in f["path"]: + keeper = f + break + if not keeper: + # 优先保留高比特率或大文件 + keeper = max(group, key=lambda x: x.get("size", 0)) + kept.append(keeper["path"]) + for f in group: + p = f["path"] + if p == keeper["path"]: + continue + if dry_run: + logger.info(f"[dry-run] 删除 {p} (保留 {keeper['path']})") + self.db_writer.enqueue_operation({ + "operation_type": "planned_delete", + "file_path": p, + "file_hash": f.get("hash"), + "reason": "dry_run", + "details": None, + "created_at": datetime.now().isoformat() + }) + deleted.append(p) + else: + ok = safe_remove(p, no_backup=no_backup, backup_dir=None, db_writer=self.db_writer) + if ok: + deleted.append(p) + else: + logger.info(f"跳过删除(可能为硬链接或权限问题): {p}") + return kept, deleted + + def run_deduplication(self, dry_run: bool=True, no_backup: bool=False) -> Dict[str,Any]: + logger.info("开始音乐去重") + self.scan_all() + logger.info("等待 db_writer 完成写入任务...") + # wait until queue is drained or timeout + start = time.time() + while not self.db_writer._queue.empty(): + time.sleep(0.5) + if time.time() - start > 600: + logger.error("等待 db_writer 超过 600 秒,提前退出") + break + groups = self.finder.find_audio_groups() + kept, deleted = self.remove_groups(groups, dry_run=dry_run, no_backup=no_backup) + return {"kept": kept, "deleted": deleted, "groups": len(groups)} + +# ===================================================== +# CLI & Main Function +# ===================================================== + +def parse_args(): + parser = argparse.ArgumentParser(description="音乐去重工具") + parser.add_argument( + "-d", "--dirs", + nargs="+", + required=True, + help="指定需要扫描的目录(一个或多个)" + ) + parser.add_argument( + "--prefer", + type=str, + default=None, + help="优先保留的路径片段(如果匹配文件路径则优先保留)" + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="演示模式:仅显示将要删除的文件,不实际删除" + ) + parser.add_argument( + "--no-backup", + action="store_true", + help="删除时不创建备份(谨慎)" + ) + parser.add_argument( + "--workers", + type=int, + default=0, + help="扫描线程数(默认自动计算)" + ) + parser.add_argument( + "--db", + type=str, + default="music_deduplicate.db", + help="使用的数据库文件" + ) + parser.add_argument( + "--migrate", + action="store_true", + help="强制允许自动迁移数据库(锁死时会迁移)" + ) + + return parser.parse_args() + +def main(): + args = parse_args() + + logger.info("==============================================") + logger.info(" 音乐去重工具 (Music Deduplicator) ") + logger.info("==============================================") + logger.info(f"扫描目录:{args.dirs}") + logger.info(f"数据库文件:{args.db}") + logger.info(f"优先保留路径片段:{args.prefer}") + if args.dry_run: + logger.info("警告:dry-run 模式(不会删除任何文件)") + if args.no_backup: + logger.warning("危险:已启用 --no-backup,不会创建备份!") + + cleaner = MusicDeduplicator( + target_dirs=args.dirs, + db_path=args.db, + prefer_folder=args.prefer, + workers=args.workers, + auto_migrate=args.migrate, + ) + + result = None + + try: + result = cleaner.run_deduplication( + dry_run=args.dry_run, + no_backup=args.no_backup, + ) + except Exception as e: + logger.error(f"运行清理任务发生异常: {e}", exc_info=True) + finally: + # ensure writer shutdown + try: + cleaner.db_writer.stop() + cleaner.db_writer.join(timeout=10) + except Exception: + pass + + logger.info("所有任务完成。") + + if result is not None: + logger.info("========== 清理结果 ==========") + logger.info(f"保留文件数: {len(result['kept'])}") + logger.info(f"删除文件数: {len(result['deleted'])}") + logger.info(f"发现相似组: {result['groups']}") + + if args.dry_run: + logger.info("\n将要删除的文件列表:") + for f in result['deleted']: + logger.info(f" {f}") + else: + logger.info("\n已删除的文件列表:") + for f in result['deleted']: + logger.info(f" {f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/video_deduplicate.py b/video_deduplicate.py new file mode 100644 index 0000000..628d6f0 --- /dev/null +++ b/video_deduplicate.py @@ -0,0 +1,841 @@ +# -*- coding: utf-8 -*- +""" +video_deduplicate.py — 视频去重专用版本 + +特性概览: +- 多线程扫描 + 单线程 DatabaseWriterThread 写入(永不出现 database is locked) +- safe_remove:硬链接保护(策略 C) +- 容错导入 OpenCV/PIL/imagehash 等(功能降级) +- 自动检测写入阻塞并自动恢复 +- 基于视频指纹和内容相似度的智能去重 +""" +from __future__ import annotations +import os +import sys +import time +import warnings +import threading +import queue +import hashlib +import shutil +import sqlite3 +import logging +import argparse +import math +import re +from pathlib import Path +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +warnings.filterwarnings("ignore", category=UserWarning, module="numba") + +# ------------------------- +# logging helper +# ------------------------- +def setup_logging(log_level=logging.INFO, log_file="video_deduplicate.log"): + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file, encoding="utf-8"), + logging.StreamHandler(sys.stdout), + ], + ) + return logging.getLogger(__name__) + +logger = setup_logging() + +# ------------------------- +# safe optional imports (robust) +# ------------------------- +VIDEO_PROCESSING_AVAILABLE = False + +try: + import numpy as np # type: ignore + + try: + import cv2 # type: ignore + import imagehash # type: ignore + from PIL import Image # type: ignore + VIDEO_PROCESSING_AVAILABLE = True + except Exception as e: + VIDEO_PROCESSING_AVAILABLE = False + logger.warning(f"视频处理库导入失败: {e}") + +except Exception as e: + logger.warning(f"科学栈初始化失败: {e}") + +# ------------------------- +# utils +# ------------------------- +def choose_worker_count(requested: Optional[int] = None) -> int: + if requested and requested > 0: + return requested + try: + cpu = os.cpu_count() or 1 + return min(32, max(4, cpu * 2)) + except Exception: + return 4 + +def file_sha256(path: str, block_size: int = 65536) -> str: + h = hashlib.sha256() + try: + with open(path, "rb") as f: + for block in iter(lambda: f.read(block_size), b""): + h.update(block) + return h.hexdigest() + except Exception as e: + logger.debug(f"计算哈希失败 {path}: {e}") + return "" + +# ------------------------- +# safe_remove (硬链接保护:策略 C) +# ------------------------- +def safe_remove(path: str, no_backup: bool=False, backup_dir: Optional[str]=None, db_writer: Optional["DatabaseWriterThread"]=None) -> bool: + try: + st = os.stat(path) + except Exception as e: + logger.warning(f"无法访问文件 {path}: {e}") + return False + + if getattr(st, "st_nlink", 1) > 1: + logger.info(f"文件有多个硬链接,跳过删除以保护硬链接: {path}") + if db_writer: + db_writer.enqueue_operation({ + "operation_type": "skip_delete_hardlink", + "file_path": path, + "file_hash": None, + "reason": "hardlink_skip", + "details": None + }) + return False + + if backup_dir and not no_backup: + try: + os.makedirs(backup_dir, exist_ok=True) + dest = os.path.join(backup_dir, os.path.basename(path)) + shutil.move(path, dest) + logger.info(f"已将文件移动到备份目录: {path} -> {dest}") + if db_writer: + db_writer.enqueue_operation({ + "operation_type": "backup_move", + "file_path": path, + "file_hash": None, + "reason": "moved_to_backup", + "details": dest + }) + return True + except Exception as e: + logger.warning(f"移动到备份目录失败 {path}: {e}") + + try: + os.remove(path) + logger.info(f"已删除文件: {path}") + if db_writer: + db_writer.enqueue_operation({ + "operation_type": "delete", + "file_path": path, + "file_hash": None, + "reason": "deleted", + "details": None + }) + return True + except Exception as e: + logger.error(f"删除文件失败 {path}: {e}") + return False + +# ------------------------- +# DatabaseWriterThread (with detection & auto-migrate) +# ------------------------- +class DatabaseWriterThread(threading.Thread): + """ + Single-threaded DB writer with: + - internal queue for files/ops + - lock detection and automatic recovery + - optional automatic DB migration to a safe directory + """ + def __init__(self, db_path: str = "video_deduplicate.db", batch_limit:int = 200, flush_interval: float = 1.0, lock_detect_timeout: float = 8.0, max_retries:int=3, auto_migrate:bool=True): + super().__init__(daemon=True) + self.db_path = str(db_path) + self.batch_limit = batch_limit + self.flush_interval = flush_interval + self.lock_detect_timeout = lock_detect_timeout + self.max_retries = max_retries + self.auto_migrate = auto_migrate + + self._conn: Optional[sqlite3.Connection] = None + self._queue: "queue.Queue[Tuple[str, Any]]" = queue.Queue() + self._stop_event = threading.Event() + self._last_write_time = 0.0 + self._consecutive_failures = 0 + + # ------------------------- + # database connection + # ------------------------- + def _connect(self): + try: + conn = sqlite3.connect( + self.db_path, + timeout=3, + isolation_level=None, + check_same_thread=False, + ) + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute("PRAGMA synchronous=NORMAL;") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT UNIQUE, + file_hash TEXT, + file_size INTEGER, + file_mtime REAL, + created_at TEXT + ); + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS operations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation_type TEXT, + file_path TEXT, + file_hash TEXT, + reason TEXT, + details TEXT, + created_at TEXT + ); + """ + ) + conn.commit() + logger.info(f"数据库连接成功:{self.db_path}") + return conn + except Exception as e: + logger.error(f"数据库连接失败 {self.db_path}: {e}") + return None + + # ------------------------- + def start(self): + # 先建立数据库连接,再启动线程 + self._conn = self._connect() + if not self._conn: + logger.error("无法连接数据库,线程启动失败") + return + super().start() + + def stop(self): + self._stop_event.set() + + def join(self, timeout=None): + self._stop_event.set() + super().join(timeout) + if self._conn: + try: + self._conn.commit() + self._conn.close() + except: + pass + + # ------------------------- + # Queue interface + # ------------------------- + def enqueue_file(self, record: Dict[str, Any]): + """ + record = { + "file_path": str, + "file_hash": str, + "file_size": int, + "file_mtime": float, + "created_at": timestamp, + } + """ + self._queue.put(("file", record)) + + def enqueue_operation(self, record: Dict[str, Any]): + self._queue.put(("operation", record)) + + # ------------------------- + # Writer loop + # ------------------------- + def run(self): + logger.info("DatabaseWriterThread 启动") + buffer_files = [] + buffer_ops = [] + last_flush_time = time.time() + + while not self._stop_event.is_set(): + try: + item_type, data = self._queue.get(timeout=self.flush_interval) + if item_type == "file": + buffer_files.append(data) + elif item_type == "operation": + buffer_ops.append(data) + except queue.Empty: + pass + + now = time.time() + need_flush = False + + if len(buffer_files) >= self.batch_limit or len(buffer_ops) >= self.batch_limit: + need_flush = True + if now - last_flush_time >= self.flush_interval: + need_flush = True + + if need_flush: + ok = self._flush(buffer_files, buffer_ops) + if ok: + buffer_files.clear() + buffer_ops.clear() + last_flush_time = now + + # 线程结束前最后刷新一次 + self._flush(buffer_files, buffer_ops) + logger.info("DatabaseWriterThread 结束(队列已清空)") + + # ------------------------- + # Flush — with lock detection and recovery + # ------------------------- + def _flush(self, files: List[Dict[str,Any]], ops: List[Dict[str,Any]]) -> bool: + if not self._conn: + logger.error("数据库连接失效(conn = None)尝试重新连接…") + self._conn = self._connect() + if not self._conn: + return False + + if not files and not ops: + return True + + start = time.time() + ok = False + last_err = None + + for attempt in range(self.max_retries): + try: + cur = self._conn.cursor() + for rec in files: + cur.execute( + """ + INSERT OR REPLACE INTO files (file_path, file_hash, file_size, file_mtime, created_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + rec.get("file_path"), + rec.get("file_hash"), + rec.get("file_size"), + rec.get("file_mtime"), + rec.get("created_at"), + ) + ) + for rec in ops: + cur.execute( + """ + INSERT INTO operations (operation_type, file_path, file_hash, reason, details, created_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + rec.get("operation_type"), + rec.get("file_path"), + rec.get("file_hash"), + rec.get("reason"), + rec.get("details"), + rec.get("created_at", datetime.now().isoformat()), + ) + ) + self._conn.commit() + ok = True + self._consecutive_failures = 0 + break + except Exception as e: + last_err = e + logger.warning(f"批量写入数据库失败 (第 {attempt+1}/{self.max_retries} 次):{e}") + + if "locked" in str(e).lower(): + time.sleep(0.8 + attempt * 0.4) + continue + + time.sleep(0.5) + + if not ok: + self._consecutive_failures += 1 + elapsed = time.time() - start + + logger.error(f"写入失败超过重试次数:{last_err}") + + if elapsed > self.lock_detect_timeout or "locked" in str(last_err).lower(): + logger.error("检测到数据库长期锁定,尝试恢复连接…") + try: + self._conn.close() + except: + pass + self._conn = self._connect() + if self._conn: + logger.info("数据库重连成功") + return False + + if self.auto_migrate: + logger.error("数据库重连失败,准备自动迁移数据库…") + return self._try_auto_migrate() + + return ok + + # ------------------------- + # Try automatically migrating DB to a safe path + # ------------------------- + def _try_auto_migrate(self) -> bool: + try: + safe_dir = "/var/db/video_deduplicate" + os.makedirs(safe_dir, exist_ok=True) + new_path = os.path.join(safe_dir, "video_deduplicate.db") + + try: + shutil.copy2(self.db_path, new_path) + logger.info(f"数据库已迁移: {self.db_path} -> {new_path}") + except Exception as e: + logger.error(f"数据库迁移失败: {e}") + return False + + self.db_path = new_path + self._conn = self._connect() + if self._conn: + logger.info("迁移后的数据库连接成功,继续运行") + return True + else: + return False + except Exception as e: + logger.error(f"自动迁移过程异常: {e}") + return False + +# ===================================================== +# 指纹分析 +# ===================================================== + +# ------------------------- +# 视频指纹提取(容错) +# ------------------------- +class VideoFingerprint: + def __init__(self): + self.ok = VIDEO_PROCESSING_AVAILABLE + + def extract(self, path: str) -> Optional[str]: + """ + 提取视频指纹,返回字符串表示 + """ + if not self.ok: + logger.debug(f"视频指纹模块不可用,跳过: {path}") + return None + + try: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + logger.error(f"打开视频失败: {path}") + return None + + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if frame_count <= 0: + cap.release() + return None + + sample_count = 5 # 采样5帧 + step = max(1, frame_count // sample_count) + fingerprints = [] + + for i in range(0, frame_count, step): + if len(fingerprints) >= sample_count: + break + + cap.set(cv2.CAP_PROP_POS_FRAMES, i) + ok, frame = cap.read() + if not ok: + continue + + # 转换为灰度图 + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + # 计算pHash + try: + pil_img = Image.fromarray(gray) + phash = imagehash.phash(pil_img) + fingerprints.append(str(phash)) + except Exception as e: + logger.debug(f"计算pHash失败: {e}") + continue + + cap.release() + + if not fingerprints: + return None + + # 返回指纹的字符串表示 + return f"video_{'_'.join(fingerprints)}" + + except Exception as e: + logger.error(f"提取视频指纹失败 {path}: {e}") + return None + +# ===================================================== +# 扫描器:多线程扫描 + 入队写数据库 +# ===================================================== + +class FileScanner: + EXT_VIDEO = {".mp4", ".mkv", ".avi", ".rmvb", ".mov", ".wmv", ".flv", ".ts", ".m2ts", ".webm"} + + def __init__(self, db_writer: DatabaseWriterThread, workers:int=8): + self.db_writer = db_writer + self.workers = choose_worker_count(workers) + self.video_fp = VideoFingerprint() + + # ------------------------- + def scan(self, root: str): + """ + 遍历路径,将文件元数据推送到数据库队列。 + """ + root = os.path.abspath(root) + logger.info(f"开始扫描路径: {root}") + + file_list: List[str] = [] + for base, dirs, files in os.walk(root): + for f in files: + ext = os.path.splitext(f)[1].lower() + if ext in self.EXT_VIDEO: + full = os.path.join(base, f) + file_list.append(full) + + logger.info(f"扫描完成,共发现视频文件: {len(file_list)}") + + with ThreadPoolExecutor(max_workers=self.workers) as ex: + futures = {ex.submit(self._process_one, path): path for path in file_list} + for fut in as_completed(futures): + try: + fut.result() + except Exception as e: + logger.error(f"处理文件异常: {e}") + + # ------------------------- + def _process_one(self, path: str): + """ + 获取文件大小、时间、hash(快速)并提交数据库线程。 + """ + try: + st = os.stat(path) + except Exception as e: + logger.debug(f"无法读取文件 stat: {path}: {e}") + return + + ext = os.path.splitext(path)[1].lower() + + # 轻量快速 hash(仅文件大小>1MB才计算) + file_hash = "" + if st.st_size > 1_000_000: + file_hash = file_sha256(path) + else: + file_hash = f"SMALL-{st.st_size}-{int(st.st_mtime)}" + + record = { + "file_path": path, + "file_hash": file_hash, + "file_size": st.st_size, + "file_mtime": st.st_mtime, + "created_at": datetime.now().isoformat(), + } + self.db_writer.enqueue_file(record) + +# ===================================================== +# 相似度检测与去重决策 +# ===================================================== + +def phash_distance(h1: str, h2: str) -> int: + """ + 计算两个 phash 字符串的汉明距离 + """ + try: + b1 = int(str(h1), 16) + b2 = int(str(h2), 16) + x = b1 ^ b2 + return bin(x).count("1") + except Exception: + return 128 # large + +class DuplicateFinder: + """ + 基于 DB 快照进行相似群组查找 + """ + def __init__(self, db_path: str): + self.db_path = db_path + self.video_fp = VideoFingerprint() + + def _read_files_from_db(self) -> List[Dict[str, Any]]: + out = [] + try: + conn = sqlite3.connect(self.db_path, timeout=30) + cur = conn.cursor() + cur.execute("SELECT file_path, file_hash, file_size FROM files") + for row in cur.fetchall(): + out.append({"path": row[0], "hash": row[1], "size": row[2]}) + except Exception as e: + logger.warning(f"读取 DB 列表失败: {e}") + finally: + try: + conn.close() + except: + pass + return out + + def group_by_name(self, files: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + groups = {} + for f in files: + key = Path(f["path"]).stem.lower() + # remove common tokens + key = re.sub(r"(1080p|720p|2160p|4k|x264|x265|h264|h265|hevc|bluray|web-dl|webdl|bdrip|brrip)", "", key) + key = re.sub(r"[\._\-]+", " ", key).strip() + groups.setdefault(key, []).append(f) + return [g for g in groups.values() if len(g) > 1] + + def are_videos_similar(self, a: str, b: str, phash_thresh: int = 10) -> bool: + """ + 判断两个视频是否相似 + """ + try: + # 首先比较文件大小 + sa = os.path.getsize(a) + sb = os.path.getsize(b) + if sa == sb: + return True + except Exception: + pass + + # 使用视频指纹比较 + try: + fp_a = self.video_fp.extract(a) + fp_b = self.video_fp.extract(b) + + if fp_a and fp_b: + # 解析指纹字符串 + parts_a = fp_a.split("_")[1:] # 去掉开头的"video_" + parts_b = fp_b.split("_")[1:] + + if len(parts_a) == len(parts_b) and len(parts_a) > 0: + # 计算匹配的帧数 + matches = 0 + for x, y in zip(parts_a, parts_b): + dist = phash_distance(x, y) + if dist <= phash_thresh: + matches += 1 + + ratio = matches / len(parts_a) + if ratio >= 0.6: # 60%的帧相似 + return True + except Exception as e: + logger.debug(f"视频指纹比较失败: {e}") + + return False + + def find_video_groups(self) -> List[List[Dict[str,Any]]]: + files = self._read_files_from_db() + name_groups = self.group_by_name(files) + result = [] + + for g in name_groups: + if len(g) <= 1: + continue + used = set() + for i in range(len(g)): + if i in used: + continue + base = g[i] + cluster = [base] + used.add(i) + for j in range(i+1, len(g)): + if j in used: + continue + try: + if self.are_videos_similar(base["path"], g[j]["path"]): + cluster.append(g[j]) + used.add(j) + except Exception as e: + logger.debug(f"比较视频相似度失败: {e}") + pass + if len(cluster) > 1: + result.append(cluster) + + logger.info(f"查找完成:发现 {len(result)} 视频候选组") + return result + +# ------------------------- +# VideoDeduplicator high-level operations +# ------------------------- +class VideoDeduplicator: + def __init__(self, target_dirs: List[str], db_path: str="video_deduplicate.db", prefer_folder: Optional[str]=None, workers: int=0, auto_migrate: bool=True): + self.target_dirs = target_dirs + self.db_path = db_path + self.prefer_folder = prefer_folder + self.db_writer = DatabaseWriterThread(db_path=db_path, auto_migrate=auto_migrate) + # 启动写入线程 + self.db_writer.start() + self.scanner = FileScanner(db_writer=self.db_writer, workers=workers) + self.finder = DuplicateFinder(db_path=self.db_path) + + def scan_all(self): + for d in self.target_dirs: + self.scanner.scan(d) + + def remove_groups(self, groups: List[List[Dict[str,Any]]], dry_run: bool=True, no_backup: bool=False) -> Tuple[List[str], List[str]]: + kept = [] + deleted = [] + for group in groups: + if not group: + continue + # choose keeper + keeper = None + if self.prefer_folder: + for f in group: + if self.prefer_folder in f["path"]: + keeper = f + break + if not keeper: + keeper = max(group, key=lambda x: x.get("size", 0)) + kept.append(keeper["path"]) + for f in group: + p = f["path"] + if p == keeper["path"]: + continue + if dry_run: + logger.info(f"[dry-run] 删除 {p} (保留 {keeper['path']})") + self.db_writer.enqueue_operation({ + "operation_type": "planned_delete", + "file_path": p, + "file_hash": f.get("hash"), + "reason": "dry_run", + "details": None, + "created_at": datetime.now().isoformat() + }) + deleted.append(p) + else: + ok = safe_remove(p, no_backup=no_backup, backup_dir=None, db_writer=self.db_writer) + if ok: + deleted.append(p) + else: + logger.info(f"跳过删除(可能为硬链接或权限问题): {p}") + return kept, deleted + + def run_deduplication(self, dry_run: bool=True, no_backup: bool=False) -> Dict[str,Any]: + logger.info("开始视频去重") + self.scan_all() + logger.info("等待 db_writer 完成写入任务...") + # wait until queue is drained or timeout + start = time.time() + while not self.db_writer._queue.empty(): + time.sleep(0.5) + if time.time() - start > 600: + logger.error("等待 db_writer 超过 600 秒,提前退出") + break + groups = self.finder.find_video_groups() + kept, deleted = self.remove_groups(groups, dry_run=dry_run, no_backup=no_backup) + return {"kept": kept, "deleted": deleted, "groups": len(groups)} + +# ===================================================== +# CLI & Main Function +# ===================================================== + +def parse_args(): + parser = argparse.ArgumentParser(description="视频去重工具") + parser.add_argument( + "-d", "--dirs", + nargs="+", + required=True, + help="指定需要扫描的目录(一个或多个)" + ) + parser.add_argument( + "--prefer", + type=str, + default=None, + help="优先保留的路径片段(如果匹配文件路径则优先保留)" + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="演示模式:仅显示将要删除的文件,不实际删除" + ) + parser.add_argument( + "--no-backup", + action="store_true", + help="删除时不创建备份(谨慎)" + ) + parser.add_argument( + "--workers", + type=int, + default=0, + help="扫描线程数(默认自动计算)" + ) + parser.add_argument( + "--db", + type=str, + default="video_deduplicate.db", + help="使用的数据库文件" + ) + parser.add_argument( + "--migrate", + action="store_true", + help="强制允许自动迁移数据库(锁死时会迁移)" + ) + + return parser.parse_args() + +def main(): + args = parse_args() + + logger.info("==============================================") + logger.info(" 视频去重工具 (Video Deduplicator) ") + logger.info("==============================================") + logger.info(f"扫描目录:{args.dirs}") + logger.info(f"数据库文件:{args.db}") + logger.info(f"优先保留路径片段:{args.prefer}") + if args.dry_run: + logger.info("警告:dry-run 模式(不会删除任何文件)") + if args.no_backup: + logger.warning("危险:已启用 --no-backup,不会创建备份!") + + cleaner = VideoDeduplicator( + target_dirs=args.dirs, + db_path=args.db, + prefer_folder=args.prefer, + workers=args.workers, + auto_migrate=args.migrate, + ) + + result = None + + try: + result = cleaner.run_deduplication( + dry_run=args.dry_run, + no_backup=args.no_backup, + ) + except Exception as e: + logger.error(f"运行清理任务发生异常: {e}", exc_info=True) + finally: + # ensure writer shutdown + try: + cleaner.db_writer.stop() + cleaner.db_writer.join(timeout=10) + except Exception: + pass + + logger.info("所有任务完成。") + + if result is not None: + logger.info("========== 清理结果 ==========") + logger.info(f"保留文件数: {len(result['kept'])}") + logger.info(f"删除文件数: {len(result['deleted'])}") + logger.info(f"发现相似组: {result['groups']}") + + if args.dry_run: + logger.info("\n将要删除的文件列表:") + for f in result['deleted']: + logger.info(f" {f}") + else: + logger.info("\n已删除的文件列表:") + for f in result['deleted']: + logger.info(f" {f}") + +if __name__ == "__main__": + main() \ No newline at end of file