You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

116 lines
3.2 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. /**
  6. * The agent provider for the Oobabooga provider.
  7. */
  8. class TextWebGenUiProvider extends InheritMultiple([Provider, UnTooled]) {
  9. model;
  10. constructor(_config = {}) {
  11. super();
  12. const client = new OpenAI({
  13. baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH,
  14. apiKey: process.env.TEXT_GEN_WEB_UI_API_KEY ?? null,
  15. maxRetries: 3,
  16. });
  17. this._client = client;
  18. this.model = null; // text-web-gen-ui does not have a model pref.
  19. this.verbose = true;
  20. }
  21. get client() {
  22. return this._client;
  23. }
  24. async #handleFunctionCallChat({ messages = [] }) {
  25. return await this.client.chat.completions
  26. .create({
  27. model: this.model,
  28. temperature: 0,
  29. messages,
  30. })
  31. .then((result) => {
  32. if (!result.hasOwnProperty("choices"))
  33. throw new Error("Oobabooga chat: No results!");
  34. if (result.choices.length === 0)
  35. throw new Error("Oobabooga chat: No results length!");
  36. return result.choices[0].message.content;
  37. })
  38. .catch((_) => {
  39. return null;
  40. });
  41. }
  42. /**
  43. * Create a completion based on the received messages.
  44. *
  45. * @param messages A list of messages to send to the API.
  46. * @param functions
  47. * @returns The completion.
  48. */
  49. async complete(messages, functions = null) {
  50. try {
  51. let completion;
  52. if (functions.length > 0) {
  53. const { toolCall, text } = await this.functionCall(
  54. messages,
  55. functions,
  56. this.#handleFunctionCallChat.bind(this)
  57. );
  58. if (toolCall !== null) {
  59. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  60. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  61. return {
  62. result: null,
  63. functionCall: {
  64. name: toolCall.name,
  65. arguments: toolCall.arguments,
  66. },
  67. cost: 0,
  68. };
  69. }
  70. completion = { content: text };
  71. }
  72. if (!completion?.content) {
  73. this.providerLog(
  74. "Will assume chat completion without tool call inputs."
  75. );
  76. const response = await this.client.chat.completions.create({
  77. model: this.model,
  78. messages: this.cleanMsgs(messages),
  79. });
  80. completion = response.choices[0].message;
  81. }
  82. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  83. // from calling the exact same function over and over in a loop within a single chat exchange
  84. // _but_ we should enable it to call previously used tools in a new chat interaction.
  85. this.deduplicator.reset("runs");
  86. return {
  87. result: completion.content,
  88. cost: 0,
  89. };
  90. } catch (error) {
  91. throw error;
  92. }
  93. }
  94. /**
  95. * Get the cost of the completion.
  96. *
  97. * @param _usage The completion to get the cost for.
  98. * @returns The cost of the completion.
  99. * Stubbed since KoboldCPP has no cost basis.
  100. */
  101. getCost(_usage) {
  102. return 0;
  103. }
  104. }
  105. module.exports = TextWebGenUiProvider;