Files
image-drop/app/app.py
MEGASOL\simon.adams e6f055ec5b .
2025-08-26 13:50:28 +02:00

353 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Immich Drop Uploader Backend (FastAPI, simplified)
----------------------------------------------------
- Serves static frontend (no settings UI)
- Uploads to Immich using values from .env ONLY
- Duplicate checks (local SHA-1 DB + optional Immich bulk-check)
- WebSocket progress per session
- Ephemeral "Connected" banner via /api/ping
"""
from __future__ import annotations
import asyncio
import io
import json
import hashlib
import os
import sqlite3
from datetime import datetime
from typing import Dict, List, Optional
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder, MultipartEncoderMonitor
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect, Request, Form
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.websockets import WebSocketState
from PIL import Image, ExifTags
from dotenv import load_dotenv
from app.config import Settings, load_settings
# ---- Load environment / defaults ----
load_dotenv()
HOST = os.getenv("HOST", "127.0.0.1")
PORT = int(os.getenv("PORT", "8080"))
STATE_DB = os.getenv("STATE_DB", "./state.db")
# ---- App & static ----
app = FastAPI(title="Immich Drop Uploader (Python)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
FRONTEND_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend")
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
# Global settings (read-only at runtime)
SETTINGS: Settings = load_settings()
# ---------- DB (local dedupe cache) ----------
def db_init() -> None:
"""Create the local SQLite table used for duplicate checks (idempotent)."""
conn = sqlite3.connect(STATE_DB)
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS uploads (
id INTEGER PRIMARY KEY AUTOINCREMENT,
checksum TEXT UNIQUE,
filename TEXT,
size INTEGER,
device_asset_id TEXT,
immich_asset_id TEXT,
created_at TEXT,
inserted_at TEXT DEFAULT CURRENT_TIMESTAMP
);
"""
)
conn.commit()
conn.close()
def db_lookup_checksum(checksum: str) -> Optional[dict]:
"""Return a record for the given checksum if seen before (None if not)."""
conn = sqlite3.connect(STATE_DB)
cur = conn.cursor()
cur.execute("SELECT checksum, immich_asset_id FROM uploads WHERE checksum = ?", (checksum,))
row = cur.fetchone()
conn.close()
if row:
return {"checksum": row[0], "immich_asset_id": row[1]}
return None
def db_lookup_device_asset(device_asset_id: str) -> bool:
"""True if a deviceAssetId has been uploaded by this service previously."""
conn = sqlite3.connect(STATE_DB)
cur = conn.cursor()
cur.execute("SELECT 1 FROM uploads WHERE device_asset_id = ?", (device_asset_id,))
row = cur.fetchone()
conn.close()
return bool(row)
def db_insert_upload(checksum: str, filename: str, size: int, device_asset_id: str, immich_asset_id: Optional[str], created_at: str) -> None:
"""Insert a newly-uploaded asset into the local cache (ignore on duplicates)."""
conn = sqlite3.connect(STATE_DB)
cur = conn.cursor()
cur.execute(
"INSERT OR IGNORE INTO uploads (checksum, filename, size, device_asset_id, immich_asset_id, created_at) VALUES (?,?,?,?,?,?)",
(checksum, filename, size, device_asset_id, immich_asset_id, created_at)
)
conn.commit()
conn.close()
db_init()
# ---------- WebSocket hub ----------
class SessionHub:
"""Holds WebSocket connections per session and broadcasts progress updates."""
def __init__(self) -> None:
self.sessions: Dict[str, List[WebSocket]] = {}
async def connect(self, session_id: str, ws: WebSocket) -> None:
"""Register a newly accepted WebSocket under the given session id."""
self.sessions.setdefault(session_id, []).append(ws)
def _cleanup_closed(self, session_id: str) -> None:
"""Drop closed sockets and cleanup empty session buckets."""
if session_id not in self.sessions:
return
self.sessions[session_id] = [w for w in self.sessions[session_id] if w.client_state == WebSocketState.CONNECTED]
if not self.sessions[session_id]:
del self.sessions[session_id]
async def send(self, session_id: str, payload: dict) -> None:
"""Broadcast a JSON payload to all sockets for one session."""
conns = self.sessions.get(session_id, [])
for ws in list(conns):
try:
await ws.send_text(json.dumps(payload))
except Exception:
try:
await ws.close()
except Exception:
pass
self._cleanup_closed(session_id)
async def disconnect(self, session_id: str, ws: WebSocket) -> None:
"""Remove a socket from the hub and close it (best-effort)."""
try:
await ws.close()
finally:
if session_id in self.sessions and ws in self.sessions[session_id]:
self.sessions[session_id].remove(ws)
self._cleanup_closed(session_id)
hub = SessionHub()
# ---------- Helpers ----------
def sha1_hex(file_bytes: bytes) -> str:
"""Return SHA-1 hex digest of file_bytes."""
h = hashlib.sha1()
h.update(file_bytes)
return h.hexdigest()
def read_exif_datetimes(file_bytes: bytes):
"""
Extract EXIF DateTimeOriginal / ModifyDate values when possible.
Returns (created, modified) as datetime or (None, None) on failure.
"""
created = modified = None
try:
with Image.open(io.BytesIO(file_bytes)) as im:
exif = getattr(im, "_getexif", lambda: None)() or {}
if exif:
tags = {ExifTags.TAGS.get(k, k): v for k, v in exif.items()}
dt_original = tags.get("DateTimeOriginal") or tags.get("CreateDate")
dt_modified = tags.get("ModifyDate") or dt_original
def parse_dt(s: str):
try:
return datetime.strptime(s, "%Y:%m:%d %H:%M:%S")
except Exception:
return None
if isinstance(dt_original, str):
created = parse_dt(dt_original)
if isinstance(dt_modified, str):
modified = parse_dt(dt_modified)
except Exception:
pass
return created, modified
def immich_headers() -> dict:
"""Headers for Immich API calls (keeps key server-side)."""
return {"Accept": "application/json", "x-api-key": SETTINGS.immich_api_key}
def immich_ping() -> bool:
"""Best-effort reachability check against a few Immich endpoints."""
if not SETTINGS.immich_api_key:
return False
base = SETTINGS.normalized_base_url
for path in ("/server-info", "/server/version", "/users/me"):
try:
r = requests.get(f"{base}{path}", headers=immich_headers(), timeout=4)
if 200 <= r.status_code < 400:
return True
except Exception:
continue
return False
def immich_bulk_check(checks: List[dict]) -> Dict[str, dict]:
"""Try Immich bulk upload check; return map id->result (or empty on failure)."""
try:
url = f"{SETTINGS.normalized_base_url}/assets/bulk-upload-check"
r = requests.post(url, headers=immich_headers(), json={"assets": checks}, timeout=10)
if r.status_code == 200:
results = r.json().get("results", [])
return {x["id"]: x for x in results}
except Exception:
pass
return {}
async def send_progress(session_id: str, item_id: str, status: str, progress: int = 0, message: str = "", response_id: Optional[str] = None) -> None:
"""Push a progress update over WebSocket for one queue item."""
await hub.send(session_id, {
"item_id": item_id,
"status": status,
"progress": progress,
"message": message,
"responseId": response_id,
})
# ---------- Routes ----------
@app.get("/", response_class=HTMLResponse)
async def index(_: Request) -> HTMLResponse:
"""Serve the SPA (frontend/index.html)."""
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
@app.post("/api/ping")
async def api_ping() -> dict:
"""Connectivity test endpoint used by the UI to display a temporary banner."""
return {"ok": immich_ping(), "base_url": SETTINGS.normalized_base_url}
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket) -> None:
"""WebSocket endpoint for pushing per-item upload progress."""
await ws.accept()
try:
init = await ws.receive_text()
data = json.loads(init)
session_id = data.get("session_id") or "default"
except Exception:
session_id = "default"
await hub.connect(session_id, ws)
# keepalive to avoid proxy idle timeouts
try:
while True:
msg_task = asyncio.create_task(ws.receive_text())
keep_task = asyncio.create_task(asyncio.sleep(30))
done, pending = await asyncio.wait({msg_task, keep_task}, return_when=asyncio.FIRST_COMPLETED)
if msg_task in done:
_ = msg_task.result()
else:
await ws.send_text('{"type":"ping"}')
for t in pending:
t.cancel()
except WebSocketDisconnect:
await hub.disconnect(session_id, ws)
@app.post("/api/upload")
async def api_upload(
_: Request,
file: UploadFile,
item_id: str = Form(...),
session_id: str = Form(...),
last_modified: Optional[int] = Form(None),
):
"""Receive a file, check duplicates, forward to Immich; stream progress via WS."""
raw = await file.read()
size = len(raw)
checksum = sha1_hex(raw)
exif_created, exif_modified = read_exif_datetimes(raw)
created_at = exif_created or (datetime.fromtimestamp(last_modified / 1000) if last_modified else datetime.utcnow())
modified_at = exif_modified or created_at
created_iso = created_at.isoformat()
modified_iso = modified_at.isoformat()
device_asset_id = f"{file.filename}-{last_modified or 0}-{size}"
if db_lookup_checksum(checksum):
await send_progress(session_id, item_id, "duplicate", 100, "Duplicate (by checksum - local cache)")
return JSONResponse({"status": "duplicate", "id": None}, status_code=200)
if db_lookup_device_asset(device_asset_id):
await send_progress(session_id, item_id, "duplicate", 100, "Already uploaded from this device (local cache)")
return JSONResponse({"status": "duplicate", "id": None}, status_code=200)
await send_progress(session_id, item_id, "checking", 2, "Checking duplicates…")
bulk = immich_bulk_check([{"id": item_id, "checksum": checksum}])
if bulk.get(item_id, {}).get("action") == "reject" and bulk[item_id].get("reason") == "duplicate":
asset_id = bulk[item_id].get("assetId")
db_insert_upload(checksum, file.filename, size, device_asset_id, asset_id, created_iso)
await send_progress(session_id, item_id, "duplicate", 100, "Duplicate (server)", asset_id)
return JSONResponse({"status": "duplicate", "id": asset_id}, status_code=200)
def gen_encoder() -> MultipartEncoder:
return MultipartEncoder(fields={
"assetData": (file.filename, io.BytesIO(raw), file.content_type or "application/octet-stream"),
"deviceAssetId": device_asset_id,
"deviceId": f"python-{session_id}",
"fileCreatedAt": created_iso,
"fileModifiedAt": modified_iso,
"isFavorite": "false",
"filename": file.filename,
})
encoder = gen_encoder()
async def do_upload():
await send_progress(session_id, item_id, "uploading", 0, "Uploading…")
sent = {"pct": 0}
def cb(monitor: MultipartEncoderMonitor) -> None:
if monitor.len:
pct = int(monitor.bytes_read * 100 / monitor.len)
if pct != sent["pct"]:
sent["pct"] = pct
asyncio.create_task(send_progress(session_id, item_id, "uploading", pct))
monitor = MultipartEncoderMonitor(encoder, cb)
headers = {"Accept": "application/json", "Content-Type": monitor.content_type, "x-immich-checksum": checksum, **immich_headers()}
try:
r = requests.post(f"{SETTINGS.normalized_base_url}/assets", headers=headers, data=monitor, timeout=120)
if r.status_code in (200, 201):
data = r.json()
asset_id = data.get("id")
db_insert_upload(checksum, file.filename, size, device_asset_id, asset_id, created_iso)
status = data.get("status", "created")
await send_progress(session_id, item_id, "duplicate" if status == "duplicate" else "done", 100, status, asset_id)
return JSONResponse({"id": asset_id, "status": status}, status_code=200)
else:
try:
msg = r.json().get("message", r.text)
except Exception:
msg = r.text
await send_progress(session_id, item_id, "error", 100, msg)
return JSONResponse({"error": msg}, status_code=400)
except Exception as e:
await send_progress(session_id, item_id, "error", 100, str(e))
return JSONResponse({"error": str(e)}, status_code=500)
return await do_upload()
if __name__ == "__main__":
import uvicorn
uvicorn.run("backend.main:app", host=HOST, port=PORT, reload=True)