You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.4 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. /**
  6. * The agent provider for the Mistral provider.
  7. * Mistral limits what models can call tools and even when using those
  8. * the model names change and dont match docs. When you do have the right model
  9. * it still fails and is not truly OpenAI compatible so its easier to just wrap
  10. * this with Untooled which 100% works since its just text & works far more reliably
  11. */
  12. class MistralProvider extends InheritMultiple([Provider, UnTooled]) {
  13. model;
  14. constructor(config = {}) {
  15. super();
  16. const { model = "mistral-medium" } = config;
  17. const client = new OpenAI({
  18. baseURL: "https://api.mistral.ai/v1",
  19. apiKey: process.env.MISTRAL_API_KEY,
  20. maxRetries: 3,
  21. });
  22. this._client = client;
  23. this.model = model;
  24. this.verbose = true;
  25. }
  26. get client() {
  27. return this._client;
  28. }
  29. async #handleFunctionCallChat({ messages = [] }) {
  30. return await this.client.chat.completions
  31. .create({
  32. model: this.model,
  33. temperature: 0,
  34. messages,
  35. })
  36. .then((result) => {
  37. if (!result.hasOwnProperty("choices"))
  38. throw new Error("LMStudio chat: No results!");
  39. if (result.choices.length === 0)
  40. throw new Error("LMStudio chat: No results length!");
  41. return result.choices[0].message.content;
  42. })
  43. .catch((_) => {
  44. return null;
  45. });
  46. }
  47. /**
  48. * Create a completion based on the received messages.
  49. *
  50. * @param messages A list of messages to send to the API.
  51. * @param functions
  52. * @returns The completion.
  53. */
  54. async complete(messages, functions = null) {
  55. try {
  56. let completion;
  57. if (functions.length > 0) {
  58. const { toolCall, text } = await this.functionCall(
  59. messages,
  60. functions,
  61. this.#handleFunctionCallChat.bind(this)
  62. );
  63. if (toolCall !== null) {
  64. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  65. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  66. return {
  67. result: null,
  68. functionCall: {
  69. name: toolCall.name,
  70. arguments: toolCall.arguments,
  71. },
  72. cost: 0,
  73. };
  74. }
  75. completion = { content: text };
  76. }
  77. if (!completion?.content) {
  78. this.providerLog(
  79. "Will assume chat completion without tool call inputs."
  80. );
  81. const response = await this.client.chat.completions.create({
  82. model: this.model,
  83. messages: this.cleanMsgs(messages),
  84. });
  85. completion = response.choices[0].message;
  86. }
  87. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  88. // from calling the exact same function over and over in a loop within a single chat exchange
  89. // _but_ we should enable it to call previously used tools in a new chat interaction.
  90. this.deduplicator.reset("runs");
  91. return {
  92. result: completion.content,
  93. cost: 0,
  94. };
  95. } catch (error) {
  96. throw error;
  97. }
  98. }
  99. /**
  100. * Get the cost of the completion.
  101. *
  102. * @param _usage The completion to get the cost for.
  103. * @returns The cost of the completion.
  104. */
  105. getCost(_usage) {
  106. return 0;
  107. }
  108. }
  109. module.exports = MistralProvider;