331 lines
12 KiB
Python
331 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Proton Mail Organizer — classifies emails using a local LLM and applies labels via Proton Bridge."""
|
|
|
|
import argparse
|
|
import email
|
|
import email.header
|
|
import html
|
|
import imaplib
|
|
import json
|
|
import logging
|
|
import re
|
|
import sqlite3
|
|
import ssl
|
|
import sys
|
|
import urllib.request
|
|
import urllib.error
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
LOG_FMT = "%(asctime)s %(levelname)-8s %(message)s"
|
|
log = logging.getLogger("proton-organizer")
|
|
|
|
DB_PATH = Path(__file__).parent / "processed.db"
|
|
DEFAULT_CONFIG = Path(__file__).parent / "config.local.yaml"
|
|
|
|
|
|
# ── helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
def load_config(path: Path) -> dict:
|
|
with open(path) as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
def init_db(db_path: Path) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(db_path)
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS processed (
|
|
message_id TEXT PRIMARY KEY,
|
|
category TEXT NOT NULL,
|
|
processed_at TEXT NOT NULL
|
|
)
|
|
""")
|
|
conn.commit()
|
|
return conn
|
|
|
|
|
|
def is_processed(conn: sqlite3.Connection, message_id: str) -> bool:
|
|
return conn.execute(
|
|
"SELECT 1 FROM processed WHERE message_id = ?", (message_id,)
|
|
).fetchone() is not None
|
|
|
|
|
|
def mark_processed(conn: sqlite3.Connection, message_id: str, category: str):
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO processed (message_id, category, processed_at) VALUES (?, ?, ?)",
|
|
(message_id, category, datetime.now(tz=__import__('zoneinfo').ZoneInfo("UTC")).isoformat()),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def decode_header(raw: str | None) -> str:
|
|
if not raw:
|
|
return ""
|
|
parts = email.header.decode_header(raw)
|
|
decoded = []
|
|
for data, charset in parts:
|
|
if isinstance(data, bytes):
|
|
try:
|
|
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
|
except (LookupError, UnicodeDecodeError):
|
|
decoded.append(data.decode("utf-8", errors="replace"))
|
|
else:
|
|
decoded.append(data)
|
|
return " ".join(decoded)
|
|
|
|
|
|
def extract_text(msg: email.message.Message, max_chars: int) -> str:
|
|
body = ""
|
|
if msg.is_multipart():
|
|
for part in msg.walk():
|
|
ct = part.get_content_type()
|
|
if ct == "text/plain":
|
|
payload = part.get_payload(decode=True)
|
|
if payload:
|
|
charset = part.get_content_charset() or "utf-8"
|
|
body = payload.decode(charset, errors="replace")
|
|
break
|
|
elif ct == "text/html" and not body:
|
|
payload = part.get_payload(decode=True)
|
|
if payload:
|
|
charset = part.get_content_charset() or "utf-8"
|
|
body = html.unescape(re.sub(r"<[^>]+>", " ",
|
|
payload.decode(charset, errors="replace")))
|
|
else:
|
|
payload = msg.get_payload(decode=True)
|
|
if payload:
|
|
charset = msg.get_content_charset() or "utf-8"
|
|
body = payload.decode(charset, errors="replace")
|
|
if msg.get_content_type() == "text/html":
|
|
body = html.unescape(re.sub(r"<[^>]+>", " ", body))
|
|
|
|
return re.sub(r"\s+", " ", body).strip()[:max_chars]
|
|
|
|
|
|
# ── Proton Bridge IMAP ──────────────────────────────────────────────────────
|
|
|
|
class ProtonClient:
|
|
def __init__(self, email_addr: str, bridge_password: "REDACTED_PASSWORD"
|
|
host: str = "127.0.0.1", port: int = 1143):
|
|
self.email = email_addr
|
|
ctx = ssl.create_default_context()
|
|
ctx.check_hostname = False
|
|
ctx.verify_mode = ssl.CERT_NONE
|
|
self.conn = imaplib.IMAP4(host, port)
|
|
self.conn.starttls(ssl_context=ctx)
|
|
self.conn.login(email_addr, bridge_password)
|
|
|
|
def fetch_uids(self, mailbox: str = "INBOX", search: str = "ALL",
|
|
batch_size: int = 50) -> list[bytes]:
|
|
self.conn.select(mailbox)
|
|
_, data = self.conn.search(None, search)
|
|
uids = data[0].split()
|
|
return list(reversed(uids[-batch_size:]))
|
|
|
|
def fetch_message(self, uid: bytes) -> email.message.Message:
|
|
_, data = self.conn.fetch(uid, "(RFC822)")
|
|
return email.message_from_bytes(data[0][1])
|
|
|
|
def apply_label(self, uid: bytes, label: str):
|
|
"""Apply a label by copying the message to the label folder."""
|
|
try:
|
|
self.conn.create(label)
|
|
except imaplib.IMAP4.error:
|
|
pass
|
|
result = self.conn.copy(uid, label)
|
|
if result[0] != "OK":
|
|
log.warning("Failed to copy to label %s: %s", label, result)
|
|
|
|
def archive(self, uid: bytes):
|
|
"""Archive: move from INBOX to Archive folder."""
|
|
self.conn.copy(uid, "Archive")
|
|
self.conn.store(uid, "+FLAGS", "\\Deleted")
|
|
self.conn.expunge()
|
|
|
|
def close(self):
|
|
try:
|
|
self.conn.close()
|
|
self.conn.logout()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ── Ollama LLM ───────────────────────────────────────────────────────────────
|
|
|
|
def classify_email(ollama_url, model, categories, subject, sender, body_snippet):
|
|
cat_descriptions = "\n".join(
|
|
f"- **{name}**: {info['description']}" for name, info in categories.items()
|
|
)
|
|
category_names = ", ".join(categories.keys())
|
|
|
|
prompt = f"""Classify this email into exactly ONE category. Reply with ONLY the category name, nothing else.
|
|
|
|
Categories:
|
|
{cat_descriptions}
|
|
|
|
Email:
|
|
From: {sender}
|
|
Subject: {subject}
|
|
Body: {body_snippet[:1000]}
|
|
|
|
Reply with one of: {category_names}"""
|
|
|
|
payload = json.dumps({
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {"temperature": 0.1, "num_predict": 20},
|
|
}).encode()
|
|
|
|
req = urllib.request.Request(
|
|
f"{ollama_url.rstrip('/')}/api/generate",
|
|
data=payload,
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=60) as resp:
|
|
result = json.loads(resp.read())
|
|
except urllib.error.URLError as e:
|
|
log.error("Ollama request failed: %s", e)
|
|
raise
|
|
|
|
raw_response = result.get("response", "").strip().lower()
|
|
raw_response = re.sub(r"<think>.*?</think>", "", raw_response, flags=re.DOTALL).strip()
|
|
for name in categories:
|
|
if name in raw_response:
|
|
return name
|
|
|
|
log.warning("LLM returned unexpected category %r, defaulting to 'personal'", raw_response)
|
|
return "personal"
|
|
|
|
|
|
# ── main ─────────────────────────────────────────────────────────────────────
|
|
|
|
def run(config_path, dry_run=False, reprocess=False, limit=None):
|
|
cfg = load_config(config_path)
|
|
proton_cfg = cfg["proton"]
|
|
ollama_cfg = cfg["ollama"]
|
|
categories = cfg["categories"]
|
|
proc_cfg = cfg.get("processing", {})
|
|
|
|
batch_size = limit or proc_cfg.get("batch_size", 50)
|
|
max_body = proc_cfg.get("max_body_chars", 2000)
|
|
dry_run = dry_run or proc_cfg.get("dry_run", False)
|
|
mailbox = proc_cfg.get("mailbox", "INBOX")
|
|
rules = cfg.get("rules", [])
|
|
|
|
log.info("Connecting to Proton Bridge as %s", proton_cfg["email"])
|
|
client = ProtonClient(
|
|
proton_cfg["email"],
|
|
proton_cfg["bridge_password"],
|
|
host=proton_cfg.get("host", "127.0.0.1"),
|
|
port=proton_cfg.get("port", 1143),
|
|
)
|
|
db = init_db(DB_PATH)
|
|
|
|
try:
|
|
uids = client.fetch_uids(mailbox=mailbox, batch_size=batch_size)
|
|
log.info("Fetched %d message UIDs", len(uids))
|
|
|
|
stats = {cat: 0 for cat in categories}
|
|
stats["rules"] = 0
|
|
stats["skipped"] = 0
|
|
stats["errors"] = 0
|
|
|
|
for i, uid in enumerate(uids, 1):
|
|
try:
|
|
msg = client.fetch_message(uid)
|
|
message_id = msg.get("Message-ID", f"uid-{uid.decode()}")
|
|
subject = decode_header(msg.get("Subject"))
|
|
sender = decode_header(msg.get("From"))
|
|
|
|
if not reprocess and is_processed(db, message_id):
|
|
stats["skipped"] += 1
|
|
continue
|
|
|
|
# Check sender-based rules before LLM
|
|
rule_matched = False
|
|
for rule in rules:
|
|
for pattern in rule["senders"]:
|
|
if pattern.lower() in sender.lower():
|
|
folder = rule["folder"]
|
|
category = rule.get("category", "personal")
|
|
log.info("[%d/%d] Rule match: %s (from: %s) → %s",
|
|
i, len(uids), subject[:60], sender[:40], folder)
|
|
if not dry_run:
|
|
client.apply_label(uid, folder)
|
|
mark_processed(db, message_id, category)
|
|
else:
|
|
log.info(" [DRY RUN] Would move to: %s", folder)
|
|
stats["rules"] += 1
|
|
rule_matched = True
|
|
break
|
|
if rule_matched:
|
|
break
|
|
if rule_matched:
|
|
continue
|
|
|
|
body = extract_text(msg, max_body)
|
|
log.info("[%d/%d] Classifying: %s (from: %s)",
|
|
i, len(uids), subject[:60], sender[:40])
|
|
|
|
category = classify_email(
|
|
ollama_cfg["url"], ollama_cfg["model"],
|
|
categories, subject, sender, body,
|
|
)
|
|
label = categories[category]["label"]
|
|
log.info(" → %s (%s)", category, label)
|
|
|
|
should_archive = categories[category].get("archive", False)
|
|
|
|
if not dry_run:
|
|
client.apply_label(uid, label)
|
|
if should_archive:
|
|
client.archive(uid)
|
|
log.info(" 📥 Archived")
|
|
mark_processed(db, message_id, category)
|
|
else:
|
|
log.info(" [DRY RUN] Would apply label: %s%s", label,
|
|
" + archive" if should_archive else "")
|
|
|
|
stats[category] = stats.get(category, 0) + 1
|
|
|
|
except Exception as e:
|
|
log.error("Error processing UID %s: %s", uid, e)
|
|
stats["errors"] += 1
|
|
continue
|
|
|
|
log.info("Done! Stats: %s", json.dumps(stats, indent=2))
|
|
|
|
finally:
|
|
client.close()
|
|
db.close()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Proton Mail Organizer — LLM-powered email classification")
|
|
parser.add_argument("-c", "--config", type=Path, default=DEFAULT_CONFIG)
|
|
parser.add_argument("-n", "--dry-run", action="store_true")
|
|
parser.add_argument("--reprocess", action="store_true")
|
|
parser.add_argument("--limit", type=int, default=None)
|
|
parser.add_argument("-v", "--verbose", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
format=LOG_FMT,
|
|
)
|
|
|
|
if not args.config.exists():
|
|
log.error("Config not found: %s", args.config)
|
|
sys.exit(1)
|
|
|
|
run(args.config, dry_run=args.dry_run, reprocess=args.reprocess, limit=args.limit)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|