You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.2 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. /**
  6. * The agent provider for the KoboldCPP provider.
  7. */
  8. class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
  9. model;
  10. constructor(_config = {}) {
  11. super();
  12. const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
  13. const client = new OpenAI({
  14. baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
  15. apiKey: null,
  16. maxRetries: 3,
  17. });
  18. this._client = client;
  19. this.model = model;
  20. this.verbose = true;
  21. }
  22. get client() {
  23. return this._client;
  24. }
  25. async #handleFunctionCallChat({ messages = [] }) {
  26. return await this.client.chat.completions
  27. .create({
  28. model: this.model,
  29. temperature: 0,
  30. messages,
  31. })
  32. .then((result) => {
  33. if (!result.hasOwnProperty("choices"))
  34. throw new Error("KoboldCPP chat: No results!");
  35. if (result.choices.length === 0)
  36. throw new Error("KoboldCPP chat: No results length!");
  37. return result.choices[0].message.content;
  38. })
  39. .catch((_) => {
  40. return null;
  41. });
  42. }
  43. /**
  44. * Create a completion based on the received messages.
  45. *
  46. * @param messages A list of messages to send to the API.
  47. * @param functions
  48. * @returns The completion.
  49. */
  50. async complete(messages, functions = null) {
  51. try {
  52. let completion;
  53. if (functions.length > 0) {
  54. const { toolCall, text } = await this.functionCall(
  55. messages,
  56. functions,
  57. this.#handleFunctionCallChat.bind(this)
  58. );
  59. if (toolCall !== null) {
  60. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  61. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  62. return {
  63. result: null,
  64. functionCall: {
  65. name: toolCall.name,
  66. arguments: toolCall.arguments,
  67. },
  68. cost: 0,
  69. };
  70. }
  71. completion = { content: text };
  72. }
  73. if (!completion?.content) {
  74. this.providerLog(
  75. "Will assume chat completion without tool call inputs."
  76. );
  77. const response = await this.client.chat.completions.create({
  78. model: this.model,
  79. messages: this.cleanMsgs(messages),
  80. });
  81. completion = response.choices[0].message;
  82. }
  83. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  84. // from calling the exact same function over and over in a loop within a single chat exchange
  85. // _but_ we should enable it to call previously used tools in a new chat interaction.
  86. this.deduplicator.reset("runs");
  87. return {
  88. result: completion.content,
  89. cost: 0,
  90. };
  91. } catch (error) {
  92. throw error;
  93. }
  94. }
  95. /**
  96. * Get the cost of the completion.
  97. *
  98. * @param _usage The completion to get the cost for.
  99. * @returns The cost of the completion.
  100. * Stubbed since KoboldCPP has no cost basis.
  101. */
  102. getCost(_usage) {
  103. return 0;
  104. }
  105. }
  106. module.exports = KoboldCPPProvider;