Posted on :: Tags: , ,

Table of Contents

This dialog about converting ColQwen2 to ONNX format.

๐Ÿข Now, I have two classes. This one is the copy of ONNX Patcher thatโ€™s used to convert ColQwen2 to ONNX format. But this one is valid only for the underlying model, that is Qwen2VLForConditionalGeneration and ColQwen2 also has some modifications to call this.

๐Ÿ‡ What are these modifications?

๐Ÿข The forward method is something like this.

ColQwen2

    def forward(self, *args, **kwargs) -> torch.Tensor:
        kwargs.pop("output_hidden_states", None)

        # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
        if "pixel_values" in kwargs:
            offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]  # (batch_size,)
            kwargs["pixel_values"] = torch.cat(
                [pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
                dim=0,
            )

        position_ids, rope_deltas = self.get_rope_index(
            input_ids=kwargs["input_ids"],
            image_grid_thw=kwargs.get("image_grid_thw", None),
            video_grid_thw=None,
            attention_mask=kwargs.get("attention_mask", None),
        )
        last_hidden_states = self.inner_forward(
            *args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True
        )  # (batch_size, sequence_length, hidden_size)

        proj = self.custom_text_proj(last_hidden_states)  # (batch_size, sequence_length, dim)

        # L2 normalization
        proj = proj / proj.norm(dim=-1, keepdim=True)  # (batch_size, sequence_length, dim)
        proj = proj * kwargs["attention_mask"].unsqueeze(-1)  # (batch_size, sequence_length, dim)

        if "pixel_values" in kwargs and self.mask_non_image_embeddings:
            # Pools only the image embeddings
            image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
            proj = proj * image_mask
        return proj
PYTHON

๐Ÿ‡ How about the inner_forward method? Does it just directly call Qwen2VLForConditionalGenerations forward?

๐Ÿข No, it has some conditionals:

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.get_dtype())
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        return hidden_states
PYTHON

๐ŸฆŠ The patcher only modifies the past_key_values_args and moves them to a DynamicCache. There are no other changes for patching.

๐Ÿข But the return types are also different. Qwen2VLForConditionalGeneration.forward returns a Union[Tuple, Qwen2VLCausalLMOutputWithPast] but ColQwen2.forward returns a torch.Tensor.

๐Ÿ‡ How is that Torch.tensor calculated from the return type of Qwen2VLForConditionalGeneration? โ“

๐ŸฆŠ There are multiple return types in Qwen2VLForConditionalGeneration. Itโ€™s determined by return_dict argument.


        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return Qwen2VLCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )
PYTHON

๐Ÿ‡ Whatโ€™s that argument in ColQwen2?

๐Ÿข Itโ€™s passed from the caller, itโ€™s a kwarg.

๐Ÿ‡ Is there any modification to this in the ONNX patcher?

๐Ÿข Nope.

๐Ÿ‡ Then we can consider it as the default. Whatโ€™s the default?

๐Ÿข Default is None. Hence it returns (logits,) + outputs[1:].

๐Ÿ‡ Then, logits are the first parameter in default.

๐Ÿข Yes, we can assume this.

๐Ÿ‡ What does ColQwen2 do with these logits?

๐ŸฆŠ By the way, it calls inner_forward with use_cache=False and output_hidden_states=True. What does these change in Qwen2VLForConditionalGeneration?

๐Ÿข These hidden states are actually output. In ColQwen2.forward, the inner_forward call is actually

        last_hidden_states = self.inner_forward(
            *args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True
        )  # (batch_size, sequence_length, hidden_size)
PYTHON

and the output is the hidden states.

๐ŸฆŠ It then runs

        proj = self.custom_text_proj(last_hidden_states)  # (batch_size, sequence_length, dim)
PY

and calculates a set of projections.

๐Ÿข custom_set_proj is defined as

        self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim)
PY

hence these projections are fully connected layer calculations.

๐ŸฆŠ It returns these after L2 normalization and checks if there are pixel_values to consider.

๐Ÿข In this case, output from ColQwen2 is this single tensor.

๐Ÿ‡ Ok. Then itโ€™s actually simpler than wrapping up the whole Qwen2VLForConditionalGeneration.

๐Ÿข It looks so, yes. But the example for ColQwen2 also has a post-processing step,

scores = processor.score_multi_vector(query_embeddings, image_embeddings)
PY

๐Ÿ‡ I think we can leave that post-processing for the time being. We only need multi-vector embeddings for fastembed.

๐Ÿข Maybe itโ€™s convertible to a forward method that we can use with ONNX

๐Ÿ‡ Letโ€™s see how this scoring works, then.

๐Ÿข score_multi_vector receives two tensors or tensor lists and compares query vectors with passage vectors. Itโ€™s rather a straigthforward implementation that requires torch, but not transformers.

๐Ÿ‡ In theory we can also convert this to an ONNX model.

๐ŸฆŠ We can also write a custom model for this.

๐Ÿข It has two for loops for comparisons. Can we convert all of these?

๐ŸฆŠ The loop indices can be seen as dynamic_axes and itโ€™s possible to convert the whole thing as a Torch model, then use torch.onnx.export just similar to the model itself.

๐Ÿข I see, but I donโ€™t think thatโ€™s what we must do now.

๐Ÿ‡ Yes, letโ€™s skip that for the time being and convert the model itself. Weโ€™ll have multivectors for patches and queries at the end.

๐Ÿข So, weโ€™ll keep the processing part, send BatchFeature objects that are output from the processor, and send this to two models, one for images and one for text.

๐Ÿ‡ Yep, thatโ€™s the plan. At the end weโ€™ll have two ONNX models that require BatchFeatures.

๐Ÿข Then, weโ€™ll modify past_key_values in this argument to use DynamicCache.

๐Ÿ‡ Yes, thatโ€™s alright. We can start by moving past_key_values_converter to a method in PatchedColQwen2.

๐Ÿข past_key_values wasnโ€™t used much so I completely removed it. I also began to use processor outputs as dummy input in the exporter. However, when using BatchFeature, we get

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: BatchFeature

๐Ÿ‡ So, we can try dynamo first I think. If that doesnโ€™t work, we can just collect items as tensors and build up batch feature inside the patcher.

๐Ÿข Letโ€™s try dynamo=True for this first.

๐ŸฆŠ This time, itโ€™s about BatchFeature again, but the error is different: KeyError: 'Indexing with integers is not available when using Python based feature extractors'

๐Ÿ‡ In this case, we can just create BatchFeature inside the patcher.

PatchedColQwen2


class PatchedColQwen2(ColQwen2):
    def forward(self, *args):
        (
            input_ids,
            inputs_embeds,
            attention_mask,
            position_ids,
            *past_key_values_args,
        ) = args
        # Convert past_key_values list to DynamicCache
        if len(past_key_values_args) == 0:
            past_key_values = None
        else:
            past_key_values = DynamicCache()
            for i in range(self.config.num_hidden_layers):
                key = past_key_values_args.pop(0)
                value = past_key_values_args.pop(0)
                past_key_values.update(key_states=key, value_states=value, layer_idx=i)

        breakpoint()
        o = super().forward(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            # position_ids=position_ids,
            past_key_values=past_key_values,
        )

        flattened_past_key_values_outputs = {
            "logits": o.logits,
        }
        output_past_key_values: DynamicCache = o.past_key_values
        for i, (key, value) in enumerate(
            zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
        ):
            flattened_past_key_values_outputs[f"present.{i}.key"] = key
            flattened_past_key_values_outputs[f"present.{i}.value"] = value

        return flattened_past_key_values_outputs

PYTHON