FEAT: Adding general websocket target class and accompanying unit tests#1351
FEAT: Adding general websocket target class and accompanying unit tests#1351kmarsh77 wants to merge 1 commit intoAzure:mainfrom
Conversation
| logger.info("Successfully connected to websocket") | ||
| return websocket | ||
|
|
||
| async def send_message(self, message: str, conversation_id: str) -> None: |
There was a problem hiding this comment.
nit: rename to send_message_async
| self._existing_conversation = existing_convo if existing_convo is not None else {} | ||
| self._websockets_kwargs = websockets_kwargs or {} | ||
|
|
||
| async def connect(self) -> Any: |
| endpoint (str): the target endpoint | ||
| initialization_strings (List[str]): These are the connection/initialization strings that must be sent after connecting to websocket in order to initiate conversation | ||
| response_parser: (Callable): Function that takes raw websocket message and tries to parse response message; message is discarded if function fails | ||
| message_builder: (Callable): Function that takes prompt and builds the message to send with it |
There was a problem hiding this comment.
Rather than a callable for message_builder, I think this would be more extensible to require a MessageStringNormalizer
There was a problem hiding this comment.
So in one of the chatbots I tested, a content-length parameter was required in each message that needed to be the length of the message. The callable message_builder allowed implementing this logic, and I was thinking there could be other use cases e.g. a websocket target that requires unique message IDs to be calculated.
| # Listen for responses | ||
| receive_messages = asyncio.create_task(self.receive_messages(conversation_id=conversation_id)) | ||
|
|
||
| result = await asyncio.wait_for(receive_messages, timeout=30.0) # Wait for all responses to be received |
There was a problem hiding this comment.
probably should have this as an init param
|
|
||
| import pytest | ||
| from websockets.exceptions import ConnectionClosed | ||
| from websockets.frames import Close |
There was a problem hiding this comment.
One of my biggest concerns about this is it's hard for us to test. Open to ideas here. In an ideal case, if there is some sort of public socket implementation or something we could chat with that's free or cheap and we're not breaking their TOS.
Then if you added a notebook demoing usage, and chatting with it, it'd be included in our integration tests.
There was a problem hiding this comment.
I agree it's tricky. One thing I haven't looked closely at is the realtime API; based on their use of different event types my initial thinking was it wouldn't work with this target class, but it would be interesting to see if it's possible and then I can use that as the example. I leave for vacation later today but will work on this when I get back in a week (if not today).
Description
To address feature request #1037 this commit adds a general WebsocketTarget class for targets using websockets instead of HTTP for communication. As this target class is intended to be generalized to work with all websocket targets, there is some heavy-lifting required by the user which is described below.
After the websocket connection is established, targets typically require several initialization/connection messages for establishing a conversation with the LLM. As such, the WebsocketTarget class requires a list of strings as input containing these messages. After connecting to the provided websocket endpoint, the WebsocketTarget class iterates through and sends the provided initialization messages. These initialization messages can be obtained by connecting to the target normally over a proxy, and extracting the messages from proxy history.
Additionally, there is no standard format for websocket messages. As such, the WebsocketTarget class requires two callable functions as input: response_parser and message_builder. The response parser takes a raw websocket message and extracts the response from LLM; it is expected to fail if message does not contain the actual response, which will allow the discarding of messages sent by server which do not contain the response (e.g. analytics messages). The message builder takes the adversarial prompt as input and returns a formatted websocket message with the injected prompt.
Finally, websocket LLMs typically send 1 or more greeting messages after connection is established. To prevent this message from being interpreted as the response to first adversarial prompt, greeting messages are discarded; the number of initial message to discard is determine by discard_initial_messages argument.
Tests and Documentation
Unit tests have been added for the WebsocketTarget class.
To develop this target class, I used public websites that had chatbots using websockets for communication. I am unsure about the legality of using those chatbots in a working example, especially since it would require the inclusion of server-specific connection strings and message format used by chatbot. Please let me know if there's a better way to approach this.