Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions caikit_nlp/data_model/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
"""
# Standard
from enum import Enum
from typing import List
from typing import List, Optional

# First Party
from caikit.core import DataObjectBase
from caikit.core import DataObjectBase, dataobject

# First party
import alog
Expand Down Expand Up @@ -71,3 +71,33 @@ class TuningConfig(DataObjectBase):
# num_layers: int # Optional - The number of layers in the base transformer model
#
# encoder_hidden_size: int # Optional - The hidden size of the prompt encoder.


@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
class DecodingParameters(DataObjectBase):
@dataobject
class ExponentialDecayLengthPenalty(DataObjectBase):
start_index: int
decay_factor: float

repetition_penalty: float
exponential_decay_length_penalty: ExponentialDecayLengthPenalty


@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
class SamplingParameters(DataObjectBase):

temperature: float
top_k: int
top_p: int
typical_p: float
seed: Optional[int]
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to set some default! TGIS defaults are here.

Most of the time this doesn't matter, because 0 temperature (in the IBM fork) indicates greedy decoding, so top_k, top_p, typical_p, etc won't be used, as they're sampling only.

TGI doesn't use temperature 0 as a toggle though, so it would be also be nice in case those APIs are ever more unified - currently there are some small divergences with stuff like prompt IDs. I'm not sure if our raw generation modules are compatible with it or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't seen us setting defaults on the data models themselves, only in the inference methods. I don't really have a strong opinion on this, trying to understand if that is the general direction caikit is moving in

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think even if we set defaults on the DM, they won't propagate to proto, so the default here would be guided by the .run function themselves.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - my main concern with leaving it up to run is that it's easy for defaults to get out of sync if we have multiple modules relying on them.

I guess an alternate is to either have a building for getting these objects with their default values that make sense, or to have consts be passed to the run function 🤔 is the intent with this type to have a parameter that is this DM object type, or to take primitives and build this object in the requests?



@caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
class StoppingCriteria(DataObjectBase):

max_new_tokens: int
min_new_tokens: int
time_limit_millis: int
stop_sequences: List[str]
113 changes: 113 additions & 0 deletions tests/data_model/test_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from caikit_nlp.data_model import (
DecodingParameters,
SamplingParameters,
StoppingCriteria,
)

## Setup #########################################################################

dummy_exponential_decay_length_penalty = (
DecodingParameters.ExponentialDecayLengthPenalty(start_index=1, decay_factor=0.95)
)
dummy_sampling_parameters = DecodingParameters(
repetition_penalty=1.2,
exponential_decay_length_penalty=dummy_exponential_decay_length_penalty,
)

dummy_sampling_parameters = SamplingParameters(
temperature=0.5, top_k=0, top_p=0, typical_p=0.2, seed=42
)

dummy_stopping_criteria = StoppingCriteria(
max_new_tokens=200, min_new_tokens=50, time_limit_millis=0, stop_sequences=["Test"]
)

## Tests ########################################################################

### Decoding Parameters
def test_sampling_parameters_all_fields_accessible():
assert dummy_sampling_parameters.repetition_penalty == 1.2
assert dummy_sampling_parameters.exponential_decay_length_penalty.start_index == 1
assert (
dummy_sampling_parameters.exponential_decay_length_penalty.decay_factor == 0.95
)


def test_sampling_parameters_from_proto_and_back():
new = DecodingParameters.from_proto(dummy_sampling_parameters.to_proto())
assert new.repetition_penalty == 1.2
assert new.exponential_decay_length_penalty.start_index == 1
assert new.exponential_decay_length_penalty.decay_factor == 0.95


def test_sampling_parameters_from_json_and_back():
new = DecodingParameters.from_json(dummy_sampling_parameters.to_json())
assert new.repetition_penalty == 1.2
assert new.exponential_decay_length_penalty.start_index == 1
assert new.exponential_decay_length_penalty.decay_factor == 0.95


### Sampling Parameters
def test_sampling_parameters_all_fields_accessible():
assert dummy_sampling_parameters.temperature == 0.5
assert dummy_sampling_parameters.top_k == 0
assert dummy_sampling_parameters.top_p == 0
assert dummy_sampling_parameters.typical_p == 0.2
assert dummy_sampling_parameters.seed == 42


def test_sampling_parameters_from_proto_and_back():
new = SamplingParameters.from_proto(dummy_sampling_parameters.to_proto())
assert new.temperature == 0.5
assert new.top_k == 0
assert new.top_p == 0
assert new.typical_p == 0.2
assert new.seed == 42


def test_sampling_parameters_from_json_and_back():
new = SamplingParameters.from_json(dummy_sampling_parameters.to_json())
assert new.temperature == 0.5
assert new.top_k == 0
assert new.top_p == 0
assert new.typical_p == 0.2
assert new.seed == 42


### Stopping Criteria
def test_stopping_criteria_all_fields_accessible():
assert dummy_stopping_criteria.max_new_tokens == 200
assert dummy_stopping_criteria.min_new_tokens == 50
assert dummy_stopping_criteria.time_limit_millis == 0
assert dummy_stopping_criteria.stop_sequences == ["Test"]


def test_stopping_criteria_from_proto_and_back():
new = StoppingCriteria.from_proto(dummy_stopping_criteria.to_proto())
assert new.max_new_tokens == 200
assert new.min_new_tokens == 50
assert new.time_limit_millis == 0
assert new.stop_sequences == ["Test"]


def test_stopping_criteria_from_json_and_back():
new = StoppingCriteria.from_json(dummy_stopping_criteria.to_json())
assert new.max_new_tokens == 200
assert new.min_new_tokens == 50
assert new.time_limit_millis == 0
assert new.stop_sequences == ["Test"]