Files
DeleteChongfuTVYY/video_deduplicate.py
2026-01-25 21:17:09 +08:00

841 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
video_deduplicate.py — 视频去重专用版本
特性概览:
- 多线程扫描 + 单线程 DatabaseWriterThread 写入(永不出现 database is locked
- safe_remove硬链接保护策略 C
- 容错导入 OpenCV/PIL/imagehash 等(功能降级)
- 自动检测写入阻塞并自动恢复
- 基于视频指纹和内容相似度的智能去重
"""
from __future__ import annotations
import os
import sys
import time
import warnings
import threading
import queue
import hashlib
import shutil
import sqlite3
import logging
import argparse
import math
import re
from pathlib import Path
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple
warnings.filterwarnings("ignore", category=UserWarning, module="numba")
# -------------------------
# logging helper
# -------------------------
def setup_logging(log_level=logging.INFO, log_file="video_deduplicate.log"):
logging.basicConfig(
level=log_level,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file, encoding="utf-8"),
logging.StreamHandler(sys.stdout),
],
)
return logging.getLogger(__name__)
logger = setup_logging()
# -------------------------
# safe optional imports (robust)
# -------------------------
VIDEO_PROCESSING_AVAILABLE = False
try:
import numpy as np # type: ignore
try:
import cv2 # type: ignore
import imagehash # type: ignore
from PIL import Image # type: ignore
VIDEO_PROCESSING_AVAILABLE = True
except Exception as e:
VIDEO_PROCESSING_AVAILABLE = False
logger.warning(f"视频处理库导入失败: {e}")
except Exception as e:
logger.warning(f"科学栈初始化失败: {e}")
# -------------------------
# utils
# -------------------------
def choose_worker_count(requested: Optional[int] = None) -> int:
if requested and requested > 0:
return requested
try:
cpu = os.cpu_count() or 1
return min(32, max(4, cpu * 2))
except Exception:
return 4
def file_sha256(path: str, block_size: int = 65536) -> str:
h = hashlib.sha256()
try:
with open(path, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
h.update(block)
return h.hexdigest()
except Exception as e:
logger.debug(f"计算哈希失败 {path}: {e}")
return ""
# -------------------------
# safe_remove (硬链接保护:策略 C)
# -------------------------
def safe_remove(path: str, no_backup: bool=False, backup_dir: Optional[str]=None, db_writer: Optional["DatabaseWriterThread"]=None) -> bool:
try:
st = os.stat(path)
except Exception as e:
logger.warning(f"无法访问文件 {path}: {e}")
return False
if getattr(st, "st_nlink", 1) > 1:
logger.info(f"文件有多个硬链接,跳过删除以保护硬链接: {path}")
if db_writer:
db_writer.enqueue_operation({
"operation_type": "skip_delete_hardlink",
"file_path": path,
"file_hash": None,
"reason": "hardlink_skip",
"details": None
})
return False
if backup_dir and not no_backup:
try:
os.makedirs(backup_dir, exist_ok=True)
dest = os.path.join(backup_dir, os.path.basename(path))
shutil.move(path, dest)
logger.info(f"已将文件移动到备份目录: {path} -> {dest}")
if db_writer:
db_writer.enqueue_operation({
"operation_type": "backup_move",
"file_path": path,
"file_hash": None,
"reason": "moved_to_backup",
"details": dest
})
return True
except Exception as e:
logger.warning(f"移动到备份目录失败 {path}: {e}")
try:
os.remove(path)
logger.info(f"已删除文件: {path}")
if db_writer:
db_writer.enqueue_operation({
"operation_type": "delete",
"file_path": path,
"file_hash": None,
"reason": "deleted",
"details": None
})
return True
except Exception as e:
logger.error(f"删除文件失败 {path}: {e}")
return False
# -------------------------
# DatabaseWriterThread (with detection & auto-migrate)
# -------------------------
class DatabaseWriterThread(threading.Thread):
"""
Single-threaded DB writer with:
- internal queue for files/ops
- lock detection and automatic recovery
- optional automatic DB migration to a safe directory
"""
def __init__(self, db_path: str = "video_deduplicate.db", batch_limit:int = 200, flush_interval: float = 1.0, lock_detect_timeout: float = 8.0, max_retries:int=3, auto_migrate:bool=True):
super().__init__(daemon=True)
self.db_path = str(db_path)
self.batch_limit = batch_limit
self.flush_interval = flush_interval
self.lock_detect_timeout = lock_detect_timeout
self.max_retries = max_retries
self.auto_migrate = auto_migrate
self._conn: Optional[sqlite3.Connection] = None
self._queue: "queue.Queue[Tuple[str, Any]]" = queue.Queue()
self._stop_event = threading.Event()
self._last_write_time = 0.0
self._consecutive_failures = 0
# -------------------------
# database connection
# -------------------------
def _connect(self):
try:
conn = sqlite3.connect(
self.db_path,
timeout=3,
isolation_level=None,
check_same_thread=False,
)
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA synchronous=NORMAL;")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT UNIQUE,
file_hash TEXT,
file_size INTEGER,
file_mtime REAL,
created_at TEXT
);
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
operation_type TEXT,
file_path TEXT,
file_hash TEXT,
reason TEXT,
details TEXT,
created_at TEXT
);
"""
)
conn.commit()
logger.info(f"数据库连接成功:{self.db_path}")
return conn
except Exception as e:
logger.error(f"数据库连接失败 {self.db_path}: {e}")
return None
# -------------------------
def start(self):
# 先建立数据库连接,再启动线程
self._conn = self._connect()
if not self._conn:
logger.error("无法连接数据库,线程启动失败")
return
super().start()
def stop(self):
self._stop_event.set()
def join(self, timeout=None):
self._stop_event.set()
super().join(timeout)
if self._conn:
try:
self._conn.commit()
self._conn.close()
except:
pass
# -------------------------
# Queue interface
# -------------------------
def enqueue_file(self, record: Dict[str, Any]):
"""
record = {
"file_path": str,
"file_hash": str,
"file_size": int,
"file_mtime": float,
"created_at": timestamp,
}
"""
self._queue.put(("file", record))
def enqueue_operation(self, record: Dict[str, Any]):
self._queue.put(("operation", record))
# -------------------------
# Writer loop
# -------------------------
def run(self):
logger.info("DatabaseWriterThread 启动")
buffer_files = []
buffer_ops = []
last_flush_time = time.time()
while not self._stop_event.is_set():
try:
item_type, data = self._queue.get(timeout=self.flush_interval)
if item_type == "file":
buffer_files.append(data)
elif item_type == "operation":
buffer_ops.append(data)
except queue.Empty:
pass
now = time.time()
need_flush = False
if len(buffer_files) >= self.batch_limit or len(buffer_ops) >= self.batch_limit:
need_flush = True
if now - last_flush_time >= self.flush_interval:
need_flush = True
if need_flush:
ok = self._flush(buffer_files, buffer_ops)
if ok:
buffer_files.clear()
buffer_ops.clear()
last_flush_time = now
# 线程结束前最后刷新一次
self._flush(buffer_files, buffer_ops)
logger.info("DatabaseWriterThread 结束(队列已清空)")
# -------------------------
# Flush — with lock detection and recovery
# -------------------------
def _flush(self, files: List[Dict[str,Any]], ops: List[Dict[str,Any]]) -> bool:
if not self._conn:
logger.error("数据库连接失效conn = None尝试重新连接…")
self._conn = self._connect()
if not self._conn:
return False
if not files and not ops:
return True
start = time.time()
ok = False
last_err = None
for attempt in range(self.max_retries):
try:
cur = self._conn.cursor()
for rec in files:
cur.execute(
"""
INSERT OR REPLACE INTO files (file_path, file_hash, file_size, file_mtime, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(
rec.get("file_path"),
rec.get("file_hash"),
rec.get("file_size"),
rec.get("file_mtime"),
rec.get("created_at"),
)
)
for rec in ops:
cur.execute(
"""
INSERT INTO operations (operation_type, file_path, file_hash, reason, details, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
rec.get("operation_type"),
rec.get("file_path"),
rec.get("file_hash"),
rec.get("reason"),
rec.get("details"),
rec.get("created_at", datetime.now().isoformat()),
)
)
self._conn.commit()
ok = True
self._consecutive_failures = 0
break
except Exception as e:
last_err = e
logger.warning(f"批量写入数据库失败 (第 {attempt+1}/{self.max_retries} 次){e}")
if "locked" in str(e).lower():
time.sleep(0.8 + attempt * 0.4)
continue
time.sleep(0.5)
if not ok:
self._consecutive_failures += 1
elapsed = time.time() - start
logger.error(f"写入失败超过重试次数:{last_err}")
if elapsed > self.lock_detect_timeout or "locked" in str(last_err).lower():
logger.error("检测到数据库长期锁定,尝试恢复连接…")
try:
self._conn.close()
except:
pass
self._conn = self._connect()
if self._conn:
logger.info("数据库重连成功")
return False
if self.auto_migrate:
logger.error("数据库重连失败,准备自动迁移数据库…")
return self._try_auto_migrate()
return ok
# -------------------------
# Try automatically migrating DB to a safe path
# -------------------------
def _try_auto_migrate(self) -> bool:
try:
safe_dir = "/var/db/video_deduplicate"
os.makedirs(safe_dir, exist_ok=True)
new_path = os.path.join(safe_dir, "video_deduplicate.db")
try:
shutil.copy2(self.db_path, new_path)
logger.info(f"数据库已迁移: {self.db_path} -> {new_path}")
except Exception as e:
logger.error(f"数据库迁移失败: {e}")
return False
self.db_path = new_path
self._conn = self._connect()
if self._conn:
logger.info("迁移后的数据库连接成功,继续运行")
return True
else:
return False
except Exception as e:
logger.error(f"自动迁移过程异常: {e}")
return False
# =====================================================
# 指纹分析
# =====================================================
# -------------------------
# 视频指纹提取(容错)
# -------------------------
class VideoFingerprint:
def __init__(self):
self.ok = VIDEO_PROCESSING_AVAILABLE
def extract(self, path: str) -> Optional[str]:
"""
提取视频指纹,返回字符串表示
"""
if not self.ok:
logger.debug(f"视频指纹模块不可用,跳过: {path}")
return None
try:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
logger.error(f"打开视频失败: {path}")
return None
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_count <= 0:
cap.release()
return None
sample_count = 5 # 采样5帧
step = max(1, frame_count // sample_count)
fingerprints = []
for i in range(0, frame_count, step):
if len(fingerprints) >= sample_count:
break
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ok, frame = cap.read()
if not ok:
continue
# 转换为灰度图
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# 计算pHash
try:
pil_img = Image.fromarray(gray)
phash = imagehash.phash(pil_img)
fingerprints.append(str(phash))
except Exception as e:
logger.debug(f"计算pHash失败: {e}")
continue
cap.release()
if not fingerprints:
return None
# 返回指纹的字符串表示
return f"video_{'_'.join(fingerprints)}"
except Exception as e:
logger.error(f"提取视频指纹失败 {path}: {e}")
return None
# =====================================================
# 扫描器:多线程扫描 + 入队写数据库
# =====================================================
class FileScanner:
EXT_VIDEO = {".mp4", ".mkv", ".avi", ".rmvb", ".mov", ".wmv", ".flv", ".ts", ".m2ts", ".webm"}
def __init__(self, db_writer: DatabaseWriterThread, workers:int=8):
self.db_writer = db_writer
self.workers = choose_worker_count(workers)
self.video_fp = VideoFingerprint()
# -------------------------
def scan(self, root: str):
"""
遍历路径,将文件元数据推送到数据库队列。
"""
root = os.path.abspath(root)
logger.info(f"开始扫描路径: {root}")
file_list: List[str] = []
for base, dirs, files in os.walk(root):
for f in files:
ext = os.path.splitext(f)[1].lower()
if ext in self.EXT_VIDEO:
full = os.path.join(base, f)
file_list.append(full)
logger.info(f"扫描完成,共发现视频文件: {len(file_list)}")
with ThreadPoolExecutor(max_workers=self.workers) as ex:
futures = {ex.submit(self._process_one, path): path for path in file_list}
for fut in as_completed(futures):
try:
fut.result()
except Exception as e:
logger.error(f"处理文件异常: {e}")
# -------------------------
def _process_one(self, path: str):
"""
获取文件大小、时间、hash快速并提交数据库线程。
"""
try:
st = os.stat(path)
except Exception as e:
logger.debug(f"无法读取文件 stat: {path}: {e}")
return
ext = os.path.splitext(path)[1].lower()
# 轻量快速 hash仅文件大小>1MB才计算
file_hash = ""
if st.st_size > 1_000_000:
file_hash = file_sha256(path)
else:
file_hash = f"SMALL-{st.st_size}-{int(st.st_mtime)}"
record = {
"file_path": path,
"file_hash": file_hash,
"file_size": st.st_size,
"file_mtime": st.st_mtime,
"created_at": datetime.now().isoformat(),
}
self.db_writer.enqueue_file(record)
# =====================================================
# 相似度检测与去重决策
# =====================================================
def phash_distance(h1: str, h2: str) -> int:
"""
计算两个 phash 字符串的汉明距离
"""
try:
b1 = int(str(h1), 16)
b2 = int(str(h2), 16)
x = b1 ^ b2
return bin(x).count("1")
except Exception:
return 128 # large
class DuplicateFinder:
"""
基于 DB 快照进行相似群组查找
"""
def __init__(self, db_path: str):
self.db_path = db_path
self.video_fp = VideoFingerprint()
def _read_files_from_db(self) -> List[Dict[str, Any]]:
out = []
try:
conn = sqlite3.connect(self.db_path, timeout=30)
cur = conn.cursor()
cur.execute("SELECT file_path, file_hash, file_size FROM files")
for row in cur.fetchall():
out.append({"path": row[0], "hash": row[1], "size": row[2]})
except Exception as e:
logger.warning(f"读取 DB 列表失败: {e}")
finally:
try:
conn.close()
except:
pass
return out
def group_by_name(self, files: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
groups = {}
for f in files:
key = Path(f["path"]).stem.lower()
# remove common tokens
key = re.sub(r"(1080p|720p|2160p|4k|x264|x265|h264|h265|hevc|bluray|web-dl|webdl|bdrip|brrip)", "", key)
key = re.sub(r"[\._\-]+", " ", key).strip()
groups.setdefault(key, []).append(f)
return [g for g in groups.values() if len(g) > 1]
def are_videos_similar(self, a: str, b: str, phash_thresh: int = 10) -> bool:
"""
判断两个视频是否相似
"""
try:
# 首先比较文件大小
sa = os.path.getsize(a)
sb = os.path.getsize(b)
if sa == sb:
return True
except Exception:
pass
# 使用视频指纹比较
try:
fp_a = self.video_fp.extract(a)
fp_b = self.video_fp.extract(b)
if fp_a and fp_b:
# 解析指纹字符串
parts_a = fp_a.split("_")[1:] # 去掉开头的"video_"
parts_b = fp_b.split("_")[1:]
if len(parts_a) == len(parts_b) and len(parts_a) > 0:
# 计算匹配的帧数
matches = 0
for x, y in zip(parts_a, parts_b):
dist = phash_distance(x, y)
if dist <= phash_thresh:
matches += 1
ratio = matches / len(parts_a)
if ratio >= 0.6: # 60%的帧相似
return True
except Exception as e:
logger.debug(f"视频指纹比较失败: {e}")
return False
def find_video_groups(self) -> List[List[Dict[str,Any]]]:
files = self._read_files_from_db()
name_groups = self.group_by_name(files)
result = []
for g in name_groups:
if len(g) <= 1:
continue
used = set()
for i in range(len(g)):
if i in used:
continue
base = g[i]
cluster = [base]
used.add(i)
for j in range(i+1, len(g)):
if j in used:
continue
try:
if self.are_videos_similar(base["path"], g[j]["path"]):
cluster.append(g[j])
used.add(j)
except Exception as e:
logger.debug(f"比较视频相似度失败: {e}")
pass
if len(cluster) > 1:
result.append(cluster)
logger.info(f"查找完成:发现 {len(result)} 视频候选组")
return result
# -------------------------
# VideoDeduplicator high-level operations
# -------------------------
class VideoDeduplicator:
def __init__(self, target_dirs: List[str], db_path: str="video_deduplicate.db", prefer_folder: Optional[str]=None, workers: int=0, auto_migrate: bool=True):
self.target_dirs = target_dirs
self.db_path = db_path
self.prefer_folder = prefer_folder
self.db_writer = DatabaseWriterThread(db_path=db_path, auto_migrate=auto_migrate)
# 启动写入线程
self.db_writer.start()
self.scanner = FileScanner(db_writer=self.db_writer, workers=workers)
self.finder = DuplicateFinder(db_path=self.db_path)
def scan_all(self):
for d in self.target_dirs:
self.scanner.scan(d)
def remove_groups(self, groups: List[List[Dict[str,Any]]], dry_run: bool=True, no_backup: bool=False) -> Tuple[List[str], List[str]]:
kept = []
deleted = []
for group in groups:
if not group:
continue
# choose keeper
keeper = None
if self.prefer_folder:
for f in group:
if self.prefer_folder in f["path"]:
keeper = f
break
if not keeper:
keeper = max(group, key=lambda x: x.get("size", 0))
kept.append(keeper["path"])
for f in group:
p = f["path"]
if p == keeper["path"]:
continue
if dry_run:
logger.info(f"[dry-run] 删除 {p} (保留 {keeper['path']})")
self.db_writer.enqueue_operation({
"operation_type": "planned_delete",
"file_path": p,
"file_hash": f.get("hash"),
"reason": "dry_run",
"details": None,
"created_at": datetime.now().isoformat()
})
deleted.append(p)
else:
ok = safe_remove(p, no_backup=no_backup, backup_dir=None, db_writer=self.db_writer)
if ok:
deleted.append(p)
else:
logger.info(f"跳过删除(可能为硬链接或权限问题): {p}")
return kept, deleted
def run_deduplication(self, dry_run: bool=True, no_backup: bool=False) -> Dict[str,Any]:
logger.info("开始视频去重")
self.scan_all()
logger.info("等待 db_writer 完成写入任务...")
# wait until queue is drained or timeout
start = time.time()
while not self.db_writer._queue.empty():
time.sleep(0.5)
if time.time() - start > 600:
logger.error("等待 db_writer 超过 600 秒,提前退出")
break
groups = self.finder.find_video_groups()
kept, deleted = self.remove_groups(groups, dry_run=dry_run, no_backup=no_backup)
return {"kept": kept, "deleted": deleted, "groups": len(groups)}
# =====================================================
# CLI & Main Function
# =====================================================
def parse_args():
parser = argparse.ArgumentParser(description="视频去重工具")
parser.add_argument(
"-d", "--dirs",
nargs="+",
required=True,
help="指定需要扫描的目录(一个或多个)"
)
parser.add_argument(
"--prefer",
type=str,
default=None,
help="优先保留的路径片段(如果匹配文件路径则优先保留)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="演示模式:仅显示将要删除的文件,不实际删除"
)
parser.add_argument(
"--no-backup",
action="store_true",
help="删除时不创建备份(谨慎)"
)
parser.add_argument(
"--workers",
type=int,
default=0,
help="扫描线程数(默认自动计算)"
)
parser.add_argument(
"--db",
type=str,
default="video_deduplicate.db",
help="使用的数据库文件"
)
parser.add_argument(
"--migrate",
action="store_true",
help="强制允许自动迁移数据库(锁死时会迁移)"
)
return parser.parse_args()
def main():
args = parse_args()
logger.info("==============================================")
logger.info(" 视频去重工具 (Video Deduplicator) ")
logger.info("==============================================")
logger.info(f"扫描目录:{args.dirs}")
logger.info(f"数据库文件:{args.db}")
logger.info(f"优先保留路径片段:{args.prefer}")
if args.dry_run:
logger.info("警告dry-run 模式(不会删除任何文件)")
if args.no_backup:
logger.warning("危险:已启用 --no-backup不会创建备份")
cleaner = VideoDeduplicator(
target_dirs=args.dirs,
db_path=args.db,
prefer_folder=args.prefer,
workers=args.workers,
auto_migrate=args.migrate,
)
result = None
try:
result = cleaner.run_deduplication(
dry_run=args.dry_run,
no_backup=args.no_backup,
)
except Exception as e:
logger.error(f"运行清理任务发生异常: {e}", exc_info=True)
finally:
# ensure writer shutdown
try:
cleaner.db_writer.stop()
cleaner.db_writer.join(timeout=10)
except Exception:
pass
logger.info("所有任务完成。")
if result is not None:
logger.info("========== 清理结果 ==========")
logger.info(f"保留文件数: {len(result['kept'])}")
logger.info(f"删除文件数: {len(result['deleted'])}")
logger.info(f"发现相似组: {result['groups']}")
if args.dry_run:
logger.info("\n将要删除的文件列表:")
for f in result['deleted']:
logger.info(f" {f}")
else:
logger.info("\n已删除的文件列表:")
for f in result['deleted']:
logger.info(f" {f}")
if __name__ == "__main__":
main()