from flask import Flask, request, Response, jsonify
import requests
import os
import json
import logging

app = Flask(__name__)
logging.basicConfig(level=logging.INFO)

MISTRAL_URL = "https://api.mistral.ai/v1"
API_KEY = os.getenv("MISTRAL_API_KEY")


# 🔴 STRICT OPENAI-COMPATIBLE FIELD WHITELIST
ALLOWED_FIELDS = {
    "model",
    "messages",
    "temperature",
    "top_p",
    "max_tokens",
    "stream",
    "tools",
    "tool_choice",
    "stop",
    "seed",
    "frequency_penalty",
    "presence_penalty"
}


def clean_request(data: dict):
    """Remove ALL non-Mistral fields safely"""

    cleaned = {k: v for k, v in data.items() if k in ALLOWED_FIELDS}

    # Ensure messages are untouched (critical for tool calling)
    if "messages" in cleaned:
        cleaned["messages"] = [
            {k: v for k, v in m.items() if k in {"role", "content", "tool_calls", "tool_call_id", "name"}}
            for m in cleaned["messages"]
        ]

    return cleaned


@app.route("/v1/chat/completions", methods=["POST"])
def chat():
    raw = request.json

    logging.info("➡️ REQUEST RECEIVED")

    data = clean_request(raw)

    logging.info("🧹 CLEAN REQUEST SENT TO MISTRAL")

    r = requests.post(
        f"{MISTRAL_URL}/chat/completions",
        headers={
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json",
        },
        json=data,
        stream=data.get("stream", False),
    )

    # ---------------- STREAMING ----------------
    if data.get("stream", False):

        def generate():
            for line in r.iter_lines():
                if not line:
                    continue

                decoded = line.decode("utf-8")

                if decoded.startswith("data:"):
                    payload = decoded[5:].strip()

                    if payload == "[DONE]":
                        yield "data: [DONE]\n\n"
                        break

                    try:
                        yield f"data: {payload}\n\n"
                    except:
                        continue

                else:
                    yield decoded + "\n\n"

        return Response(generate(), content_type="text/event-stream")

    # ---------------- NON-STREAM ----------------
    return jsonify(r.json()), r.status_code


@app.route("/v1/models", methods=["GET"])
def models():
    r = requests.get(
        f"{MISTRAL_URL}/models",
        headers={"Authorization": f"Bearer {API_KEY}"}
    )
    return jsonify(r.json()), r.status_code


@app.route("/health")
def health():
    return {"status": "ok"}


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8880)