Coverage for /builds/hweiske/ase/ase/calculators/kim/kimpy_wrappers.py: 75.81%
339 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-22 11:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-22 11:22 +0000
1"""
2Wrappers that provide a minimal interface to kimpy methods and objects
4Daniel S. Karls
5University of Minnesota
6"""
8import functools
9from abc import ABC
11import numpy as np
13from .exceptions import (KIMModelInitializationError, KIMModelNotFound,
14 KIMModelParameterError, KimpyError)
17class LazyKimpyImport:
18 """This class avoids module level import of the optional kimpy module."""
20 def __getattr__(self, attr):
21 return getattr(self._kimpy, attr)
23 @functools.cached_property
24 def _kimpy(self):
25 import kimpy
26 return kimpy
29class Wrappers:
30 """Shortcuts written in a way that avoids module-level kimpy import."""
32 @property
33 def collections_create(self):
34 return functools.partial(check_call, kimpy.collections.create)
36 @property
37 def model_create(self):
38 return functools.partial(check_call, kimpy.model.create)
40 @property
41 def simulator_model_create(self):
42 return functools.partial(check_call, kimpy.simulator_model.create)
44 @property
45 def get_species_name(self):
46 return functools.partial(
47 check_call, kimpy.species_name.get_species_name)
49 @property
50 def get_number_of_species_names(self):
51 return functools.partial(
52 check_call, kimpy.species_name.get_number_of_species_names)
54 @property
55 def collection_item_type_portableModel(self):
56 return kimpy.collection_item_type.portableModel
59kimpy = LazyKimpyImport()
60wrappers = Wrappers()
62# Function used for casting parameter/extent indices to C-compatible ints
63c_int = np.intc
65# Function used for casting floating point parameter values to C-compatible
66# doubles
67c_double = np.double
70def c_int_args(func):
71 """
72 Decorator for instance methods that will cast all of the args passed,
73 excluding the first (which corresponds to 'self'), to C-compatible
74 integers.
75 """
77 @functools.wraps(func)
78 def myfunc(*args, **kwargs):
79 args_cast = [args[0]]
80 args_cast += map(c_int, args[1:])
81 return func(*args, **kwargs)
83 return myfunc
86def check_call(f, *args, **kwargs):
87 """Call a kimpy function using its arguments and, if a RuntimeError is
88 raised, catch it and raise a KimpyError with the exception's
89 message.
91 (Starting with kimpy 2.0.0, a RuntimeError is the only exception
92 type raised when something goes wrong.)"""
94 try:
95 return f(*args, **kwargs)
96 except RuntimeError as e:
97 raise KimpyError(
98 f'Calling kimpy function "{f.__name__}" failed:\n {e!s}')
101def check_call_wrapper(func):
102 @functools.wraps(func)
103 def myfunc(*args, **kwargs):
104 return check_call(func, *args, **kwargs)
106 return myfunc
109class ModelCollections:
110 """
111 KIM Portable Models and Simulator Models are installed/managed into
112 different "collections". In order to search through the different
113 KIM API model collections on the system, a corresponding object must
114 be instantiated. For more on model collections, see the KIM API's
115 install file:
116 https://github.com/openkim/kim-api/blob/master/INSTALL
117 """
119 def __init__(self):
120 self.collection = wrappers.collections_create()
122 def __enter__(self):
123 return self
125 def __exit__(self, exc_type, value, traceback):
126 pass
128 def get_item_type(self, model_name):
129 try:
130 model_type = check_call(self.collection.get_item_type, model_name)
131 except KimpyError:
132 msg = (
133 "Could not find model {} installed in any of the KIM API "
134 "model collections on this system. See "
135 "https://openkim.org/doc/usage/obtaining-models/ for "
136 "instructions on installing models.".format(model_name)
137 )
138 raise KIMModelNotFound(msg)
140 return model_type
142 @property
143 def initialized(self):
144 return hasattr(self, "collection")
147class PortableModel:
148 """Creates a KIM API Portable Model object and provides a minimal
149 interface to it"""
151 def __init__(self, model_name, debug):
152 self.model_name = model_name
153 self.debug = debug
155 # Create KIM API Model object
156 units_accepted, self.kim_model = wrappers.model_create(
157 kimpy.numbering.zeroBased,
158 kimpy.length_unit.A,
159 kimpy.energy_unit.eV,
160 kimpy.charge_unit.e,
161 kimpy.temperature_unit.K,
162 kimpy.time_unit.ps,
163 self.model_name,
164 )
166 if not units_accepted:
167 raise KIMModelInitializationError(
168 "Requested units not accepted in kimpy.model.create"
169 )
171 if self.debug:
172 l_unit, e_unit, c_unit, te_unit, ti_unit = check_call(
173 self.kim_model.get_units
174 )
175 print(f"Length unit is: {l_unit}")
176 print(f"Energy unit is: {e_unit}")
177 print(f"Charge unit is: {c_unit}")
178 print(f"Temperature unit is: {te_unit}")
179 print(f"Time unit is: {ti_unit}")
180 print()
182 self._create_parameters()
184 def __enter__(self):
185 return self
187 def __exit__(self, exc_type, value, traceback):
188 pass
190 @check_call_wrapper
191 def _get_number_of_parameters(self):
192 return self.kim_model.get_number_of_parameters()
194 def _create_parameters(self):
195 def _kim_model_parameter(**kwargs):
196 dtype = kwargs["dtype"]
198 if dtype == "Integer":
199 return KIMModelParameterInteger(**kwargs)
200 elif dtype == "Double":
201 return KIMModelParameterDouble(**kwargs)
202 else:
203 raise KIMModelParameterError(
204 f"Invalid model parameter type {dtype}. Supported types "
205 "'Integer' and 'Double'."
206 )
208 self._parameters = {}
209 num_params = self._get_number_of_parameters()
210 for index_param in range(num_params):
211 parameter_metadata = self._get_one_parameter_metadata(index_param)
212 name = parameter_metadata["name"]
214 self._parameters[name] = _kim_model_parameter(
215 kim_model=self.kim_model,
216 dtype=parameter_metadata["dtype"],
217 extent=parameter_metadata["extent"],
218 name=name,
219 description=parameter_metadata["description"],
220 parameter_index=index_param,
221 )
223 def get_model_supported_species_and_codes(self):
224 """Get all of the supported species for this model and their
225 corresponding integer codes that are defined in the KIM API
227 Returns
228 -------
229 species : list of str
230 Abbreviated chemical symbols of all species the mmodel
231 supports (e.g. ["Mo", "S"])
233 codes : list of int
234 Integer codes used by the model for each species (order
235 corresponds to the order of ``species``)
236 """
237 species = []
238 codes = []
239 num_kim_species = wrappers.get_number_of_species_names()
241 for i in range(num_kim_species):
242 species_name = wrappers.get_species_name(i)
244 species_is_supported, code = self.get_species_support_and_code(
245 species_name)
247 if species_is_supported:
248 species.append(str(species_name))
249 codes.append(code)
251 return species, codes
253 @check_call_wrapper
254 def clear_then_refresh(self):
255 self.kim_model.clear_then_refresh()
257 @c_int_args
258 def _get_parameter_metadata(self, index_parameter):
259 try:
260 dtype, extent, name, description = check_call(
261 self.kim_model.get_parameter_metadata, index_parameter
262 )
263 except KimpyError as e:
264 raise KIMModelParameterError(
265 "Failed to retrieve metadata for "
266 f"parameter at index {index_parameter}"
267 ) from e
269 return dtype, extent, name, description
271 def parameters_metadata(self):
272 """Metadata associated with all model parameters.
274 Returns
275 -------
276 dict
277 Metadata associated with all model parameters.
278 """
279 return {
280 param_name: param.metadata
281 for param_name, param in self._parameters.items()
282 }
284 def parameter_names(self):
285 """Names of model parameters registered in the KIM API.
287 Returns
288 -------
289 tuple
290 Names of model parameters registered in the KIM API
291 """
292 return tuple(self._parameters.keys())
294 def get_parameters(self, **kwargs):
295 """
296 Get the values of one or more model parameter arrays.
298 Given the names of one or more model parameters and a set of indices
299 for each of them, retrieve the corresponding elements of the relevant
300 model parameter arrays.
302 Parameters
303 ----------
304 **kwargs
305 Names of the model parameters and the indices whose values should
306 be retrieved.
308 Returns
309 -------
310 dict
311 The requested indices and the values of the model's parameters.
313 Note
314 ----
315 The output of this method can be used as input of
316 ``set_parameters``.
318 Example
319 -------
320 To get `epsilons` and `sigmas` in the LJ universal model for Mo-Mo
321 (index 4879), Mo-S (index 2006) and S-S (index 1980) interactions::
323 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
324 >>> calc = KIM(LJ)
325 >>> calc.get_parameters(epsilons=[4879, 2006, 1980],
326 ... sigmas=[4879, 2006, 1980])
327 {'epsilons': [[4879, 2006, 1980],
328 [4.47499, 4.421814057295943, 4.36927]],
329 'sigmas': [[4879, 2006, 1980],
330 [2.74397, 2.30743, 1.87089]]}
331 """
332 parameters = {}
333 for parameter_name, index_range in kwargs.items():
334 parameters.update(
335 self._get_one_parameter(
336 parameter_name,
337 index_range))
338 return parameters
340 def set_parameters(self, **kwargs):
341 """
342 Set the values of one or more model parameter arrays.
344 Given the names of one or more model parameters and a set of indices
345 and corresponding values for each of them, mutate the corresponding
346 elements of the relevant model parameter arrays.
348 Parameters
349 ----------
350 **kwargs
351 Names of the model parameters to mutate and the corresponding
352 indices and values to set.
354 Returns
355 -------
356 dict
357 The requested indices and the values of the model's parameters
358 that were set.
360 Example
361 -------
362 To set `epsilons` in the LJ universal model for Mo-Mo (index 4879),
363 Mo-S (index 2006) and S-S (index 1980) interactions to 5.0, 4.5, and
364 4.0, respectively::
366 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
367 >>> calc = KIM(LJ)
368 >>> calc.set_parameters(epsilons=[[4879, 2006, 1980],
369 ... [5.0, 4.5, 4.0]])
370 {'epsilons': [[4879, 2006, 1980],
371 [5.0, 4.5, 4.0]]}
372 """
373 parameters = {}
374 for parameter_name, parameter_data in kwargs.items():
375 index_range, values = parameter_data
376 self._set_one_parameter(parameter_name, index_range, values)
377 parameters[parameter_name] = parameter_data
379 return parameters
381 def _get_one_parameter(self, parameter_name, index_range):
382 """
383 Retrieve value of one or more components of a model parameter array.
385 Parameters
386 ----------
387 parameter_name : str
388 Name of model parameter registered in the KIM API.
389 index_range : int or list
390 Zero-based index (int) or indices (list of int) specifying the
391 component(s) of the corresponding model parameter array that are
392 to be retrieved.
394 Returns
395 -------
396 dict
397 The requested indices and the corresponding values of the model
398 parameter array.
399 """
400 if parameter_name not in self._parameters:
401 raise KIMModelParameterError(
402 f"Parameter '{parameter_name}' is not "
403 "supported by this model. "
404 "Please check that the parameter name is spelled correctly."
405 )
407 return self._parameters[parameter_name].get_values(index_range)
409 def _set_one_parameter(self, parameter_name, index_range, values):
410 """
411 Set the value of one or more components of a model parameter array.
413 Parameters
414 ----------
415 parameter_name : str
416 Name of model parameter registered in the KIM API.
417 index_range : int or list
418 Zero-based index (int) or indices (list of int) specifying the
419 component(s) of the corresponding model parameter array that are
420 to be mutated.
421 values : int/float or list
422 Value(s) to assign to the component(s) of the model parameter
423 array specified by ``index_range``.
424 """
425 if parameter_name not in self._parameters:
426 raise KIMModelParameterError(
427 f"Parameter '{parameter_name}' is not "
428 "supported by this model. "
429 "Please check that the parameter name is spelled correctly."
430 )
432 self._parameters[parameter_name].set_values(index_range, values)
434 def _get_one_parameter_metadata(self, index_parameter):
435 """
436 Get metadata associated with a single model parameter.
438 Parameters
439 ----------
440 index_parameter : int
441 Zero-based index used by the KIM API to refer to this model
442 parameter.
444 Returns
445 -------
446 dict
447 Metadata associated with the requested model parameter.
448 """
449 dtype, extent, name, description = self._get_parameter_metadata(
450 index_parameter)
451 parameter_metadata = {
452 "name": name,
453 "dtype": repr(dtype),
454 "extent": extent,
455 "description": description,
456 }
457 return parameter_metadata
459 @check_call_wrapper
460 def compute(self, compute_args_wrapped, release_GIL):
461 return self.kim_model.compute(
462 compute_args_wrapped.compute_args, release_GIL)
464 @check_call_wrapper
465 def get_species_support_and_code(self, species_name):
466 return self.kim_model.get_species_support_and_code(species_name)
468 @check_call_wrapper
469 def get_influence_distance(self):
470 return self.kim_model.get_influence_distance()
472 @check_call_wrapper
473 def get_neighbor_list_cutoffs_and_hints(self):
474 return self.kim_model.get_neighbor_list_cutoffs_and_hints()
476 def compute_arguments_create(self):
477 return ComputeArguments(self, self.debug)
479 @property
480 def initialized(self):
481 return hasattr(self, "kim_model")
484class KIMModelParameter(ABC):
485 def __init__(self, kim_model, dtype, extent,
486 name, description, parameter_index):
487 self._kim_model = kim_model
488 self._dtype = dtype
489 self._extent = extent
490 self._name = name
491 self._description = description
493 # Ensure that parameter_index is cast to a C-compatible integer. This
494 # is necessary because this is passed to kimpy.
495 self._parameter_index = c_int(parameter_index)
497 @property
498 def metadata(self):
499 return {
500 "dtype": self._dtype,
501 "extent": self._extent,
502 "name": self._name,
503 "description": self._description,
504 }
506 @c_int_args
507 def _get_one_value(self, index_extent):
508 get_parameter = getattr(self._kim_model, self._dtype_accessor)
509 try:
510 return check_call(
511 get_parameter, self._parameter_index, index_extent)
512 except KimpyError as exception:
513 raise KIMModelParameterError(
514 f"Failed to access component {index_extent} of model "
515 f"parameter of type '{self._dtype}' at parameter index "
516 f"{self._parameter_index}"
517 ) from exception
519 def _set_one_value(self, index_extent, value):
520 value_typecast = self._dtype_c(value)
522 try:
523 check_call(
524 self._kim_model.set_parameter,
525 self._parameter_index,
526 c_int(index_extent),
527 value_typecast,
528 )
529 except KimpyError:
530 raise KIMModelParameterError(
531 f"Failed to set component {index_extent} at parameter index "
532 f"{self._parameter_index} to {self._dtype} value "
533 f"{value_typecast}"
534 )
536 def get_values(self, index_range):
537 index_range_dim = np.ndim(index_range)
538 if index_range_dim == 0:
539 values = self._get_one_value(index_range)
540 elif index_range_dim == 1:
541 values = []
542 for idx in index_range:
543 values.append(self._get_one_value(idx))
544 else:
545 raise KIMModelParameterError(
546 "Index range must be an integer or a list of integers"
547 )
548 return {self._name: [index_range, values]}
550 def set_values(self, index_range, values):
551 index_range_dim = np.ndim(index_range)
552 values_dim = np.ndim(values)
554 # Check the shape of index_range and values
555 msg = "index_range and values must have the same shape"
556 assert index_range_dim == values_dim, msg
558 if index_range_dim == 0:
559 self._set_one_value(index_range, values)
560 elif index_range_dim == 1:
561 assert len(index_range) == len(values), msg
562 for idx, value in zip(index_range, values):
563 self._set_one_value(idx, value)
564 else:
565 raise KIMModelParameterError(
566 "Index range must be an integer or a list containing a "
567 "single integer"
568 )
571class KIMModelParameterInteger(KIMModelParameter):
572 _dtype_c = c_int
573 _dtype_accessor = "get_parameter_int"
576class KIMModelParameterDouble(KIMModelParameter):
577 _dtype_c = c_double
578 _dtype_accessor = "get_parameter_double"
581class ComputeArguments:
582 """Creates a KIM API ComputeArguments object from a KIM Portable
583 Model object and configures it for ASE. A ComputeArguments object
584 is associated with a KIM Portable Model and is used to inform the
585 KIM API of what the model can compute. It is also used to
586 register the data arrays that allow the KIM API to pass the atomic
587 coordinates to the model and retrieve the corresponding energy and
588 forces, etc."""
590 def __init__(self, kim_model_wrapped, debug):
591 self.kim_model_wrapped = kim_model_wrapped
592 self.debug = debug
594 # Create KIM API ComputeArguments object
595 self.compute_args = check_call(
596 self.kim_model_wrapped.kim_model.compute_arguments_create
597 )
599 # Check compute arguments
600 kimpy_arg_name = kimpy.compute_argument_name
601 num_arguments = kimpy_arg_name.get_number_of_compute_argument_names()
602 if self.debug:
603 print(f"Number of compute_args: {num_arguments}")
605 for i in range(num_arguments):
606 name = check_call(kimpy_arg_name.get_compute_argument_name, i)
607 dtype = check_call(
608 kimpy_arg_name.get_compute_argument_data_type, name)
610 arg_support = self.get_argument_support_status(name)
612 if self.debug:
613 print(
614 "Compute Argument name {:21} is of type {:7} "
615 "and has support "
616 "status {}".format(*[str(x)
617 for x in [name, dtype, arg_support]])
618 )
620 # See if the model demands that we ask it for anything
621 # other than energy and forces. If so, raise an
622 # exception.
623 if arg_support == kimpy.support_status.required:
624 if (
625 name != kimpy.compute_argument_name.partialEnergy
626 and name != kimpy.compute_argument_name.partialForces
627 ):
628 raise KIMModelInitializationError(
629 f"Unsupported required ComputeArgument {name}"
630 )
632 # Check compute callbacks
633 callback_name = kimpy.compute_callback_name
634 num_callbacks = callback_name.get_number_of_compute_callback_names()
635 if self.debug:
636 print()
637 print(f"Number of callbacks: {num_callbacks}")
639 for i in range(num_callbacks):
640 name = check_call(callback_name.get_compute_callback_name, i)
642 support_status = self.get_callback_support_status(name)
644 if self.debug:
645 print(
646 "Compute callback {:17} has support status {}".format(
647 str(name), support_status
648 )
649 )
651 # Cannot handle any "required" callbacks
652 if support_status == kimpy.support_status.required:
653 raise KIMModelInitializationError(
654 f"Unsupported required ComputeCallback: {name}"
655 )
657 @check_call_wrapper
658 def set_argument_pointer(self, compute_arg_name, data_object):
659 return self.compute_args.set_argument_pointer(
660 compute_arg_name, data_object)
662 @check_call_wrapper
663 def get_argument_support_status(self, name):
664 return self.compute_args.get_argument_support_status(name)
666 @check_call_wrapper
667 def get_callback_support_status(self, name):
668 return self.compute_args.get_callback_support_status(name)
670 @check_call_wrapper
671 def set_callback(self, compute_callback_name,
672 callback_function, data_object):
673 return self.compute_args.set_callback(
674 compute_callback_name, callback_function, data_object
675 )
677 @check_call_wrapper
678 def set_callback_pointer(
679 self, compute_callback_name, callback, data_object):
680 return self.compute_args.set_callback_pointer(
681 compute_callback_name, callback, data_object
682 )
684 def update(
685 self, num_particles, species_code, particle_contributing,
686 coords, energy, forces
687 ):
688 """Register model input and output in the kim_model object."""
689 compute_arg_name = kimpy.compute_argument_name
690 set_argument_pointer = self.set_argument_pointer
692 set_argument_pointer(compute_arg_name.numberOfParticles, num_particles)
693 set_argument_pointer(
694 compute_arg_name.particleSpeciesCodes,
695 species_code)
696 set_argument_pointer(
697 compute_arg_name.particleContributing, particle_contributing
698 )
699 set_argument_pointer(compute_arg_name.coordinates, coords)
700 set_argument_pointer(compute_arg_name.partialEnergy, energy)
701 set_argument_pointer(compute_arg_name.partialForces, forces)
703 if self.debug:
704 print("Debug: called update_kim")
705 print()
708class SimulatorModel:
709 """Creates a KIM API Simulator Model object and provides a minimal
710 interface to it. This is only necessary in this package in order to
711 extract any information about a given simulator model because it is
712 generally embedded in a shared object.
713 """
715 def __init__(self, model_name):
716 # Create a KIM API Simulator Model object for this model
717 self.model_name = model_name
718 self.simulator_model = wrappers.simulator_model_create(self.model_name)
720 # Need to close template map in order to access simulator
721 # model metadata
722 self.simulator_model.close_template_map()
724 def __enter__(self):
725 return self
727 def __exit__(self, exc_type, value, traceback):
728 pass
730 @property
731 def simulator_name(self):
732 simulator_name, _ = self.simulator_model.\
733 get_simulator_name_and_version()
734 return simulator_name
736 @property
737 def num_supported_species(self):
738 num_supported_species = self.simulator_model.\
739 get_number_of_supported_species()
740 if num_supported_species == 0:
741 raise KIMModelInitializationError(
742 "Unable to determine supported species of "
743 "simulator model {}.".format(self.model_name)
744 )
745 return num_supported_species
747 @property
748 def supported_species(self):
749 supported_species = []
750 for spec_code in range(self.num_supported_species):
751 species = check_call(
752 self.simulator_model.get_supported_species, spec_code)
753 supported_species.append(species)
755 return tuple(supported_species)
757 @property
758 def num_metadata_fields(self):
759 return self.simulator_model.get_number_of_simulator_fields()
761 @property
762 def metadata(self):
763 sm_metadata_fields = {}
764 for field in range(self.num_metadata_fields):
765 extent, field_name = check_call(
766 self.simulator_model.get_simulator_field_metadata, field
767 )
768 sm_metadata_fields[field_name] = []
769 for ln in range(extent):
770 field_line = check_call(
771 self.simulator_model.get_simulator_field_line, field, ln
772 )
773 sm_metadata_fields[field_name].append(field_line)
775 return sm_metadata_fields
777 @property
778 def supported_units(self):
779 try:
780 supported_units = self.metadata["units"][0]
781 except (KeyError, IndexError):
782 raise KIMModelInitializationError(
783 "Unable to determine supported units of "
784 "simulator model {}.".format(self.model_name)
785 )
787 return supported_units
789 @property
790 def atom_style(self):
791 """
792 See if a 'model-init' field exists in the SM metadata and, if
793 so, whether it contains any entries including an "atom_style"
794 command. This is specific to LAMMPS SMs and is only required
795 for using the LAMMPSrun calculator because it uses
796 lammps.inputwriter to create a data file. All other content in
797 'model-init', if it exists, is ignored.
798 """
799 atom_style = None
800 for ln in self.metadata.get("model-init", []):
801 if ln.find("atom_style") != -1:
802 atom_style = ln.split()[1]
804 return atom_style
806 @property
807 def model_defn(self):
808 return self.metadata["model-defn"]
810 @property
811 def initialized(self):
812 return hasattr(self, "simulator_model")