You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.2 KiB

11 months ago
  1. const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. /**
  6. * The agent provider for the Azure OpenAI API.
  7. */
  8. class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
  9. model;
  10. constructor(_config = {}) {
  11. super();
  12. const client = new OpenAIClient(
  13. process.env.AZURE_OPENAI_ENDPOINT,
  14. new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
  15. );
  16. this._client = client;
  17. this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
  18. this.verbose = true;
  19. }
  20. get client() {
  21. return this._client;
  22. }
  23. async #handleFunctionCallChat({ messages = [] }) {
  24. return await this.client
  25. .getChatCompletions(this.model, messages, {
  26. temperature: 0,
  27. })
  28. .then((result) => {
  29. if (!result.hasOwnProperty("choices"))
  30. throw new Error("Azure OpenAI chat: No results!");
  31. if (result.choices.length === 0)
  32. throw new Error("Azure OpenAI chat: No results length!");
  33. return result.choices[0].message.content;
  34. })
  35. .catch((_) => {
  36. return null;
  37. });
  38. }
  39. /**
  40. * Create a completion based on the received messages.
  41. *
  42. * @param messages A list of messages to send to the API.
  43. * @param functions
  44. * @returns The completion.
  45. */
  46. async complete(messages, functions = null) {
  47. try {
  48. let completion;
  49. if (functions.length > 0) {
  50. const { toolCall, text } = await this.functionCall(
  51. messages,
  52. functions,
  53. this.#handleFunctionCallChat.bind(this)
  54. );
  55. if (toolCall !== null) {
  56. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  57. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  58. return {
  59. result: null,
  60. functionCall: {
  61. name: toolCall.name,
  62. arguments: toolCall.arguments,
  63. },
  64. cost: 0,
  65. };
  66. }
  67. completion = { content: text };
  68. }
  69. if (!completion?.content) {
  70. this.providerLog(
  71. "Will assume chat completion without tool call inputs."
  72. );
  73. const response = await this.client.getChatCompletions(
  74. this.model,
  75. this.cleanMsgs(messages),
  76. {
  77. temperature: 0.7,
  78. }
  79. );
  80. completion = response.choices[0].message;
  81. }
  82. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  83. // from calling the exact same function over and over in a loop within a single chat exchange
  84. // _but_ we should enable it to call previously used tools in a new chat interaction.
  85. this.deduplicator.reset("runs");
  86. return { result: completion.content, cost: 0 };
  87. } catch (error) {
  88. throw error;
  89. }
  90. }
  91. /**
  92. * Get the cost of the completion.
  93. * Stubbed since Azure OpenAI has no public cost basis.
  94. *
  95. * @param _usage The completion to get the cost for.
  96. * @returns The cost of the completion.
  97. */
  98. getCost(_usage) {
  99. return 0;
  100. }
  101. }
  102. module.exports = AzureOpenAiProvider;