You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
3.3 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. /**
  6. * The agent provider for the FireworksAI provider.
  7. * We wrap FireworksAI in UnTooled because its tool-calling may not be supported for specific models and this normalizes that.
  8. */
  9. class FireworksAIProvider extends InheritMultiple([Provider, UnTooled]) {
  10. model;
  11. constructor(config = {}) {
  12. const { model = "accounts/fireworks/models/llama-v3p1-8b-instruct" } =
  13. config;
  14. super();
  15. const client = new OpenAI({
  16. baseURL: "https://api.fireworks.ai/inference/v1",
  17. apiKey: process.env.FIREWORKS_AI_LLM_API_KEY,
  18. maxRetries: 0,
  19. });
  20. this._client = client;
  21. this.model = model;
  22. this.verbose = true;
  23. }
  24. get client() {
  25. return this._client;
  26. }
  27. async #handleFunctionCallChat({ messages = [] }) {
  28. return await this.client.chat.completions
  29. .create({
  30. model: this.model,
  31. temperature: 0,
  32. messages,
  33. })
  34. .then((result) => {
  35. if (!result.hasOwnProperty("choices"))
  36. throw new Error("FireworksAI chat: No results!");
  37. if (result.choices.length === 0)
  38. throw new Error("FireworksAI chat: No results length!");
  39. return result.choices[0].message.content;
  40. })
  41. .catch((_) => {
  42. return null;
  43. });
  44. }
  45. /**
  46. * Create a completion based on the received messages.
  47. *
  48. * @param messages A list of messages to send to the API.
  49. * @param functions
  50. * @returns The completion.
  51. */
  52. async complete(messages, functions = null) {
  53. try {
  54. let completion;
  55. if (functions.length > 0) {
  56. const { toolCall, text } = await this.functionCall(
  57. messages,
  58. functions,
  59. this.#handleFunctionCallChat.bind(this)
  60. );
  61. if (toolCall !== null) {
  62. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  63. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  64. return {
  65. result: null,
  66. functionCall: {
  67. name: toolCall.name,
  68. arguments: toolCall.arguments,
  69. },
  70. cost: 0,
  71. };
  72. }
  73. completion = { content: text };
  74. }
  75. if (!completion?.content) {
  76. this.providerLog(
  77. "Will assume chat completion without tool call inputs."
  78. );
  79. const response = await this.client.chat.completions.create({
  80. model: this.model,
  81. messages: this.cleanMsgs(messages),
  82. });
  83. completion = response.choices[0].message;
  84. }
  85. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  86. // from calling the exact same function over and over in a loop within a single chat exchange
  87. // _but_ we should enable it to call previously used tools in a new chat interaction.
  88. this.deduplicator.reset("runs");
  89. return {
  90. result: completion.content,
  91. cost: 0,
  92. };
  93. } catch (error) {
  94. throw error;
  95. }
  96. }
  97. /**
  98. * Get the cost of the completion.
  99. *
  100. * @param _usage The completion to get the cost for.
  101. * @returns The cost of the completion.
  102. */
  103. getCost(_usage) {
  104. return 0;
  105. }
  106. }
  107. module.exports = FireworksAIProvider;