.
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
352
app/app.py
Normal file
352
app/app.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Immich Drop Uploader – Backend (FastAPI, simplified)
|
||||
----------------------------------------------------
|
||||
- Serves static frontend (no settings UI)
|
||||
- Uploads to Immich using values from .env ONLY
|
||||
- Duplicate checks (local SHA-1 DB + optional Immich bulk-check)
|
||||
- WebSocket progress per session
|
||||
- Ephemeral "Connected" banner via /api/ping
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from requests_toolbelt.multipart.encoder import MultipartEncoder, MultipartEncoderMonitor
|
||||
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect, Request, Form
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.websockets import WebSocketState
|
||||
from PIL import Image, ExifTags
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.config import Settings, load_settings
|
||||
|
||||
# ---- Load environment / defaults ----
|
||||
load_dotenv()
|
||||
HOST = os.getenv("HOST", "127.0.0.1")
|
||||
PORT = int(os.getenv("PORT", "8080"))
|
||||
STATE_DB = os.getenv("STATE_DB", "./state.db")
|
||||
|
||||
# ---- App & static ----
|
||||
app = FastAPI(title="Immich Drop Uploader (Python)")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
FRONTEND_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend")
|
||||
app.mount("/static", StaticFiles(directory=FRONTEND_DIR), name="static")
|
||||
|
||||
# Global settings (read-only at runtime)
|
||||
SETTINGS: Settings = load_settings()
|
||||
|
||||
# ---------- DB (local dedupe cache) ----------
|
||||
|
||||
def db_init() -> None:
|
||||
"""Create the local SQLite table used for duplicate checks (idempotent)."""
|
||||
conn = sqlite3.connect(STATE_DB)
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS uploads (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
checksum TEXT UNIQUE,
|
||||
filename TEXT,
|
||||
size INTEGER,
|
||||
device_asset_id TEXT,
|
||||
immich_asset_id TEXT,
|
||||
created_at TEXT,
|
||||
inserted_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def db_lookup_checksum(checksum: str) -> Optional[dict]:
|
||||
"""Return a record for the given checksum if seen before (None if not)."""
|
||||
conn = sqlite3.connect(STATE_DB)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT checksum, immich_asset_id FROM uploads WHERE checksum = ?", (checksum,))
|
||||
row = cur.fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
return {"checksum": row[0], "immich_asset_id": row[1]}
|
||||
return None
|
||||
|
||||
def db_lookup_device_asset(device_asset_id: str) -> bool:
|
||||
"""True if a deviceAssetId has been uploaded by this service previously."""
|
||||
conn = sqlite3.connect(STATE_DB)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT 1 FROM uploads WHERE device_asset_id = ?", (device_asset_id,))
|
||||
row = cur.fetchone()
|
||||
conn.close()
|
||||
return bool(row)
|
||||
|
||||
def db_insert_upload(checksum: str, filename: str, size: int, device_asset_id: str, immich_asset_id: Optional[str], created_at: str) -> None:
|
||||
"""Insert a newly-uploaded asset into the local cache (ignore on duplicates)."""
|
||||
conn = sqlite3.connect(STATE_DB)
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"INSERT OR IGNORE INTO uploads (checksum, filename, size, device_asset_id, immich_asset_id, created_at) VALUES (?,?,?,?,?,?)",
|
||||
(checksum, filename, size, device_asset_id, immich_asset_id, created_at)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
db_init()
|
||||
|
||||
# ---------- WebSocket hub ----------
|
||||
|
||||
class SessionHub:
|
||||
"""Holds WebSocket connections per session and broadcasts progress updates."""
|
||||
def __init__(self) -> None:
|
||||
self.sessions: Dict[str, List[WebSocket]] = {}
|
||||
|
||||
async def connect(self, session_id: str, ws: WebSocket) -> None:
|
||||
"""Register a newly accepted WebSocket under the given session id."""
|
||||
self.sessions.setdefault(session_id, []).append(ws)
|
||||
|
||||
def _cleanup_closed(self, session_id: str) -> None:
|
||||
"""Drop closed sockets and cleanup empty session buckets."""
|
||||
if session_id not in self.sessions:
|
||||
return
|
||||
self.sessions[session_id] = [w for w in self.sessions[session_id] if w.client_state == WebSocketState.CONNECTED]
|
||||
if not self.sessions[session_id]:
|
||||
del self.sessions[session_id]
|
||||
|
||||
async def send(self, session_id: str, payload: dict) -> None:
|
||||
"""Broadcast a JSON payload to all sockets for one session."""
|
||||
conns = self.sessions.get(session_id, [])
|
||||
for ws in list(conns):
|
||||
try:
|
||||
await ws.send_text(json.dumps(payload))
|
||||
except Exception:
|
||||
try:
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._cleanup_closed(session_id)
|
||||
|
||||
async def disconnect(self, session_id: str, ws: WebSocket) -> None:
|
||||
"""Remove a socket from the hub and close it (best-effort)."""
|
||||
try:
|
||||
await ws.close()
|
||||
finally:
|
||||
if session_id in self.sessions and ws in self.sessions[session_id]:
|
||||
self.sessions[session_id].remove(ws)
|
||||
self._cleanup_closed(session_id)
|
||||
|
||||
hub = SessionHub()
|
||||
|
||||
# ---------- Helpers ----------
|
||||
|
||||
def sha1_hex(file_bytes: bytes) -> str:
|
||||
"""Return SHA-1 hex digest of file_bytes."""
|
||||
h = hashlib.sha1()
|
||||
h.update(file_bytes)
|
||||
return h.hexdigest()
|
||||
|
||||
def read_exif_datetimes(file_bytes: bytes):
|
||||
"""
|
||||
Extract EXIF DateTimeOriginal / ModifyDate values when possible.
|
||||
Returns (created, modified) as datetime or (None, None) on failure.
|
||||
"""
|
||||
created = modified = None
|
||||
try:
|
||||
with Image.open(io.BytesIO(file_bytes)) as im:
|
||||
exif = getattr(im, "_getexif", lambda: None)() or {}
|
||||
if exif:
|
||||
tags = {ExifTags.TAGS.get(k, k): v for k, v in exif.items()}
|
||||
dt_original = tags.get("DateTimeOriginal") or tags.get("CreateDate")
|
||||
dt_modified = tags.get("ModifyDate") or dt_original
|
||||
def parse_dt(s: str):
|
||||
try:
|
||||
return datetime.strptime(s, "%Y:%m:%d %H:%M:%S")
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(dt_original, str):
|
||||
created = parse_dt(dt_original)
|
||||
if isinstance(dt_modified, str):
|
||||
modified = parse_dt(dt_modified)
|
||||
except Exception:
|
||||
pass
|
||||
return created, modified
|
||||
|
||||
def immich_headers() -> dict:
|
||||
"""Headers for Immich API calls (keeps key server-side)."""
|
||||
return {"Accept": "application/json", "x-api-key": SETTINGS.immich_api_key}
|
||||
|
||||
def immich_ping() -> bool:
|
||||
"""Best-effort reachability check against a few Immich endpoints."""
|
||||
if not SETTINGS.immich_api_key:
|
||||
return False
|
||||
base = SETTINGS.normalized_base_url
|
||||
for path in ("/server-info", "/server/version", "/users/me"):
|
||||
try:
|
||||
r = requests.get(f"{base}{path}", headers=immich_headers(), timeout=4)
|
||||
if 200 <= r.status_code < 400:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
def immich_bulk_check(checks: List[dict]) -> Dict[str, dict]:
|
||||
"""Try Immich bulk upload check; return map id->result (or empty on failure)."""
|
||||
try:
|
||||
url = f"{SETTINGS.normalized_base_url}/assets/bulk-upload-check"
|
||||
r = requests.post(url, headers=immich_headers(), json={"assets": checks}, timeout=10)
|
||||
if r.status_code == 200:
|
||||
results = r.json().get("results", [])
|
||||
return {x["id"]: x for x in results}
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
async def send_progress(session_id: str, item_id: str, status: str, progress: int = 0, message: str = "", response_id: Optional[str] = None) -> None:
|
||||
"""Push a progress update over WebSocket for one queue item."""
|
||||
await hub.send(session_id, {
|
||||
"item_id": item_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"responseId": response_id,
|
||||
})
|
||||
|
||||
# ---------- Routes ----------
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(_: Request) -> HTMLResponse:
|
||||
"""Serve the SPA (frontend/index.html)."""
|
||||
return FileResponse(os.path.join(FRONTEND_DIR, "index.html"))
|
||||
|
||||
@app.post("/api/ping")
|
||||
async def api_ping() -> dict:
|
||||
"""Connectivity test endpoint used by the UI to display a temporary banner."""
|
||||
return {"ok": immich_ping(), "base_url": SETTINGS.normalized_base_url}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def ws_endpoint(ws: WebSocket) -> None:
|
||||
"""WebSocket endpoint for pushing per-item upload progress."""
|
||||
await ws.accept()
|
||||
try:
|
||||
init = await ws.receive_text()
|
||||
data = json.loads(init)
|
||||
session_id = data.get("session_id") or "default"
|
||||
except Exception:
|
||||
session_id = "default"
|
||||
await hub.connect(session_id, ws)
|
||||
|
||||
# keepalive to avoid proxy idle timeouts
|
||||
try:
|
||||
while True:
|
||||
msg_task = asyncio.create_task(ws.receive_text())
|
||||
keep_task = asyncio.create_task(asyncio.sleep(30))
|
||||
done, pending = await asyncio.wait({msg_task, keep_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
if msg_task in done:
|
||||
_ = msg_task.result()
|
||||
else:
|
||||
await ws.send_text('{"type":"ping"}')
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
except WebSocketDisconnect:
|
||||
await hub.disconnect(session_id, ws)
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def api_upload(
|
||||
_: Request,
|
||||
file: UploadFile,
|
||||
item_id: str = Form(...),
|
||||
session_id: str = Form(...),
|
||||
last_modified: Optional[int] = Form(None),
|
||||
):
|
||||
"""Receive a file, check duplicates, forward to Immich; stream progress via WS."""
|
||||
raw = await file.read()
|
||||
size = len(raw)
|
||||
checksum = sha1_hex(raw)
|
||||
|
||||
exif_created, exif_modified = read_exif_datetimes(raw)
|
||||
created_at = exif_created or (datetime.fromtimestamp(last_modified / 1000) if last_modified else datetime.utcnow())
|
||||
modified_at = exif_modified or created_at
|
||||
created_iso = created_at.isoformat()
|
||||
modified_iso = modified_at.isoformat()
|
||||
|
||||
device_asset_id = f"{file.filename}-{last_modified or 0}-{size}"
|
||||
|
||||
if db_lookup_checksum(checksum):
|
||||
await send_progress(session_id, item_id, "duplicate", 100, "Duplicate (by checksum - local cache)")
|
||||
return JSONResponse({"status": "duplicate", "id": None}, status_code=200)
|
||||
if db_lookup_device_asset(device_asset_id):
|
||||
await send_progress(session_id, item_id, "duplicate", 100, "Already uploaded from this device (local cache)")
|
||||
return JSONResponse({"status": "duplicate", "id": None}, status_code=200)
|
||||
|
||||
await send_progress(session_id, item_id, "checking", 2, "Checking duplicates…")
|
||||
bulk = immich_bulk_check([{"id": item_id, "checksum": checksum}])
|
||||
if bulk.get(item_id, {}).get("action") == "reject" and bulk[item_id].get("reason") == "duplicate":
|
||||
asset_id = bulk[item_id].get("assetId")
|
||||
db_insert_upload(checksum, file.filename, size, device_asset_id, asset_id, created_iso)
|
||||
await send_progress(session_id, item_id, "duplicate", 100, "Duplicate (server)", asset_id)
|
||||
return JSONResponse({"status": "duplicate", "id": asset_id}, status_code=200)
|
||||
|
||||
def gen_encoder() -> MultipartEncoder:
|
||||
return MultipartEncoder(fields={
|
||||
"assetData": (file.filename, io.BytesIO(raw), file.content_type or "application/octet-stream"),
|
||||
"deviceAssetId": device_asset_id,
|
||||
"deviceId": f"python-{session_id}",
|
||||
"fileCreatedAt": created_iso,
|
||||
"fileModifiedAt": modified_iso,
|
||||
"isFavorite": "false",
|
||||
"filename": file.filename,
|
||||
})
|
||||
|
||||
encoder = gen_encoder()
|
||||
|
||||
async def do_upload():
|
||||
await send_progress(session_id, item_id, "uploading", 0, "Uploading…")
|
||||
sent = {"pct": 0}
|
||||
def cb(monitor: MultipartEncoderMonitor) -> None:
|
||||
if monitor.len:
|
||||
pct = int(monitor.bytes_read * 100 / monitor.len)
|
||||
if pct != sent["pct"]:
|
||||
sent["pct"] = pct
|
||||
asyncio.create_task(send_progress(session_id, item_id, "uploading", pct))
|
||||
monitor = MultipartEncoderMonitor(encoder, cb)
|
||||
headers = {"Accept": "application/json", "Content-Type": monitor.content_type, "x-immich-checksum": checksum, **immich_headers()}
|
||||
try:
|
||||
r = requests.post(f"{SETTINGS.normalized_base_url}/assets", headers=headers, data=monitor, timeout=120)
|
||||
if r.status_code in (200, 201):
|
||||
data = r.json()
|
||||
asset_id = data.get("id")
|
||||
db_insert_upload(checksum, file.filename, size, device_asset_id, asset_id, created_iso)
|
||||
status = data.get("status", "created")
|
||||
await send_progress(session_id, item_id, "duplicate" if status == "duplicate" else "done", 100, status, asset_id)
|
||||
return JSONResponse({"id": asset_id, "status": status}, status_code=200)
|
||||
else:
|
||||
try:
|
||||
msg = r.json().get("message", r.text)
|
||||
except Exception:
|
||||
msg = r.text
|
||||
await send_progress(session_id, item_id, "error", 100, msg)
|
||||
return JSONResponse({"error": msg}, status_code=400)
|
||||
except Exception as e:
|
||||
await send_progress(session_id, item_id, "error", 100, str(e))
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
return await do_upload()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("backend.main:app", host=HOST, port=PORT, reload=True)
|
||||
31
app/config.py
Normal file
31
app/config.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Config loader for the Immich Drop Uploader (Python).
|
||||
Reads ONLY from .env; there is NO runtime mutation from the UI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""App settings loaded from environment variables (.env)."""
|
||||
immich_base_url: str
|
||||
immich_api_key: str
|
||||
max_concurrent: int = 3
|
||||
|
||||
@property
|
||||
def normalized_base_url(self) -> str:
|
||||
"""Return the base URL without a trailing slash for clean joining and display."""
|
||||
return self.immich_base_url.rstrip("/")
|
||||
|
||||
def load_settings() -> Settings:
|
||||
"""Load settings from .env, applying defaults when absent."""
|
||||
base = os.getenv("IMMICH_BASE_URL", "http://127.0.0.1:2283/api")
|
||||
api_key = os.getenv("IMMICH_API_KEY", "")
|
||||
try:
|
||||
maxc = int(os.getenv("MAX_CONCURRENT", "3"))
|
||||
except ValueError:
|
||||
maxc = 3
|
||||
return Settings(immich_base_url=base, immich_api_key=api_key, max_concurrent=maxc)
|
||||
Reference in New Issue
Block a user