import { Grid } from '@northvolt/ui'
import type { Message as TMessage } from 'client/model'
import PlottingToolOutput from './PlottingToolOutput'
import SearchToolUseCard from './SearchToolUseCard'
import WombatQueryCard from './WombatQueryCard'

// This thing exists because the LangChain has tool input/output as separate messages
export type ToolCallWithOutput = {
  id: string
  name: string
  created_at?: string
  args: Record<string, string> //Map<string, string>
  // One of the following should be present
  output_text?: string
  output_json?: any
}

interface ToolUseCardProps {
  toolMessages: TMessage[]
}

export default function ToolUseSection({ toolMessages }: ToolUseCardProps) {
  const fullSizeToolNames = ['Plotting']
  const joinedToolData = joinToolInputOutput(toolMessages)
  const modalCardTools = joinedToolData.filter(
    (toolData: ToolCallWithOutput) =>
      !fullSizeToolNames.includes(toolData.name),
  )
  const fullSizeOutputTools = joinedToolData.filter(
    (toolData: ToolCallWithOutput) => fullSizeToolNames.includes(toolData.name),
  )
  const toolGridColumns = Math.min(4, modalCardTools.length || 1)

  return (
    <Grid container columns={toolGridColumns} spacing={1}>
      {modalCardTools.map((toolData: ToolCallWithOutput, index) => {
        if (toolData.name === 'WombatQuery') {
          return (
            <Grid xs={1} key={index} sx={{ py: 1 }}>
              <WombatQueryCard key={index} toolData={toolData} />
            </Grid>
          )
        } else {
          return (
            <Grid xs={1} key={index} sx={{ py: 1 }}>
              <SearchToolUseCard key={index} toolData={toolData} />
            </Grid>
          )
        }
      })}
      {fullSizeOutputTools.map((toolData: ToolCallWithOutput, index) => {
        if (toolData.name === 'Plotting') {
          return (
            <Grid xs={12} key={index} sx={{ py: 1 }}>
              <PlottingToolOutput key={index} toolData={toolData} />
            </Grid>
          )
        }
      })}
    </Grid>
  )
}

function joinToolInputOutput(toolMessages: TMessage[]): ToolCallWithOutput[] {
  const toolCalls: any[] = toolMessages.flatMap(msg => msg.tool_calls || []) // any type because langchain types are completely broken.
  const toolOutputs = toolMessages.filter(msg => msg.role === 'tool')
  const joinedResults = []
  for (const call of toolCalls) {
    const output = toolOutputs.find(
      msg => msg.tool_call_id === call.id || msg.tool_call_id === call.call_id,
    ) // :'(
    const joinedResult: ToolCallWithOutput = {
      id: call.id,
      name: call.name,
      args: call.args,
    }
    if (output) {
      joinedResult.created_at = output.created_at?.toString()
      try {
        joinedResult.output_json = JSON.parse(output.content)
      } catch (e) {
        joinedResult.output_text = output.content
      }
    }
    joinedResults.push(joinedResult)
  }
  return joinedResults
}
