/*
 * Decompiled with CFR 0.152.
 */
package io.openliberty.mcp.internal;

import com.ibm.websphere.ras.Tr;
import com.ibm.websphere.ras.TraceComponent;
import com.ibm.websphere.ras.annotation.InjectedTrace;
import com.ibm.websphere.ras.annotation.TraceObjectField;
import com.ibm.websphere.ras.annotation.TraceOptions;
import com.ibm.ws.ffdc.FFDCFilter;
import com.ibm.ws.ffdc.annotation.FFDCIgnore;
import com.ibm.ws.kernel.service.util.ServiceCaller;
import com.ibm.ws.ras.instrument.annotation.InjectedFFDC;
import io.openliberty.mcp.content.Content;
import io.openliberty.mcp.internal.Capabilities;
import io.openliberty.mcp.internal.McpProtocolVersion;
import io.openliberty.mcp.internal.McpRequestTracker;
import io.openliberty.mcp.internal.McpSession;
import io.openliberty.mcp.internal.McpSessionStore;
import io.openliberty.mcp.internal.McpTransport;
import io.openliberty.mcp.internal.RequestMethod;
import io.openliberty.mcp.internal.ToolDescription;
import io.openliberty.mcp.internal.ToolMetadata;
import io.openliberty.mcp.internal.ToolRegistry;
import io.openliberty.mcp.internal.ToolResult;
import io.openliberty.mcp.internal.config.McpConfiguration;
import io.openliberty.mcp.internal.exceptions.jsonrpc.HttpResponseException;
import io.openliberty.mcp.internal.exceptions.jsonrpc.JSONRPCErrorCode;
import io.openliberty.mcp.internal.exceptions.jsonrpc.JSONRPCException;
import io.openliberty.mcp.internal.requests.CancellationImpl;
import io.openliberty.mcp.internal.requests.ExecutionRequestId;
import io.openliberty.mcp.internal.requests.McpInitializeParams;
import io.openliberty.mcp.internal.requests.McpNotificationParams;
import io.openliberty.mcp.internal.requests.McpRequestId;
import io.openliberty.mcp.internal.requests.McpToolCallParams;
import io.openliberty.mcp.internal.responses.McpInitializeResult;
import io.openliberty.mcp.messaging.Cancellation;
import io.openliberty.mcp.tools.ToolCallException;
import io.openliberty.mcp.tools.ToolResponse;
import jakarta.enterprise.context.spi.CreationalContext;
import jakarta.enterprise.inject.spi.BeanManager;
import jakarta.inject.Inject;
import jakarta.json.bind.Jsonb;
import jakarta.json.bind.JsonbBuilder;
import jakarta.servlet.ServletConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletionStage;

