817 lines
29 KiB
Python
817 lines
29 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
music_deduplicate.py — 音乐去重专用版本
|
||
|
||
特性概览:
|
||
- 多线程扫描 + 单线程 DatabaseWriterThread 写入
|
||
- safe_remove:硬链接保护
|
||
- 容错导入 librosa/scipy 等(功能降级)
|
||
- 自动检测写入阻塞并自动恢复
|
||
- 基于音频指纹的智能去重
|
||
"""
|
||
from __future__ import annotations
|
||
import os
|
||
import sys
|
||
import time
|
||
import warnings
|
||
import threading
|
||
import queue
|
||
import hashlib
|
||
import shutil
|
||
import sqlite3
|
||
import logging
|
||
import argparse
|
||
import re
|
||
import math
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
warnings.filterwarnings("ignore", category=UserWarning, module="numba")
|
||
|
||
# -------------------------
|
||
# logging helper
|
||
# -------------------------
|
||
def setup_logging(log_level=logging.INFO, log_file="music_deduplicate.log"):
|
||
logging.basicConfig(
|
||
level=log_level,
|
||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||
handlers=[
|
||
logging.FileHandler(log_file, encoding="utf-8"),
|
||
logging.StreamHandler(sys.stdout),
|
||
],
|
||
)
|
||
return logging.getLogger(__name__)
|
||
|
||
logger = setup_logging()
|
||
|
||
# -------------------------
|
||
# safe optional imports (robust)
|
||
# -------------------------
|
||
LIBROSA_AVAILABLE = False
|
||
SCIPY_AVAILABLE = False
|
||
|
||
try:
|
||
import numpy as np # type: ignore
|
||
try:
|
||
import librosa # type: ignore
|
||
LIBROSA_AVAILABLE = True
|
||
logger.info("librosa available")
|
||
except Exception as e:
|
||
librosa = None # type: ignore
|
||
LIBROSA_AVAILABLE = False
|
||
logger.warning(f"librosa 导入失败: {e}")
|
||
|
||
try:
|
||
from scipy import signal as scipy_signal # type: ignore
|
||
SCIPY_AVAILABLE = True
|
||
except Exception as e:
|
||
scipy_signal = None
|
||
SCIPY_AVAILABLE = False
|
||
logger.warning(f"scipy.signal 导入失败: {e}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"科学栈初始化失败: {e}")
|
||
|
||
# -------------------------
|
||
# utils
|
||
# -------------------------
|
||
def choose_worker_count(requested: Optional[int] = None) -> int:
|
||
if requested and requested > 0:
|
||
return requested
|
||
try:
|
||
cpu = os.cpu_count() or 1
|
||
return min(32, max(4, cpu * 2))
|
||
except Exception:
|
||
return 4
|
||
|
||
def file_sha256(path: str, block_size: int = 65536) -> str:
|
||
h = hashlib.sha256()
|
||
try:
|
||
with open(path, "rb") as f:
|
||
for block in iter(lambda: f.read(block_size), b""):
|
||
h.update(block)
|
||
return h.hexdigest()
|
||
except Exception as e:
|
||
logger.debug(f"计算哈希失败 {path}: {e}")
|
||
return ""
|
||
|
||
# -------------------------
|
||
# safe_remove (硬链接保护:策略 C)
|
||
# -------------------------
|
||
def safe_remove(path: str, no_backup: bool=False, backup_dir: Optional[str]=None, db_writer: Optional["DatabaseWriterThread"]=None) -> bool:
|
||
try:
|
||
st = os.stat(path)
|
||
except Exception as e:
|
||
logger.warning(f"无法访问文件 {path}: {e}")
|
||
return False
|
||
|
||
if getattr(st, "st_nlink", 1) > 1:
|
||
logger.info(f"文件有多个硬链接,跳过删除以保护硬链接: {path}")
|
||
if db_writer:
|
||
db_writer.enqueue_operation({
|
||
"operation_type": "skip_delete_hardlink",
|
||
"file_path": path,
|
||
"file_hash": None,
|
||
"reason": "hardlink_skip",
|
||
"details": None
|
||
})
|
||
return False
|
||
|
||
if backup_dir and not no_backup:
|
||
try:
|
||
os.makedirs(backup_dir, exist_ok=True)
|
||
dest = os.path.join(backup_dir, os.path.basename(path))
|
||
shutil.move(path, dest)
|
||
logger.info(f"已将文件移动到备份目录: {path} -> {dest}")
|
||
if db_writer:
|
||
db_writer.enqueue_operation({
|
||
"operation_type": "backup_move",
|
||
"file_path": path,
|
||
"file_hash": None,
|
||
"reason": "moved_to_backup",
|
||
"details": dest
|
||
})
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"移动到备份目录失败 {path}: {e}")
|
||
|
||
try:
|
||
os.remove(path)
|
||
logger.info(f"已删除文件: {path}")
|
||
if db_writer:
|
||
db_writer.enqueue_operation({
|
||
"operation_type": "delete",
|
||
"file_path": path,
|
||
"file_hash": None,
|
||
"reason": "deleted",
|
||
"details": None
|
||
})
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"删除文件失败 {path}: {e}")
|
||
return False
|
||
|
||
# -------------------------
|
||
# DatabaseWriterThread (with detection & auto-migrate)
|
||
# -------------------------
|
||
class DatabaseWriterThread(threading.Thread):
|
||
"""
|
||
Single-threaded DB writer with:
|
||
- internal queue for files/ops
|
||
- lock detection and automatic recovery
|
||
- optional automatic DB migration to a safe directory
|
||
"""
|
||
def __init__(self, db_path: str = "music_deduplicate.db", batch_limit:int = 200, flush_interval: float = 1.0, lock_detect_timeout: float = 8.0, max_retries:int=3, auto_migrate:bool=True):
|
||
super().__init__(daemon=True)
|
||
self.db_path = str(db_path)
|
||
self.batch_limit = batch_limit
|
||
self.flush_interval = flush_interval
|
||
self.lock_detect_timeout = lock_detect_timeout
|
||
self.max_retries = max_retries
|
||
self.auto_migrate = auto_migrate
|
||
|
||
self._conn: Optional[sqlite3.Connection] = None
|
||
self._queue: "queue.Queue[Tuple[str, Any]]" = queue.Queue()
|
||
self._stop_event = threading.Event()
|
||
self._last_write_time = 0.0
|
||
self._consecutive_failures = 0
|
||
|
||
# -------------------------
|
||
# database connection
|
||
# -------------------------
|
||
def _connect(self):
|
||
try:
|
||
conn = sqlite3.connect(
|
||
self.db_path,
|
||
timeout=3,
|
||
isolation_level=None,
|
||
check_same_thread=False,
|
||
)
|
||
conn.execute("PRAGMA journal_mode=WAL;")
|
||
conn.execute("PRAGMA synchronous=NORMAL;")
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS files (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
file_path TEXT UNIQUE,
|
||
file_hash TEXT,
|
||
file_size INTEGER,
|
||
file_mtime REAL,
|
||
created_at TEXT
|
||
);
|
||
"""
|
||
)
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS operations (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
operation_type TEXT,
|
||
file_path TEXT,
|
||
file_hash TEXT,
|
||
reason TEXT,
|
||
details TEXT,
|
||
created_at TEXT
|
||
);
|
||
"""
|
||
)
|
||
conn.commit()
|
||
logger.info(f"数据库连接成功:{self.db_path}")
|
||
return conn
|
||
except Exception as e:
|
||
logger.error(f"数据库连接失败 {self.db_path}: {e}")
|
||
return None
|
||
|
||
# -------------------------
|
||
def start(self):
|
||
super().start()
|
||
|
||
def stop(self):
|
||
self._stop_event.set()
|
||
|
||
def join(self, timeout=None):
|
||
self._stop_event.set()
|
||
super().join(timeout)
|
||
if self._conn:
|
||
try:
|
||
self._conn.commit()
|
||
self._conn.close()
|
||
except:
|
||
pass
|
||
|
||
# -------------------------
|
||
# Queue interface
|
||
# -------------------------
|
||
def enqueue_file(self, record: Dict[str, Any]):
|
||
"""
|
||
record = {
|
||
"file_path": str,
|
||
"file_hash": str,
|
||
"file_size": int,
|
||
"file_mtime": float,
|
||
"created_at": timestamp,
|
||
}
|
||
"""
|
||
self._queue.put(("file", record))
|
||
|
||
def enqueue_operation(self, record: Dict[str, Any]):
|
||
self._queue.put(("operation", record))
|
||
|
||
# -------------------------
|
||
# Writer loop
|
||
# -------------------------
|
||
def run(self):
|
||
logger.info("DatabaseWriterThread 启动")
|
||
buffer_files = []
|
||
buffer_ops = []
|
||
last_flush_time = time.time()
|
||
|
||
while not self._stop_event.is_set():
|
||
try:
|
||
item_type, data = self._queue.get(timeout=self.flush_interval)
|
||
if item_type == "file":
|
||
buffer_files.append(data)
|
||
elif item_type == "operation":
|
||
buffer_ops.append(data)
|
||
except queue.Empty:
|
||
pass
|
||
|
||
now = time.time()
|
||
need_flush = False
|
||
|
||
if len(buffer_files) >= self.batch_limit or len(buffer_ops) >= self.batch_limit:
|
||
need_flush = True
|
||
if now - last_flush_time >= self.flush_interval:
|
||
need_flush = True
|
||
|
||
if need_flush:
|
||
ok = self._flush(buffer_files, buffer_ops)
|
||
if ok:
|
||
buffer_files.clear()
|
||
buffer_ops.clear()
|
||
last_flush_time = now
|
||
|
||
self._flush(buffer_files, buffer_ops)
|
||
logger.info("DatabaseWriterThread 结束(队列已清空)")
|
||
|
||
# -------------------------
|
||
# Flush — with lock detection and recovery
|
||
# -------------------------
|
||
def _flush(self, files: List[Dict[str,Any]], ops: List[Dict[str,Any]]) -> bool:
|
||
if not self._conn:
|
||
logger.error("数据库连接失效(conn = None)尝试重新连接…")
|
||
self._conn = self._connect()
|
||
if not self._conn:
|
||
return False
|
||
|
||
if not files and not ops:
|
||
return True
|
||
|
||
start = time.time()
|
||
ok = False
|
||
last_err = None
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
cur = self._conn.cursor()
|
||
for rec in files:
|
||
cur.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO files (file_path, file_hash, file_size, file_mtime, created_at)
|
||
VALUES (?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
rec.get("file_path"),
|
||
rec.get("file_hash"),
|
||
rec.get("file_size"),
|
||
rec.get("file_mtime"),
|
||
rec.get("created_at"),
|
||
)
|
||
)
|
||
for rec in ops:
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO operations (operation_type, file_path, file_hash, reason, details, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
rec.get("operation_type"),
|
||
rec.get("file_path"),
|
||
rec.get("file_hash"),
|
||
rec.get("reason"),
|
||
rec.get("details"),
|
||
rec.get("created_at", datetime.now().isoformat()),
|
||
)
|
||
)
|
||
self._conn.commit()
|
||
ok = True
|
||
self._consecutive_failures = 0
|
||
break
|
||
except Exception as e:
|
||
last_err = e
|
||
logger.warning(f"批量写入数据库失败 (第 {attempt+1}/{self.max_retries} 次):{e}")
|
||
|
||
if "locked" in str(e).lower():
|
||
time.sleep(0.8 + attempt * 0.4)
|
||
continue
|
||
|
||
time.sleep(0.5)
|
||
|
||
if not ok:
|
||
self._consecutive_failures += 1
|
||
elapsed = time.time() - start
|
||
|
||
logger.error(f"写入失败超过重试次数:{last_err}")
|
||
|
||
if elapsed > self.lock_detect_timeout or "locked" in str(last_err).lower():
|
||
logger.error("检测到数据库长期锁定,尝试恢复连接…")
|
||
try:
|
||
self._conn.close()
|
||
except:
|
||
pass
|
||
self._conn = self._connect()
|
||
if self._conn:
|
||
logger.info("数据库重连成功")
|
||
return False
|
||
|
||
if self.auto_migrate:
|
||
logger.error("数据库重连失败,准备自动迁移数据库…")
|
||
return self._try_auto_migrate()
|
||
|
||
return ok
|
||
|
||
# -------------------------
|
||
# Try automatically migrating DB to a safe path
|
||
# -------------------------
|
||
def _try_auto_migrate(self) -> bool:
|
||
try:
|
||
safe_dir = "/var/db/music_deduplicate"
|
||
os.makedirs(safe_dir, exist_ok=True)
|
||
new_path = os.path.join(safe_dir, "music_deduplicate.db")
|
||
|
||
try:
|
||
shutil.copy2(self.db_path, new_path)
|
||
logger.info(f"数据库已迁移: {self.db_path} -> {new_path}")
|
||
except Exception as e:
|
||
logger.error(f"数据库迁移失败: {e}")
|
||
return False
|
||
|
||
self.db_path = new_path
|
||
self._conn = self._connect()
|
||
if self._conn:
|
||
logger.info("迁移后的数据库连接成功,继续运行")
|
||
return True
|
||
else:
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"自动迁移过程异常: {e}")
|
||
return False
|
||
|
||
# =====================================================
|
||
# 音频指纹分析
|
||
# =====================================================
|
||
|
||
# -------------------------
|
||
# 音频指纹提取(容错)
|
||
# -------------------------
|
||
class AudioFingerprint:
|
||
def __init__(self):
|
||
self.ok = LIBROSA_AVAILABLE or SCIPY_AVAILABLE
|
||
|
||
def extract(self, path: str) -> Optional[str]:
|
||
"""
|
||
提取音频指纹,返回字符串表示
|
||
"""
|
||
if not self.ok:
|
||
logger.debug(f"音频指纹模块不可用,跳过: {path}")
|
||
return None
|
||
|
||
try:
|
||
if LIBROSA_AVAILABLE:
|
||
y, sr = librosa.load(path, sr=22050, mono=True, duration=30) # 只分析前30秒
|
||
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=20)
|
||
# 取MFCC的均值和方差作为特征
|
||
fp_mean = np.mean(mfcc, axis=1)
|
||
fp_std = np.std(mfcc, axis=1)
|
||
fp = np.concatenate([fp_mean, fp_std])
|
||
# 将浮点数数组转换为字符串
|
||
return f"audio_{'_'.join([f'{x:.4f}' for x in fp])}"
|
||
|
||
# librosa 不可用时,用 scipy_signal
|
||
if SCIPY_AVAILABLE:
|
||
import soundfile as sf
|
||
data, sr = sf.read(path)
|
||
if data.ndim > 1:
|
||
data = data.mean(axis=1)
|
||
# 只取前30秒
|
||
max_samples = min(len(data), sr * 30)
|
||
data = data[:max_samples]
|
||
freqs, times, Sxx = scipy_signal.spectrogram(data, sr)
|
||
fp = np.mean(Sxx, axis=1)
|
||
return f"audio_{'_'.join([f'{x:.4f}' for x in fp])}"
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"提取音频指纹失败 {path}: {e}")
|
||
return None
|
||
|
||
# =====================================================
|
||
# 扫描器:多线程扫描 + 入队写数据库
|
||
# =====================================================
|
||
|
||
class FileScanner:
|
||
EXT_AUDIO = {".mp3", ".aac", ".flac", ".ogg", ".wav", ".m4a", ".wma", ".ape", ".alac"}
|
||
|
||
def __init__(self, db_writer: DatabaseWriterThread, workers:int=8):
|
||
self.db_writer = db_writer
|
||
self.workers = choose_worker_count(workers)
|
||
self.audio_fp = AudioFingerprint()
|
||
|
||
# -------------------------
|
||
def scan(self, root: str):
|
||
"""
|
||
遍历路径,将文件元数据推送到数据库队列。
|
||
"""
|
||
root = os.path.abspath(root)
|
||
logger.info(f"开始扫描路径: {root}")
|
||
|
||
file_list: List[str] = []
|
||
for base, dirs, files in os.walk(root):
|
||
for f in files:
|
||
ext = os.path.splitext(f)[1].lower()
|
||
if ext in self.EXT_AUDIO:
|
||
full = os.path.join(base, f)
|
||
file_list.append(full)
|
||
|
||
logger.info(f"扫描完成,共发现音频文件: {len(file_list)}")
|
||
|
||
with ThreadPoolExecutor(max_workers=self.workers) as ex:
|
||
futures = {ex.submit(self._process_one, path): path for path in file_list}
|
||
for fut in as_completed(futures):
|
||
try:
|
||
fut.result()
|
||
except Exception as e:
|
||
logger.error(f"处理文件异常: {e}")
|
||
|
||
# -------------------------
|
||
def _process_one(self, path: str):
|
||
"""
|
||
获取文件大小、时间、hash(快速)并提交数据库线程。
|
||
"""
|
||
try:
|
||
st = os.stat(path)
|
||
except Exception as e:
|
||
logger.debug(f"无法读取文件 stat: {path}: {e}")
|
||
return
|
||
|
||
ext = os.path.splitext(path)[1].lower()
|
||
|
||
# 轻量快速 hash(仅文件大小>1MB才计算)
|
||
file_hash = ""
|
||
if st.st_size > 1_000_000:
|
||
file_hash = file_sha256(path)
|
||
else:
|
||
file_hash = f"SMALL-{st.st_size}-{int(st.st_mtime)}"
|
||
|
||
record = {
|
||
"file_path": path,
|
||
"file_hash": file_hash,
|
||
"file_size": st.st_size,
|
||
"file_mtime": st.st_mtime,
|
||
"created_at": datetime.now().isoformat(),
|
||
}
|
||
self.db_writer.enqueue_file(record)
|
||
|
||
# =====================================================
|
||
# 相似度检测与去重决策
|
||
# =====================================================
|
||
|
||
class DuplicateFinder:
|
||
"""
|
||
基于 DB 快照进行相似群组查找
|
||
"""
|
||
def __init__(self, db_path: str):
|
||
self.db_path = db_path
|
||
self.audio_fp = AudioFingerprint()
|
||
|
||
def _read_files_from_db(self) -> List[Dict[str, Any]]:
|
||
out = []
|
||
try:
|
||
conn = sqlite3.connect(self.db_path, timeout=30)
|
||
cur = conn.cursor()
|
||
cur.execute("SELECT file_path, file_hash, file_size FROM files")
|
||
for row in cur.fetchall():
|
||
out.append({"path": row[0], "hash": row[1], "size": row[2]})
|
||
except Exception as e:
|
||
logger.warning(f"读取 DB 列表失败: {e}")
|
||
finally:
|
||
try:
|
||
conn.close()
|
||
except:
|
||
pass
|
||
return out
|
||
|
||
def group_by_name(self, files: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
|
||
groups = {}
|
||
for f in files:
|
||
# 使用文件名(不含扩展名)作为分组键
|
||
key = Path(f["path"]).stem.lower()
|
||
# 移除常见标记
|
||
key = re.sub(r"(official|version|remaster|remastered|remix|edit|radio|clean|explicit)", "", key, flags=re.IGNORECASE)
|
||
key = re.sub(r"[\(\)\[\]\{\}\.]", " ", key)
|
||
key = re.sub(r"\s+", " ", key).strip()
|
||
groups.setdefault(key, []).append(f)
|
||
return [g for g in groups.values() if len(g) > 1]
|
||
|
||
def are_audios_similar(self, a: str, b: str, threshold: float = 0.85) -> bool:
|
||
"""
|
||
判断两个音频文件是否相似
|
||
"""
|
||
try:
|
||
# 首先比较文件大小
|
||
sa = os.path.getsize(a)
|
||
sb = os.path.getsize(b)
|
||
size_ratio = min(sa, sb) / max(sa, sb) if max(sa, sb) > 0 else 0
|
||
if size_ratio > 0.95: # 大小相差小于5%
|
||
return True
|
||
except Exception:
|
||
pass
|
||
|
||
# 使用音频指纹比较
|
||
try:
|
||
fp_a = self.audio_fp.extract(a)
|
||
fp_b = self.audio_fp.extract(b)
|
||
|
||
if fp_a and fp_b:
|
||
# 解析指纹字符串
|
||
if fp_a.startswith("audio_") and fp_b.startswith("audio_"):
|
||
values_a = [float(x) for x in fp_a[6:].split("_")]
|
||
values_b = [float(x) for x in fp_b[6:].split("_")]
|
||
|
||
if len(values_a) == len(values_b):
|
||
# 计算余弦相似度
|
||
dot_product = sum(a*b for a, b in zip(values_a, values_b))
|
||
norm_a = math.sqrt(sum(a*a for a in values_a))
|
||
norm_b = math.sqrt(sum(b*b for b in values_b))
|
||
|
||
if norm_a > 0 and norm_b > 0:
|
||
similarity = dot_product / (norm_a * norm_b)
|
||
return similarity >= threshold
|
||
|
||
except Exception as e:
|
||
logger.debug(f"音频指纹比较失败: {e}")
|
||
|
||
return False
|
||
|
||
def find_audio_groups(self) -> List[List[Dict[str,Any]]]:
|
||
files = self._read_files_from_db()
|
||
name_groups = self.group_by_name(files)
|
||
result = []
|
||
|
||
for g in name_groups:
|
||
if len(g) <= 1:
|
||
continue
|
||
used = set()
|
||
for i in range(len(g)):
|
||
if i in used:
|
||
continue
|
||
base = g[i]
|
||
cluster = [base]
|
||
used.add(i)
|
||
for j in range(i+1, len(g)):
|
||
if j in used:
|
||
continue
|
||
try:
|
||
if self.are_audios_similar(base["path"], g[j]["path"]):
|
||
cluster.append(g[j])
|
||
used.add(j)
|
||
except Exception as e:
|
||
logger.debug(f"比较音频相似度失败: {e}")
|
||
pass
|
||
if len(cluster) > 1:
|
||
result.append(cluster)
|
||
|
||
logger.info(f"查找完成:发现 {len(result)} 音频候选组")
|
||
return result
|
||
|
||
# -------------------------
|
||
# MusicDeduplicator high-level operations
|
||
# -------------------------
|
||
class MusicDeduplicator:
|
||
def __init__(self, target_dirs: List[str], db_path: str="music_deduplicate.db", prefer_folder: Optional[str]=None, workers: int=0, auto_migrate: bool=True):
|
||
self.target_dirs = target_dirs
|
||
self.db_path = db_path
|
||
self.prefer_folder = prefer_folder
|
||
self.db_writer = DatabaseWriterThread(db_path=db_path, auto_migrate=auto_migrate)
|
||
# 启动写入线程
|
||
self.db_writer.start()
|
||
self.scanner = FileScanner(db_writer=self.db_writer, workers=workers)
|
||
self.finder = DuplicateFinder(db_path=self.db_path)
|
||
|
||
def scan_all(self):
|
||
for d in self.target_dirs:
|
||
self.scanner.scan(d)
|
||
|
||
def remove_groups(self, groups: List[List[Dict[str,Any]]], dry_run: bool=True, no_backup: bool=False) -> Tuple[List[str], List[str]]:
|
||
kept = []
|
||
deleted = []
|
||
for group in groups:
|
||
if not group:
|
||
continue
|
||
# choose keeper
|
||
keeper = None
|
||
if self.prefer_folder:
|
||
for f in group:
|
||
if self.prefer_folder in f["path"]:
|
||
keeper = f
|
||
break
|
||
if not keeper:
|
||
# 优先保留高比特率或大文件
|
||
keeper = max(group, key=lambda x: x.get("size", 0))
|
||
kept.append(keeper["path"])
|
||
for f in group:
|
||
p = f["path"]
|
||
if p == keeper["path"]:
|
||
continue
|
||
if dry_run:
|
||
logger.info(f"[dry-run] 删除 {p} (保留 {keeper['path']})")
|
||
self.db_writer.enqueue_operation({
|
||
"operation_type": "planned_delete",
|
||
"file_path": p,
|
||
"file_hash": f.get("hash"),
|
||
"reason": "dry_run",
|
||
"details": None,
|
||
"created_at": datetime.now().isoformat()
|
||
})
|
||
deleted.append(p)
|
||
else:
|
||
ok = safe_remove(p, no_backup=no_backup, backup_dir=None, db_writer=self.db_writer)
|
||
if ok:
|
||
deleted.append(p)
|
||
else:
|
||
logger.info(f"跳过删除(可能为硬链接或权限问题): {p}")
|
||
return kept, deleted
|
||
|
||
def run_deduplication(self, dry_run: bool=True, no_backup: bool=False) -> Dict[str,Any]:
|
||
logger.info("开始音乐去重")
|
||
self.scan_all()
|
||
logger.info("等待 db_writer 完成写入任务...")
|
||
# wait until queue is drained or timeout
|
||
start = time.time()
|
||
while not self.db_writer._queue.empty():
|
||
time.sleep(0.5)
|
||
if time.time() - start > 600:
|
||
logger.error("等待 db_writer 超过 600 秒,提前退出")
|
||
break
|
||
groups = self.finder.find_audio_groups()
|
||
kept, deleted = self.remove_groups(groups, dry_run=dry_run, no_backup=no_backup)
|
||
return {"kept": kept, "deleted": deleted, "groups": len(groups)}
|
||
|
||
# =====================================================
|
||
# CLI & Main Function
|
||
# =====================================================
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="音乐去重工具")
|
||
parser.add_argument(
|
||
"-d", "--dirs",
|
||
nargs="+",
|
||
required=True,
|
||
help="指定需要扫描的目录(一个或多个)"
|
||
)
|
||
parser.add_argument(
|
||
"--prefer",
|
||
type=str,
|
||
default=None,
|
||
help="优先保留的路径片段(如果匹配文件路径则优先保留)"
|
||
)
|
||
parser.add_argument(
|
||
"--dry-run",
|
||
action="store_true",
|
||
help="演示模式:仅显示将要删除的文件,不实际删除"
|
||
)
|
||
parser.add_argument(
|
||
"--no-backup",
|
||
action="store_true",
|
||
help="删除时不创建备份(谨慎)"
|
||
)
|
||
parser.add_argument(
|
||
"--workers",
|
||
type=int,
|
||
default=0,
|
||
help="扫描线程数(默认自动计算)"
|
||
)
|
||
parser.add_argument(
|
||
"--db",
|
||
type=str,
|
||
default="music_deduplicate.db",
|
||
help="使用的数据库文件"
|
||
)
|
||
parser.add_argument(
|
||
"--migrate",
|
||
action="store_true",
|
||
help="强制允许自动迁移数据库(锁死时会迁移)"
|
||
)
|
||
|
||
return parser.parse_args()
|
||
|
||
def main():
|
||
args = parse_args()
|
||
|
||
logger.info("==============================================")
|
||
logger.info(" 音乐去重工具 (Music Deduplicator) ")
|
||
logger.info("==============================================")
|
||
logger.info(f"扫描目录:{args.dirs}")
|
||
logger.info(f"数据库文件:{args.db}")
|
||
logger.info(f"优先保留路径片段:{args.prefer}")
|
||
if args.dry_run:
|
||
logger.info("警告:dry-run 模式(不会删除任何文件)")
|
||
if args.no_backup:
|
||
logger.warning("危险:已启用 --no-backup,不会创建备份!")
|
||
|
||
cleaner = MusicDeduplicator(
|
||
target_dirs=args.dirs,
|
||
db_path=args.db,
|
||
prefer_folder=args.prefer,
|
||
workers=args.workers,
|
||
auto_migrate=args.migrate,
|
||
)
|
||
|
||
result = None
|
||
|
||
try:
|
||
result = cleaner.run_deduplication(
|
||
dry_run=args.dry_run,
|
||
no_backup=args.no_backup,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"运行清理任务发生异常: {e}", exc_info=True)
|
||
finally:
|
||
# ensure writer shutdown
|
||
try:
|
||
cleaner.db_writer.stop()
|
||
cleaner.db_writer.join(timeout=10)
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info("所有任务完成。")
|
||
|
||
if result is not None:
|
||
logger.info("========== 清理结果 ==========")
|
||
logger.info(f"保留文件数: {len(result['kept'])}")
|
||
logger.info(f"删除文件数: {len(result['deleted'])}")
|
||
logger.info(f"发现相似组: {result['groups']}")
|
||
|
||
if args.dry_run:
|
||
logger.info("\n将要删除的文件列表:")
|
||
for f in result['deleted']:
|
||
logger.info(f" {f}")
|
||
else:
|
||
logger.info("\n已删除的文件列表:")
|
||
for f in result['deleted']:
|
||
logger.info(f" {f}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |