Table of Contents
This dialog about converting ColQwen2 to ONNX format.
๐ข Now, I have two classes. This one is the copy of ONNX Patcher thatโs used to convert ColQwen2 to ONNX format. But this one is valid only for the underlying model, that is Qwen2VLForConditionalGeneration
and ColQwen2
also has some modifications to call this.
๐ What are these modifications?
๐ข The forward method is something like this.
ColQwen2
def forward(self, *args, **kwargs) -> torch.Tensor:
kwargs.pop("output_hidden_states", None)
# Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
if "pixel_values" in kwargs:
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2] # (batch_size,)
kwargs["pixel_values"] = torch.cat(
[pixel_sequence[:offset] for pixel_sequence, offset in zip(kwargs["pixel_values"], offsets)],
dim=0,
)
position_ids, rope_deltas = self.get_rope_index(
input_ids=kwargs["input_ids"],
image_grid_thw=kwargs.get("image_grid_thw", None),
video_grid_thw=None,
attention_mask=kwargs.get("attention_mask", None),
)
last_hidden_states = self.inner_forward(
*args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True
) # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim)
if "pixel_values" in kwargs and self.mask_non_image_embeddings:
# Pools only the image embeddings
image_mask = (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
proj = proj * image_mask
return proj
PYTHON
๐ How about the inner_forward
method? Does it just directly call Qwen2VLForConditionalGeneration
s forward
?
๐ข No, it has some conditionals:
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
return hidden_states
PYTHON
๐ฆ The patcher only modifies the past_key_values_args
and moves them to a DynamicCache
. There are no other changes for patching.
๐ข But the return types are also different. Qwen2VLForConditionalGeneration.forward
returns a Union[Tuple, Qwen2VLCausalLMOutputWithPast]
but ColQwen2.forward
returns a torch.Tensor
.
๐ How is that Torch.tensor
calculated from the return type of Qwen2VLForConditionalGeneration
? โ
๐ฆ There are multiple return types in Qwen2VLForConditionalGeneration
. Itโs determined by return_dict
argument.
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
PYTHON
๐ Whatโs that argument in ColQwen2
?
๐ข Itโs passed from the caller, itโs a kwarg
.
๐ Is there any modification to this in the ONNX patcher?
๐ข Nope.
๐ Then we can consider it as the default. Whatโs the default?
๐ข Default is None
. Hence it returns (logits,) + outputs[1:]
.
๐ Then, logits
are the first parameter in default.
๐ข Yes, we can assume this.
๐ What does ColQwen2
do with these logits?
๐ฆ By the way, it calls inner_forward
with use_cache=False
and output_hidden_states=True
. What does these change in Qwen2VLForConditionalGeneration
?
๐ข These hidden states are actually output. In ColQwen2.forward
, the inner_forward call is actually
last_hidden_states = self.inner_forward(
*args, **kwargs, position_ids=position_ids, use_cache=False, output_hidden_states=True
) # (batch_size, sequence_length, hidden_size)
PYTHON
and the output
is the hidden states.
๐ฆ It then runs
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
PY
and calculates a set of projections.
๐ข custom_set_proj
is defined as
self.custom_text_proj = nn.Linear(self.model.config.hidden_size, self.dim)
PY
hence these projections are fully connected layer calculations.
๐ฆ It returns these after L2 normalization and checks if there are pixel_values
to consider.
๐ข In this case, output from ColQwen2
is this single tensor.
๐ Ok. Then itโs actually simpler than wrapping up the whole Qwen2VLForConditionalGeneration
.
๐ข It looks so, yes. But the example for ColQwen2
also has a post-processing step,
scores = processor.score_multi_vector(query_embeddings, image_embeddings)
PY
๐ I think we can leave that post-processing for the time being. We only need multi-vector embeddings for fastembed.
๐ข Maybe itโs convertible to a forward method that we can use with ONNX
๐ Letโs see how this scoring works, then.
๐ข score_multi_vector
receives two tensors or tensor lists and compares query vectors with passage vectors. Itโs rather a straigthforward implementation that requires torch, but not transformers.
๐ In theory we can also convert this to an ONNX model.
๐ฆ We can also write a custom model for this.
๐ข It has two for loops for comparisons. Can we convert all of these?
๐ฆ The loop indices can be seen as dynamic_axes
and itโs possible to convert the whole thing as a Torch model, then use torch.onnx.export
just similar to the model itself.
๐ข I see, but I donโt think thatโs what we must do now.
๐ Yes, letโs skip that for the time being and convert the model itself. Weโll have multivectors for patches and queries at the end.
๐ข So, weโll keep the processing part, send BatchFeature
objects that are output from the processor, and send this to two models, one for images and one for text.
๐ Yep, thatโs the plan. At the end weโll have two ONNX models that require BatchFeature
s.
๐ข Then, weโll modify past_key_values
in this argument to use DynamicCache
.
๐ Yes, thatโs alright. We can start by moving past_key_values_converter
to a method in PatchedColQwen2
.
๐ข past_key_values
wasnโt used much so I completely removed it. I also began to use processor outputs as dummy input in the exporter. However, when using BatchFeature
, we get
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: BatchFeature
๐ So, we can try dynamo first I think. If that doesnโt work, we can just collect items as tensors and build up batch feature inside the patcher.
๐ข Letโs try dynamo=True for this first.
๐ฆ This time, itโs about BatchFeature again, but the error is different: KeyError: 'Indexing with integers is not available when using Python based feature extractors'
๐ In this case, we can just create BatchFeature
inside the patcher.
PatchedColQwen2
class PatchedColQwen2(ColQwen2):
def forward(self, *args):
(
input_ids,
inputs_embeds,
attention_mask,
position_ids,
*past_key_values_args,
) = args
# Convert past_key_values list to DynamicCache
if len(past_key_values_args) == 0:
past_key_values = None
else:
past_key_values = DynamicCache()
for i in range(self.config.num_hidden_layers):
key = past_key_values_args.pop(0)
value = past_key_values_args.pop(0)
past_key_values.update(key_states=key, value_states=value, layer_idx=i)
breakpoint()
o = super().forward(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
# position_ids=position_ids,
past_key_values=past_key_values,
)
flattened_past_key_values_outputs = {
"logits": o.logits,
}
output_past_key_values: DynamicCache = o.past_key_values
for i, (key, value) in enumerate(
zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
):
flattened_past_key_values_outputs[f"present.{i}.key"] = key
flattened_past_key_values_outputs[f"present.{i}.value"] = value
return flattened_past_key_values_outputs
PYTHON