@TraceObjectField(fieldName="tc", fieldDesc="Lcom/ibm/websphere/ras/TraceComponent;")
@InjectedFFDC
@TraceOptions
public class McpServlet
extends HttpServlet {
    private static final long serialVersionUID = 1L;
    private static final TraceComponent tc = Tr.register(McpServlet.class, (String)"MCP", (String)"io.openliberty.mcp.internal.resources.CWMCM");
    private static final ServiceCaller<McpConfiguration> mcpConfigService = new ServiceCaller(McpServlet.class, McpConfiguration.class);
    private Jsonb jsonb;
    @Inject
    BeanManager bm;
    @Inject
    McpSessionStore sessionStore;
    @Inject
    McpRequestTracker requestTracker;

    public void init(ServletConfig config) throws ServletException {
        super.init(config);
        this.jsonb = JsonbBuilder.create();
    }

    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        McpTransport transport = new McpTransport(req, resp, this.jsonb);
        String excpetionMessage = Tr.formatMessage((TraceComponent)tc, (String)"CWMCM0009I.get.disallowed", (Object[])new Object[0]);
        HttpResponseException e = new HttpResponseException(405, excpetionMessage).withHeader("Allow", "POST");
        transport.sendHttpException(e);
    }

    @FFDCIgnore(value={JSONRPCException.class, HttpResponseException.class})
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException, JSONRPCException {
        McpTransport transport = new McpTransport(req, resp, this.jsonb);
        try {
            McpSession session;
            Boolean stateless = mcpConfigService.run(config -> {
                boolean s = config.isStateless();
                return s;
            }).orElse(false);
            transport.init(this.sessionStore);
            RequestMethod method = transport.getMcpRequest().getRequestMethod();
            if (!stateless.booleanValue() && method != RequestMethod.INITIALIZE && method != RequestMethod.PING && (session = transport.getSession()) == null) {
                throw new HttpResponseException(400, "Missing Mcp-Session-Id header");
            }
            this.callRequest(transport);
        }
        catch (JSONRPCException e) {
            transport.sendJsonRpcException(e);
        }
        catch (HttpResponseException e) {
            transport.sendHttpException(e);
        }
        catch (Exception e) {
            FFDCFilter.processException((Throwable)e, (String)"io.openliberty.mcp.internal.McpServlet", (String)"119", (Object)((Object)this), (Object[])new Object[]{req, resp});
            transport.sendError(e);
        }
    }

    protected void callRequest(McpTransport transport) throws JSONRPCException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, IOException {
        RequestMethod method = transport.getMcpRequest().getRequestMethod();
        switch (method) {
            case TOOLS_CALL: {
                this.callTool(transport);
                break;
            }
            case TOOLS_LIST: {
                this.listTools(transport);
                break;
            }
            case INITIALIZE: {
                this.initialize(transport);
                break;
            }
            case INITIALIZED: {
                this.initialized(transport);
                break;
            }
            case PING: {
                this.ping(transport);
                break;
            }
            case CANCELLED: {
                this.cancelRequest(transport);
                break;
            }
            default: {
                throw new JSONRPCException(JSONRPCErrorCode.METHOD_NOT_FOUND, List.of(String.valueOf(String.valueOf((Object)method) + " not found")));
            }
        }
    }

    protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        boolean stateless = Boolean.TRUE.equals(mcpConfigService.run(McpConfiguration::isStateless).orElse(false));
        if (stateless) {
            resp.sendError(404, "Session not found");
            return;
        }
        String sessionId = req.getHeader("Mcp-Session-Id");
        if (sessionId == null) {
            resp.sendError(400, "Missing Mcp-Session-Id");
            return;
        }
        if (this.sessionStore.isValid(sessionId)) {
            McpSession session = this.sessionStore.getSession(sessionId);
            if (session != null) {
                this.requestTracker.cancelSessionRequests(sessionId);
            }
            resp.setStatus(200);
        } else {
            resp.sendError(404, "Session not found");
        }
    }

    @FFDCIgnore(value={IllegalAccessException.class, IllegalArgumentException.class})
    private void callTool(McpTransport transport) {
        ExecutionRequestId requestId = this.createOngoingRequestId(transport);
        McpToolCallParams params = transport.getParams(McpToolCallParams.class);
        if (requestId != null && this.requestTracker.isOngoingRequest(requestId)) {
            throw new JSONRPCException(JSONRPCErrorCode.INVALID_PARAMS, Tr.formatMessage((TraceComponent)tc, (String)"CWMCM0008E.invalid.request.params", (Object[])new Object[]{requestId.id()}));
        }
        try {
            if (params.getMetadata() == null) {
                if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
                    Tr.event((Object)((Object)this), (TraceComponent)tc, (String)("Attempt to call non-existant tool: " + params.getName()), (Object[])new Object[0]);
                }
                throw new JSONRPCException(JSONRPCErrorCode.INVALID_PARAMS, List.of("Method " + params.getName() + " not found"));
            }
            if (params.getMetadata().returnsCompletionStage()) {
                this.callToolMethodAndSendResponseAsync(transport, requestId, params, params.getMethod());
            } else {
                this.callToolSynchronously(transport, requestId, params, params.getMethod());
            }
        }
        catch (IllegalAccessException e) {
            throw new JSONRPCException(JSONRPCErrorCode.INTERNAL_ERROR, List.of("Could not call " + params.getName()));
        }
        catch (IllegalArgumentException e) {
            throw new JSONRPCException(JSONRPCErrorCode.INVALID_PARAMS, List.of("Incorrect arguments in params"));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @FFDCIgnore(value={JSONRPCException.class, InvocationTargetException.class})
    private void callToolSynchronously(McpTransport transport, ExecutionRequestId requestId, McpToolCallParams params, Method method) throws IllegalAccessException, IllegalArgumentException {
        CreationalContext cc = this.bm.createCreationalContext(null);
        if (requestId != null) {
            CancellationImpl cancellation = new CancellationImpl();
            cancellation.setRequestId(requestId);
            this.requestTracker.registerOngoingRequest(requestId, cancellation);
        }
        try {
            Object bean = this.bm.getReference(params.getBean(), (Type)params.getBean().getBeanClass(), cc);
            Object[] arguments = params.getArguments(this.jsonb);
            this.addSpecialArguments(arguments, requestId, params.getMetadata());
            if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
                Tr.event((Object)((Object)this), (TraceComponent)tc, (String)("Calling tool " + params.getMetadata().name()), (Object[])arguments);
            }
            Object result = method.invoke(bean, arguments);
            transport.sendResponse(this.evaluateToolResponse(result, params));
        }
        catch (JSONRPCException e) {
            throw e;
        }
        catch (InvocationTargetException e) {
            Throwable t = e.getCause();
            if (this.isBusinessException(t, params)) {
                transport.sendResponse(this.toErrorResponse(e.getCause()));
            } else {
                Tr.error((TraceComponent)tc, (String)"CWMCM0010E.internal.server.error.detailed", (Object[])new Object[]{params.getMetadata().name(), e.getCause()});
                transport.sendResponse(ToolResponse.error((String)Tr.formatMessage((TraceComponent)tc, (String)"CWMCM0011E.internal.server.error", (Object[])new Object[0])));
            }
        }
        finally {
            this.cleanup(requestId, (CreationalContext<Object>)cc, params);
        }
    }

    @FFDCIgnore(value={InvocationTargetException.class})
    private void callToolMethodAndSendResponseAsync(McpTransport transport, ExecutionRequestId requestId, McpToolCallParams params, Method method) throws IllegalAccessException, IllegalArgumentException {
        CreationalContext cc = this.bm.createCreationalContext(null);
        if (requestId != null) {
            CancellationImpl cancellation = new CancellationImpl();
            cancellation.setRequestId(requestId);
            this.requestTracker.registerOngoingRequest(requestId, cancellation);
        }
        try {
            Object bean = this.bm.getReference(params.getBean(), (Type)params.getBean().getBeanClass(), cc);
            Object[] arguments = params.getArguments(this.jsonb);
            this.addSpecialArguments(arguments, requestId, params.getMetadata());
            if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
                Tr.event((Object)((Object)this), (TraceComponent)tc, (String)("Calling tool " + params.getMetadata().name()), (Object[])arguments);
            }
            CompletionStage<Object> stage = (CompletionStage<Object>)method.invoke(bean, arguments);
            stage = stage.thenApply(result -> this.evaluateToolResponse(result, params)).exceptionally(throwable -> {
                Tr.error((TraceComponent)tc, (String)"CWMCM0010E.internal.server.error.detailed", (Object[])new Object[]{params.getMetadata().name(), throwable.getCause()});
                return ToolResponse.error((String)Tr.formatMessage((TraceComponent)tc, (String)"CWMCM0011E.internal.server.error", (Object[])new Object[0]));
            });
            transport.sendResultAsync(stage).whenComplete((result, throwable) -> this.cleanup(requestId, (CreationalContext<Object>)cc, params));
        }
        catch (InvocationTargetException e) {
            Throwable t = e.getCause();
            if (this.isBusinessException(t, params)) {
                transport.sendResponse(this.toErrorResponse(e.getCause()));
            } else {
                Tr.error((TraceComponent)tc, (String)"CWMCM0010E.internal.server.error.detailed", (Object[])new Object[]{params.getMetadata().name(), e.getCause()});
                transport.sendResponse(ToolResponse.error((String)Tr.formatMessage((TraceComponent)tc, (String)"CWMCM0011E.internal.server.error", (Object[])new Object[0])));
            }
            this.cleanup(requestId, (CreationalContext<Object>)cc, params);
        }
    }

    /*
     * WARNING - void declaration
     */
    private void cleanup(ExecutionRequestId requestId, CreationalContext<Object> cc, McpToolCallParams params) {
        if (requestId != null && this.requestTracker.isOngoingRequest(requestId)) {
            this.requestTracker.deregisterOngoingRequest(requestId);
        }
        try {
            cc.release();
        }
        catch (Exception exception) {
            void ex;
            FFDCFilter.processException((Throwable)exception, (String)"io.openliberty.mcp.internal.McpServlet", (String)"302", (Object)((Object)this), (Object[])new Object[]{requestId, cc, params});
            Tr.warning((TraceComponent)tc, (String)"CWMCM0012E.bean.release.fail", (Object[])new Object[]{ex, params.getName()});
        }
    }

    private Object evaluateToolResponse(Object result, McpToolCallParams params) {
        List list;
        boolean includeStructuredContent = params.getMetadata().annotation().structuredContent();
        if (result instanceof ToolResponse) {
            ToolResponse response = (ToolResponse)result;
            return response;
        }
        if (result instanceof List && !(list = (List)result).isEmpty() && list.stream().allMatch(item -> item instanceof Content)) {
            List contents = list;
            return ToolResponse.success((List)contents);
        }
        if (result instanceof Content) {
            Content content = (Content)result;
            return ToolResponse.success((Content[])new Content[]{content});
        }
        if (result instanceof String) {
            String s = (String)result;
            return ToolResponse.success((String)s);
        }
        if (includeStructuredContent) {
            return ToolResponse.structuredSuccess((String)this.jsonb.toJson(result), (Object)result);
        }
        return ToolResponse.success((String)Objects.toString(result));
    }

    private boolean isBusinessException(Throwable t, McpToolCallParams params) {
        if (t instanceof ToolCallException) {
            return true;
        }
        if (params != null && params.getMetadata() != null) {
            for (Class<? extends Throwable> clazz : params.getMetadata().businessExceptions()) {
                if (!clazz.isAssignableFrom(t.getClass())) continue;
                return true;
            }
        }
        return false;
    }

    private ToolResponse toErrorResponse(Throwable t) {
        String msg = t.getMessage() != null ? t.getMessage() : t.getClass().getSimpleName();
        return ToolResponse.error((String)msg);
    }

    private void addSpecialArguments(Object[] argumentsArray, ExecutionRequestId requestId, ToolMetadata toolMetadata) {
        for (ToolMetadata.SpecialArgumentMetadata argMetadata : toolMetadata.specialArguments()) {
            switch (argMetadata.typeResolution().specialArgsType()) {
                case CANCELLATION: {
                    CancellationImpl cancellation = requestId != null ? (CancellationImpl)this.requestTracker.getOngoingRequestCancellation(requestId) : new CancellationImpl();
                    argumentsArray[argMetadata.index()] = cancellation;
                    break;
                }
            }
        }
    }

    private void listTools(McpTransport transport) throws IOException {
        ToolRegistry toolRegistry = ToolRegistry.get();
        LinkedList<ToolDescription> response = new LinkedList<ToolDescription>();
        if (toolRegistry.hasTools()) {
            for (ToolMetadata tmd : toolRegistry.getAllTools()) {
                response.add(new ToolDescription(tmd));
            }
            ToolResult toolResult = new ToolResult(response);
            transport.sendResponse(toolResult);
        }
    }

    @FFDCIgnore(value={NoSuchElementException.class})
    private void initialize(McpTransport transport) throws IOException {
        McpProtocolVersion version;
        McpInitializeParams params = transport.getParams(McpInitializeParams.class);
        try {
            version = McpProtocolVersion.parse(params.getProtocolVersion());
        }
        catch (NoSuchElementException e) {
            version = McpProtocolVersion.V_2025_06_18;
        }
        if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
            Tr.event((Object)((Object)this), (TraceComponent)tc, (String)("Client initializing: " + String.valueOf(params.getClientInfo())), (Object[])new Object[]{params.getCapabilities()});
        }
        String sessionId = this.sessionStore.createSession();
        Capabilities.ServerCapabilities caps = Capabilities.ServerCapabilities.of(new Capabilities.Tools(false));
        McpInitializeResult.ServerInfo info = new McpInitializeResult.ServerInfo("test-server", "Test Server", "0.1");
        McpInitializeResult result = new McpInitializeResult(version, caps, info, null);
        transport.setResponseHeader("Mcp-Session-Id", sessionId);
        transport.sendResponse(result);
    }

    private void initialized(McpTransport transport) {
        if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
            Tr.event((Object)((Object)this), (TraceComponent)tc, (String)"Client initialized", (Object[])new Object[0]);
        }
        transport.sendEmptyResponse();
    }

    private void ping(McpTransport transport) {
        transport.sendResponse(new Object());
    }

    private void cancelRequest(McpTransport transport) {
        Cancellation cancellation;
        McpNotificationParams notificationParams = transport.getMcpRequest().getParams(McpNotificationParams.class, this.jsonb);
        McpRequestId mcpReqId = notificationParams.getRequestId();
        String sessionId = transport.getSessionId();
        if (sessionId == null) {
            transport.sendEmptyResponse();
            return;
        }
        ExecutionRequestId requestId = new ExecutionRequestId(mcpReqId, sessionId);
        Optional<String> reason = Optional.ofNullable(notificationParams.getReason());
        if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
            Tr.event((Object)((Object)this), (TraceComponent)tc, (String)("Cancellation requested for " + String.valueOf(requestId)), (Object[])new Object[0]);
        }
        if ((cancellation = this.requestTracker.getOngoingRequestCancellation(requestId)) != null) {
            if (TraceComponent.isAnyTracingEnabled() && tc.isEventEnabled()) {
                Tr.event((Object)((Object)this), (TraceComponent)tc, (String)"Cancelling task", (Object[])new Object[0]);
            }
            ((CancellationImpl)cancellation).cancel(reason);
        }
        transport.sendEmptyResponse();
    }

    private ExecutionRequestId createOngoingRequestId(McpTransport transport) {
        String sessionId = transport.getSessionId();
        if (sessionId != null) {
            return new ExecutionRequestId(transport.getMcpRequest().id(), sessionId);
        }
        return null;
    }
}

